NTT任意模數模板(+O(1)快速乘)

Self-Discipline發表於2018-09-03

 

NTT任意模數的方法其實有點取巧。

兩個數列每個有n個數,每個數的大小最多是10^9。

如果沒有模數,那麼卷積過後每個位置的答案一定小於10^9*10^9*n,差不多是10^24左右

那麼就有一個神奇的做法,選3個乘積大於10^24的NTT模數,分別做一次,得到每個位上模意義下的答案,

然後用中國剩餘定理得到模上三個質數乘積的答案。

因為答案顯然小於三個質數乘積,那麼模上三個質數乘積的答案就是這個數應該的值。

不過這個值可能會超long long(及時不超,對於乘積大於long long的三個質數做中國剩餘定理也不是一件小事)

考慮先將兩個模數意義下的答案合併,

現在我們還剩兩個模數,一個為long long,一個為int

不能中國剩餘定理硬上了。

設模數為P1(longlong) ,P2(int), 餘數為a1,a2

設答案ANS=P1*K+a1=P2*Q+a2

那麼K*P1=P2*Q+(a2-a1)

K*P1 % P2=a2-a1

a1-a2為常數

用同餘方程的解法即可解出K模P2(int)意義下的值

又有ANS<P1*P2(之前已證)

so K*P1+a1<P1*P2

顯然K<P2

所以原本答案K的值只能為模P2意義下的值

所以我們就求出K了,然後可以不用高精度就算出ANS%MOD(MOD為任意模數)

但是,

回顧整個過程,附加條件非常多。。。。。。

首先每個數<=10^9(再大或許可以通過增加模數的方法解決,但是CRT時可就不能迴避高精度取模了,常數捉急)

然後K的值必須為非負.(如果為負數那麼就有兩個可能的答案了,這是你用第一條性質怎樣都無法迴避的)

其次你需要解決兩個long long相乘mod long long

用二分乘法會T(常數啊),可以用接近作弊的O(1)long long乘法取模:

//O(1)快速乘
LL mul(LL a,LL b,LL P)
{
	a=(a%P+P)%P,b=(b%P+P)%P;
	return ((a*b-(LL)((long double)a/P*b+1e-6)*P)%P+P)%P;
}

主要原理是mod後的答案用公式 A % B=A-floor(A/B)*B來算。

注意其中A和floor(A/B)*B都是可能爆long long的。

但是因為減法,所以無論是兩個都不溢位還是兩個都溢位亦或是一個溢位另一個不溢位,都沒有關係。。。。。。。

(好像2009集訓隊論文中關於底層優化的那篇上有)

 

模板:

#include<cstdio>
#include<cstring>
#include<cctype>
//#include<ctime>
#include<algorithm>
#define maxn 300005
#define LL long long
#define RealMod 1000000007
using namespace std;
 
int n,m;
int P[3]={998244353,1004535809,469762049},G[3]={3,3,3},invG[3],wn[3][2][maxn];
 
inline int pow(int base,int k,int P)
{
	int ret=1;
	for(;k;k>>=1,base=1ll*base*base%P) if(k&1) ret=1ll*ret*base%P;
	return ret;
}
 
inline void Prework(int id)
{
	invG[id]=pow(G[id],P[id]-2,P[id]);
	for(int i=1;i<24;i++) wn[id][1][i]=pow(G[id],(P[id]-1)/(1<<i),P[id]),wn[id][0][i]=pow(invG[id],(P[id]-1)/(1<<i),P[id]);
}
 
inline void NTT(int *A,int n,int typ,int id)
{
	for(int i=0,j=0,k;i<n;i++)
	{
		if(i<j) swap(A[i],A[j]);
		for(k=n>>1;k;k>>=1) if((j^=k)>=k) break;
	}
 
	for(int i=1,j,k,len,w,x,y;1<<i<=n;i++)	
	{
		len=1<<(i-1);
		for(j=0,w=1;j<n;j+=1<<i,w=1)
			for(k=0;k<len;k++,w=1ll*w*wn[id][typ][i]%P[id])
			{
				x=A[j+k],y=1ll*A[j+k+len]*w%P[id];
				A[j+k]=(x+y)%P[id];
				A[j+k+len]=(x-y+P[id])%P[id];
			}
	}
	
	if(typ==0)
		for(int i=0,inv=pow(n,P[id]-2,P[id]);i<n;i++)
			A[i]=1ll*A[i]*inv%P[id];
}
 
void mul(int *ret,int *A,int lena,int *B,int lenb,int id)
{
	static int seq1[maxn],seq2[maxn];
	int n=1;for(;n<=lena+lenb;n<<=1);
	for(int i=0;i<n;i++) 
	{
		if(i<=lena) seq1[i]=A[i]; else seq1[i]=0;
		if(i<=lenb) seq2[i]=B[i]; else seq2[i]=0;
	}
	NTT(seq1,n,1,id);NTT(seq2,n,1,id);
	for(int i=0;i<n;i++) ret[i]=1ll*seq1[i]*seq2[i]%P[id];
	NTT(ret,n,0,id);
}
 
//O(1)快速乘
LL mul(LL a,LL b,LL P)
{
	a=(a%P+P)%P,b=(b%P+P)%P;
	return ((a*b-(LL)((long double)a/P*b+1e-6)*P)%P+P)%P;
}
/*
long long mul (long long a, long long b, long long mod) {
	a%=mod,b%=mod;
    if (b == 0)
        return 0;
    long long ans = mul (a, b>>1, mod);
    ans = ans*2%mod;
    if (b&1) ans += a, ans %= mod;
    return (ans+mod)%mod;
}*/
 
 
int a[maxn],b[maxn],c[3][maxn];
LL Mod=1ll*P[0]*P[1];
LL inv1=pow(P[0],P[1]-2,P[1]),inv2=pow(P[1],P[0]-2,P[0]),inv=pow(Mod%P[2],P[2]-2,P[2]);
inline void solve(int i)
{
	LL C=(mul(1ll*c[0][i]*P[1]%Mod,inv2,Mod)+mul(1ll*c[1][i]*P[0]%Mod,inv1,Mod))%Mod;
	LL K=1ll*((1ll*c[2][i]-(C%P[2]))%P[2])*(inv%P[2])%P[2];
	c[0][i]=(((K%RealMod+RealMod)*(Mod%RealMod)%RealMod+C)%RealMod);
}
 
int main()
{
	
	//freopen("1.in","r",stdin);
	
	//int t1=clock();
	
	scanf("%d%d",&n,&m);
	for(int i=0;i<=n;i++) scanf("%d",&a[i]);
	for(int j=0;j<=m;j++) scanf("%d",&b[j]);
	for(int i=0;i<3;i++) Prework(i),mul(c[i],a,n,b,m,i);
	for(int i=0;i<=n+m;i++) 
		solve(i); 
	for(int i=0;i<n+m;i++)
	{
		int tmp=(c[0][i]+RealMod)%RealMod;
		printf("%d ",tmp);
	}
	printf("%d\n",(c[0][n+m]+RealMod)%RealMod);
	
	//printf("%d\n",clock()-t1);
}

原文:https://blog.csdn.net/qq_35950004/article/details/79477797

相關文章