QOJ 6504. CCPC Final 2022 D Flower's Land 2题解
QOJ 6504. CCPC Final 2022 D Flower's Land 2题解
题意简述
给你一个只含 \(0,1,2\) 的序列,相邻两个相同的数字可以直接消掉。
询问包含两种
-
区间所有数 \(+1\) 并对 \(3\) 取模。
-
求一段区间能否用上述消除方式消完。
样例输入
8 9 01211012 2 4 5 2 3 6 1 6 8 1 6 8 2 3 6 2 1 8 1 1 1 1 7 7 2 1 8
样例输出 #1
Yes No Yes No Yes
提示
在我们做相邻两个能被消掉,判断一段区间能否被消掉时,常常用矩阵来考虑。
把每一种颜色用一种矩阵来表示,若当前位是偶数就设为这个矩阵,若当前位是奇数就设为这个矩阵的逆。
求解就把所有的矩阵乘起来,看最后结果矩阵是不是 \(I\) 。
为什么矩阵是正确的呢?因为矩阵满足结合律但不满足交换律。
这样就可以保证 \(1,2,3,1,2,3\) 会判断为错。
如果还没理解,下面再解释详细一点:
这是一段序列 \(0122221000\) 显然他是合法的。
在矩阵中,因为满足结合律,你先算中间那段 \(2222\) ,因为奇数和偶数个数相同,一定为 \(I\) ,相当于没有了,变成了 \(0110000\) 。一直向下就可以得到 \(I\)。
题解
我们用线段树来维护矩阵乘法,这很容易,具体就是加了以后如何在矩阵中体现出来。
因为只有 \(0,1,2\) ,我们把当前,\(+1\) 后, \(+2\) 后的矩阵都记录下来。这样就可以了。
代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 5e5 + 10, mod = 998244353;
int n,q;
char x;
int a[N], opt, l, r;
inline ll mpow(ll x,int k){
ll ans = 1;
while(k){
if(k & 1) ans = ans * x % mod;
x = x * x % mod;
k >>= 1;
}
return ans;
}
struct Mar{
ll a[3][3];
inline Mar operator *(const Mar b)const{
Mar c;
for(int i = 1; i <= 2; ++i){
for(int j = 1; j <= 2; ++j){
c.a[i][j] = 0;
for(int k = 1;k <= 2; ++k){
c.a[i][j] = (c.a[i][j] + a[i][k] * b.a[k][j] % mod) % mod;
}
}
}
return c;
}
inline bool check(){
if(a[1][1] != 1) return 0;
if(a[1][2] != 0) return 0;
if(a[2][1] != 0) return 0;
if(a[2][2] != 1) return 0;
return 1;
}
inline Mar inv()const{
Mar c, b;
c.a[1][1] = 1;
c.a[1][2] = 0;
c.a[2][1] = 0;
c.a[2][2] = 1;
for(int i = 1; i <= 2; ++i)for(int j = 1; j <= 2; ++j) b.a[i][j] = a[i][j];
for(int i = 1; i <= 2; ++i){
for(int j = 1; j <= 2; ++j){
if(i == j) continue;
ll w = b.a[j][i] * mpow(b.a[i][i],mod - 2) % mod;
for(int k = 1; k <= 2; ++k){
b.a[j][k] = (b.a[j][k] - b.a[i][k] * w % mod + mod) % mod;
}
for(int k = 1; k <= 2; ++k){
c.a[j][k] = (c.a[j][k] - c.a[i][k] * w % mod + mod) % mod;
}
}
}
for(int i = 1; i <= 2; ++i){
for(int j = 1; j <= 2; ++j){
c.a[i][j] = c.a[i][j] * mpow(b.a[i][i],mod - 2) % mod;
}
}
return c;
}
inline void print(){
for(int i = 1; i <= 2; ++i){
for(int j = 1; j <= 2; ++j){
cout<<a[i][j]<<' ';
}
cout<<'\n';
}
}
}I;
struct node{
Mar now,nxt,nnt;
int tag;
}tr[N << 2];
Mar m[3], m_[3];
inline void pre(){
I.a[1][1] = 1,
I.a[1][2] = 0;
I.a[2][1] = 0;
I.a[2][2] = 1;
m[0].a[1][1] = 2,m[0].a[1][2] = 3;
m[0].a[2][1] = 5,m[0].a[2][2] = 7;
m[1].a[1][1] = 11,m[1].a[1][2] = 13;
m[1].a[2][1] = 17,m[1].a[2][2] = 19;
m[2].a[1][1] = 23,m[2].a[1][2] = 29;
m[2].a[2][1] = 31,m[2].a[2][2] = 37;
m_[0] = m[0].inv();
m_[1] = m[1].inv();
m_[2] = m[2].inv();
}
inline void input(){
cin>> n >> q;
for(int i = 1; i <= n; ++i){
cin>>x;
a[i] = x - '0';
}
}
inline void pd(int x){
cout<<"now:"<<'\n';
tr[x].now.print();
cout<<"nxt:"<<'\n';
tr[x].nxt.print();
cout<<"nnt:"<<'\n';
tr[x].nnt.print();
cout<<"tag:"<<'\n'<<tr[x].tag<<'\n';
}
inline void downdate(int x){
tr[x << 1].tag = (tr[x << 1].tag + tr[x].tag) % 3;
tr[x << 1 | 1].tag = (tr[x << 1 | 1].tag + tr[x].tag) % 3;
while(tr[x].tag > 0){
swap(tr[x << 1].now, tr[x << 1].nxt);
swap(tr[x << 1].nnt, tr[x << 1].nxt);
swap(tr[x << 1 | 1].now, tr[x << 1 | 1].nxt);
swap(tr[x << 1 | 1].nnt, tr[x << 1 | 1].nxt);
--tr[x].tag;
}
}
inline void pushup(int x){
tr[x].now = tr[x << 1].now * tr[x << 1 | 1].now;
tr[x].nxt = tr[x << 1].nxt * tr[x << 1 | 1].nxt;
tr[x].nnt = tr[x << 1].nnt * tr[x << 1 | 1].nnt;
}
inline void build(int x, int l, int r){
if(l == r){
if(l % 2){
tr[x].now = m_[a[l]];
tr[x].nxt = m_[(a[l] + 1) % 3];
tr[x].nnt = m_[(a[l] + 2) % 3];
}else{
tr[x].now = m[a[l]];
tr[x].nxt = m[(a[l] + 1) % 3];
tr[x].nnt = m[(a[l] + 2) % 3];
}
return ;
}
int mid = (l + r) >> 1;
build(x << 1, l, mid);
build(x << 1 | 1, mid + 1, r);
pushup(x);
}
inline void adtr(int x){
tr[x].tag = (tr[x].tag + 1) % 3;
swap(tr[x].now, tr[x].nxt);
swap(tr[x].nnt, tr[x].nxt);
}
inline void add(int x, int l, int r, int L, int R){
if(L <= l && r <= R){
adtr(x);
return ;
}
downdate(x);
int mid = (l + r) >> 1;
if(L <= mid) add(x << 1, l, mid, L, R);
if(R > mid) add(x << 1 | 1, mid + 1, r, L, R);
pushup(x);
}
inline Mar query(int x, int l, int r, int L, int R){
if(L <= l && r <= R){
// cout<<x<<'\n';
// tr[x].now.print();
return tr[x].now;
}
downdate(x);
int mid = (l + r) >> 1;
Mar ans = I;
if(L <= mid) ans = ans * query(x << 1, l, mid, L, R);
if(R > mid) ans = ans * query(x << 1 | 1,mid + 1, r, L, R);
return ans;
}
inline void op(){
build(1,1,n);
for(int i = 1; i <= q; ++i){
cin>> opt >> l >> r;
if(opt == 1){
add(1, 1, n, l, r);
}else if(opt == 2){
if(query(1, 1, n, l, r).check()){
cout<<"Yes"<<'\n';
}else{
cout<<"No"<<'\n';
}
}
}
}
int main(){
cin.tie(0)->sync_with_stdio(false);
pre();
input();
op();
return 0;
}