轉置原理小練習:Do Use FFT

ffffyc發表於2024-04-04

\(\text{Link}\)

題意

給定三個長為 \(n\) 的陣列 \(a_{0,\dots,n-1},b_{0,\dots,n-1},c_{0,\dots,n-1}\),對 \(\forall i\in[0,n-1]\) 求出:

\[d_i=\sum_{j=0}^{n-1}c_j\prod_{k=0}^i(a_j+b_k) \]

\(998244353\) 取模。

\(n\le 2.5\times 10^5\)

思路

\(a,b\) 看成常量,那麼 \(d\) 就是由 \(c\) 的線性變換得來,我們考慮其轉置:

\[c_j=\sum_{i=0}^{n-1}d_i\prod_{k=0}^i(a_j+b_k) \]

注意到 \(j\) 只出現一次,不妨令:

\[F(x)=\sum_{i=0}^{n-1}d_i\prod_{k=0}^i(x+b_k) \]

那麼顯然有:

\[c_j=F(a_j) \]

\(d\) 作為輸入,\(c\) 作為輸出,用分治 NTT 求出 \(F(x)\) 再多點求值得到 \(c\),便在 \(O(n\log^2n)\) 的時間複雜度內解決了轉置後的問題。由轉置原理,我們可以在同時間複雜度內求出原問題。

多點求值的轉置是老生常談了:

\[F(x)=\sum_{i=0}^{n-1}\frac{c_i}{1-a_ix} \]

分治 NTT 即可。

要寫出前面的分治 NTT 的轉置,我們把原演算法過程寫出來:

\[F_{l,r}(x)=\sum_{i=l}^rd_i\prod_{k=l}^i(x+b_k) \]

\[G_{l,r}(x)=\prod_{k=l}^r(x+b_k) \]

  1. 對於葉子結點:\(F_{i,i}=b_id_i+d_ix\)
  2. 對於非葉子結點:\(F_{l,r}=F_{l,mid}+F_{mid+1,r}\times G_{l,mid}\)\(G_{l,r}=G_{l,mid}\times G_{mid+1,r}\)

寫演算法轉置的基本步驟很簡單:

  1. 將流程翻轉;
  2. 將每一步基本運算轉置,其中最重要的就是將 \(a_i\) 乘以 \(v\) 加給 \(b_j\) 經轉置變為將 \(b_j\) 乘以 \(v\) 加給 \(a_j\);對於多項式也是如此:將 \(F\) 乘以 \(G\) 加給 \(H\) 經轉置變為將 \(H\) 轉置乘 \(G\) 加給 \(F\)

同時,需要注意分辨常量與變數,與輸入輸出無關的常量不需要參與轉置。不難發現 \(G\)\(c,d\) 均無關,在此演算法中屬於常量,故 \(G\) 的計算不需要轉置。

對於 \(F_{l,r}=F_{l,mid}+F_{mid+1,r}\times G_{l,mid}\),我們可以將其看成三步:\(F_{l,r}\gets 0\)\(F_{l,r}\gets F_{l,r}+F_{l,mid}\)\(F_{l,r}\gets F_{l,r}+F_{mid+1,r}\times G_{l,mid}\),於是該演算法的轉置也不難寫出:

  1. 對於非葉子結點:\(F_{l,mid}=F_{l,r}\)\(F_{mid+1,r}=F_{l,r}\times^T G_{l,mid}\),其中 \(F_{l,mid}\) 只需要保留 \(mid-l+1\) 次;
  2. 對於葉子結點:\(d_i=b_i[x^0]F_{i,i}+[x^1]F_{i,i}\)

於是在 \(O(n\log^2n)\) 時間複雜度內解決了原問題。

核心程式碼:

namespace MulTT{
	inline Poly MulT(const Poly &a,const Poly &b){
		Poly F=a,G=b;
		int n=a.size(),m=b.size();
		reverse(G.begin(),G.end());
		init(n);
		F.resize(lim),G.resize(lim);
		NTT(F,1),NTT(G,1);
		for(int i=0;i<lim;i++)
			G[i]=1ll*F[i]*G[i]%mod;
		NTT(G,-1);
		for(int i=m-1;i<n;i++)
			F[i-m+1]=G[i];
		F.resize(max(0,n-m+1));
		return F;
	}
}
using namespace MulTT;
#define PolyY vector<Poly>
inline PolyY operator*(const PolyY &a,const PolyY &b){
	int p=a[0].size(),q=b[0].size();
	PolyY F=a,G=b;
	init(p+q);
	for(int i=0;i<2;i++)
		F[i].resize(lim),G[i].resize(lim),
		NTT(F[i],1),NTT(G[i],1);
	for(int i=0;i<lim;i++)
		F[1][i]=(1ll*F[0][i]*G[1][i]+1ll*F[1][i]*G[0][i])%mod,
		F[0][i]=1ll*F[0][i]*G[0][i]%mod;
	for(int i=0;i<2;i++)
		NTT(F[i],-1),F[i].resize(p+q-1);
	return F;
}
#define ls (rt<<1)
#define rs (rt<<1|1)
int n,m;
Poly A,B,C,D,G[N];
inline void solve1(int rt,int l,int r){
	if(l==r){
		G[rt]={B[l],1};
		return ;
	}
	int mid=l+r>>1;
	solve1(ls,l,mid),solve1(rs,mid+1,r);
	G[rt]=G[ls]*G[rs];
}
inline PolyY solve2(int l,int r){
	if(l==r) return {{1,dec(0,A[l])},{C[l],0}};
	int mid=l+r>>1;
	return solve2(l,mid)*solve2(mid+1,r);
}
inline void solve3(int rt,int l,int r,Poly F){
	if(l==r){
		D[l]=add(F[1],1ll*F[0]*B[l]%mod);
		return ;
	}
	int mid=l+r>>1;
	Poly L=F;
	L.resize(mid-l+2);
	solve3(ls,l,mid,L);
	Poly R=MulT(F,G[ls]);
	solve3(rs,mid+1,r,R);
}
int main(){
	n=read();
	Prefix(n*2);
	for(int i=0;i<n;i++)
		A.push_back(read());
	for(int i=0;i<n;i++)
		B.push_back(read());
	for(int i=0;i<n;i++)
		C.push_back(read());
	D.resize(n);
	solve1(1,0,n-1);
	PolyY T=solve2(0,n-1);
	Poly F=T[1]*Inv(T[0]);
	F.resize(n+1);
	solve3(1,0,n-1,F);
	for(auto tmp:D)
		write(tmp),putc(' ');
	flush();
}

相關文章