動態dp & 矩陣加速遞推

luqyou發表於2024-08-19

廣義矩陣乘法

我們定義兩個矩陣 \(A,B\) 在廣義矩陣乘法下的乘積為 \(C\),其中

\[C = \begin{bmatrix} \max\limits_{i=1}\limits^{m} A_{1,i} + B_{i,1} & \max\limits_{i=1}\limits^{m} A_{1,i} + B_{i,2} & \dots & \max\limits_{i=1}\limits^{m} A_{1,i} + B_{i,k} \\\ \sum\limits_{i=1}\limits^{m} A_{2,i} + B_{i,1} & \max\limits_{i=1}^{m} A_{2,i} + B_{i,2} & \dots & \max\limits_{i=1}\limits^{m} A_{2,i} + B_{i,k} \\\ \vdots & \vdots & \ddots & \vdots \\\ \max\limits_{i=1}\limits^{m} A_{n,i} + B_{i,1} & \max\limits_{i=1}\limits^{m} A_{n,i} + B_{i,2} & \dots & \max\limits_{i=1}\limits^{m} A_{n,i}+ B_{i,k}\end{bmatrix}\]

這麼定義矩陣乘法是為了改寫某些 DP 柿子。不難發現這個乘法依然具有結合律。

動態 dp

引入

有一個序列 \(a\),你可以在其中選擇一些數,但是你不能選擇相鄰的兩個數,求你能選出的數的總和最大是多少。

我們令 \(dp_{i,0/1}\) 為考慮前 \(i\) 個數,選或不選第 \(i\) 個數的最大和。

不難得到 \(dp_{i,0}=\max(dp_{i-1,0},dp_{i-1,1}),dp_{i,1}=dp_{i-1,0}+a_i\)

那麼這跟我們上面所說的矩陣乘法有什麼關係呢?

我們將 dp 式改寫一下:

\[dp_{i,0}=\max(dp_{i-1,0}+0,dp_{i-1,1}+0) \]

\[dp_{i,1}=\max(dp_{i-1,0}+a_i,dp_{i-1,1} - \infty) \]

現在是不是和上述的廣義矩陣乘法很像了?我們將 dp 繼續改寫為矩陣乘的形式:

\[\begin{bmatrix} dp_{i-1,0} & dp_{i-1,1} \end{bmatrix} \times \begin{bmatrix} 0 & a_i \\\ 0 & - \infty \end{bmatrix} = \begin{bmatrix} dp_{i,0} & dp_{i,1} \end{bmatrix} \]

由於矩陣乘具有結合律,所以我們現在可以將 dp 結果寫成一系列矩陣連乘的結果了!

但是我們這麼做卻不是為了最佳化時間複雜度,而是為了:

帶修

如果我們將引入題改一下,增加 \(m\) 次單點修改 \(a_i\) 的值,怎麼做?

我們只需要使用線段樹維護上述的矩陣乘法即可!!!

例題 Gym102644H String Mood Updates

首先我們考慮暴力 dp。

我們令 \(dp_{i,0/1}\) 代表讀完前 \(i\) 個字元後當前狀態為 \(0/1\) 的方案數。轉移有:

\[\begin{cases}dp_{i,0}=dp_{i-1,0}+dp_{i-1,1},dp_{i,1}=0 \space\space s_i=\texttt{S,D} \\\ dp_{i,0}=0,dp_{i,1}=dp_{i-1,0}+dp_{i-1,1} \space\space s_i=\texttt{H} \\\ dp_{i,0}=dp_{i-1,1},dp_{i,1}=dp_{i-1,0} \space\space s_i=\texttt{A,E,I,O,U} \\\ dp_{i,0}=20 \times dp_{i-1,0} + 7 \times dp_{i-1,1},dp_{i-1,1}=6 \times dp_{i-1,0} + 19 \times dp_{i-1,1} \space\space s_i=\texttt{?} \\\ dp_{i,0}=dp_{i-1,0},dp_{i,1}=dp_{i-1,1} \space\space \operatorname{otherwise}\end{cases} \]

帶修的話,根據這個狀態轉移,構建出轉移矩陣,並用線段樹維護即可。

#include<bits/stdc++.h>
#define int long long
using namespace std;
#define fi first
#define sc second
#define pii pair<int,int>
#define pb push_back
const int maxn=2e5+10;
const int mod=1e9+7;
int n,m,dp[maxn][2];
string s;
struct mat{
	int n,a[3][3];
	void init(int x){
		for(int i=1;i<=n;i++){
			for(int j=1;j<=n;j++) a[i][j]=x;
		}
	}
	void getI(){
		init(0);
		for(int i=1;i<=n;i++) a[i][i]=1;
	}
	mat operator *(mat x){
		mat ans;
		ans.n=n,ans.init(0);
		for(int i=1;i<=n;i++){
			for(int j=1;j<=n;j++){
				for(int k=1;k<=n;k++){
					ans.a[i][j]=(ans.a[i][j]+a[i][k]*x.a[k][j])%mod;
				}
			}
		}
		return ans;
	} 
	void output(){
		for(int i=1;i<=n;i++,cout<<endl){
			for(int j=1;j<=n;j++) cout<<a[i][j]<<" ";
		}
	}
};
struct node{
	int l,r;
	mat mt;
	node operator +(node x){
		if(l==-1) return x;
		if(x.l==-1) return (*this);
		node res;
		res.l=l,res.r=x.r,res.mt=(mt*x.mt);
		return res;
	}
	void debug(){
		cout<<l<<" "<<r<<endl;
		cout<<"MATRIX:"<<endl;
		mt.output();
		cout<<endl;
	}
}tr[maxn*4];
int ls(int u){
	return (u<<1);
}
int rs(int u){
	return (u<<1)|1;
}
bool ir(int L,int R,int l,int r){
	return (L<=l)&&(r<=R);
}
bool ofr(int L,int R,int l,int r){
	return (R<l)||(r<L);
}
void pushup(int u){
	if(tr[u].l==tr[u].r) return ;
	tr[u]=tr[ls(u)]+tr[rs(u)];
}
void build(int u,int l,int r){
	tr[u].l=l,tr[u].r=r,tr[u].mt.n=2;
	if(l==r){
		if(s[l]=='S'||s[l]=='D'){
			tr[u].mt.a[1][1]=1,tr[u].mt.a[1][2]=0;
			tr[u].mt.a[2][1]=1,tr[u].mt.a[2][2]=0;
		}
		else if(s[l]=='H'){
			tr[u].mt.a[1][1]=0,tr[u].mt.a[1][2]=1;
			tr[u].mt.a[2][1]=0,tr[u].mt.a[2][2]=1;
		}
		else if(s[l]=='A'||s[l]=='E'||s[l]=='I'||s[l]=='O'||s[l]=='U'){
			tr[u].mt.a[1][1]=0,tr[u].mt.a[1][2]=1;
			tr[u].mt.a[2][1]=1,tr[u].mt.a[2][2]=0;
		}
		else if(s[l]=='?'){
			tr[u].mt.a[1][1]=20,tr[u].mt.a[1][2]=6;
			tr[u].mt.a[2][1]=7,tr[u].mt.a[2][2]=19;
		}
		else{
			tr[u].mt.a[1][1]=1,tr[u].mt.a[1][2]=0;
			tr[u].mt.a[2][1]=0,tr[u].mt.a[2][2]=1;
		}
		return ;
	}
	int mid=(l+r)>>1;
	build(ls(u),l,mid),build(rs(u),mid+1,r),pushup(u);
}
void upd(int u,int x,char k){
	if(tr[u].l==tr[u].r){
		if(k=='S'||k=='D'){
			tr[u].mt.a[1][1]=1,tr[u].mt.a[1][2]=0;
			tr[u].mt.a[2][1]=1,tr[u].mt.a[2][2]=0;
		}
		else if(k=='H'){
			tr[u].mt.a[1][1]=0,tr[u].mt.a[1][2]=1;
			tr[u].mt.a[2][1]=0,tr[u].mt.a[2][2]=1;
		}
		else if(k=='A'||k=='E'||k=='I'||k=='O'||k=='U'){
			tr[u].mt.a[1][1]=0,tr[u].mt.a[1][2]=1;
			tr[u].mt.a[2][1]=1,tr[u].mt.a[2][2]=0;
		}
		else if(k=='?'){
			tr[u].mt.a[1][1]=20,tr[u].mt.a[1][2]=6;
			tr[u].mt.a[2][1]=7,tr[u].mt.a[2][2]=19;
		}
		else{
			tr[u].mt.a[1][1]=1,tr[u].mt.a[1][2]=0;
			tr[u].mt.a[2][1]=0,tr[u].mt.a[2][2]=1;
		}
		return ;
	}
	int mid=(tr[u].l+tr[u].r)>>1;
	if(x<=mid) upd(ls(u),x,k);
	else upd(rs(u),x,k);
	pushup(u);
}
void solve(){
	cin>>n>>m>>s,s=" "+s;
	build(1,1,n);
//	tr[1].mt.output();
	cout<<tr[1].mt.a[2][2]<<endl;
	while(m--){
		int x;
		char c;
		cin>>x>>c,upd(1,x,c);
//		tr[1].mt.output();
		cout<<tr[1].mt.a[2][2]<<endl;
//		for(int i=1;i<n*2;i++) tr[i].debug();
//		cout<<endl;
	}
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    int t=1;
//    cin>>t;
    while(t--) solve();
    return 0;
}
/*
Samples
input:

output:

THINGS TODO:
檢查freopen,尤其是字尾名
檢查空間
檢查除錯語句是否全部註釋
*/

相關文章