多項式與點值的雙射 與 Reed–Solomon 編碼糾錯.

QedDust發表於2024-06-14

其實早就知道啊,不過apio t3之後還是在皮皮橙大神的指導下認真看了看.

放一個 $O(n^2)$ 的實現
#include <bits/stdc++.h>
using u32=unsigned;
using i64=long long;
using u64=unsigned long long;
using idt=std::size_t;
constexpr u32 mod=998244353;
constexpr u32 mul(u32 x,u32 y){return u64(x)*y%mod;}
constexpr u32 shrk(u32 x){return std::min(x,x-mod);}
constexpr u32 dilt(u32 x){return std::min(x,x+mod);}
constexpr u32 qpw(u32 a,u32 b,u32 r=1){
    for(;b;b>>=1,a=mul(a,a)){
        if(b&1){
            r=mul(r,a);
        }
    }
    return r;
}
using poly=std::vector<u32>;
std::tuple<poly,poly> intpol(const poly&x,const poly&y){
    idt n=x.size();
    poly base(n+1),res(n);
    base[0]=1;
    for(idt i=0;i<n;++i){
        u32 _xi=mod-x[i];
		for(idt j=i+1;j;--j){
			base[j]=(base[j-1]+u64(base[j])*_xi)%mod;
		}
		base[0]=mul(base[0],_xi);
    }
    for(idt i=0;i<n;++i){
		u32 fx=1,hi=base.back();
		for(idt j=0;j<n;++j){
			if(i!=j){
				fx=mul(fx,x[i]-x[j]+mod);
			}
		}
		fx=qpw(fx,mod-2,y[i]);
		for(idt j=n-1;~j;--j){
			res[j]=(u64(hi)*fx+res[j])%mod;
			hi=(base[j]+u64(hi)*x[i])%mod;
		}
	}
    while(!res.empty()&&res.back()==0){res.pop_back();}
	return {res,base};
}
poly mul(const poly&f,const poly&g){
	idt n=f.size(),m=g.size();
	poly r(n+m-1);
	for(idt i=0;i<n;++i){
		for(idt j=0;j<m;++j){
			r[i+j]=(r[i+j]+u64(f[i])*g[j])%mod;
		}
	}
	return r;
}
inline idt deg(const poly&f){
    return f.size()-1;
}
struct matp{
	poly a00,a01,a10,a11;
    void lmul_reg(const poly&p,idt n){
		a00.swap(a10),a01.swap(a11);
		a10.resize(deg(a00)+n),a11.resize(deg(a01)+n);
		for(idt i=0;i<n;++i){
			for(idt j=0;j<a00.size();++j){
				a10[i+j]=(a10[i+j]+u64(p[i])*a00[j])%mod;
			}
			for(idt j=0;j<a01.size();++j){
				a11[i+j]=(a11[i+j]+u64(p[i])*a01[j])%mod;
			}
		}   
	}
};
//r,x,y
std::tuple<poly,poly,poly> extgcd(const poly&a,const poly&b,idt k){
    idt n=a.size(),d=n-1,m=b.size(),thr=d-k;
    matp res={{1},{},{},{1}};
    poly nq(n),P=a,Q=b;
    while(m>thr){
		idt u=m-1;
		for(idt i=n-m;~i;--i){
			nq[i]=qpw(mod-Q[u],mod-2,P[i+u]),P[i+u]=0;
			for(idt j=u-1;~j;--j){
				P[i+j]=(P[i+j]+u64(nq[i])*Q[j])%mod;
			}
		}
		res.lmul_reg(nq,n-m+1),n=u;
        while(n>0&&P[n-1]==0){--n;}
		std::swap(n,m),std::swap(P,Q);
	}
    Q.resize(m);
    return {Q,res.a10,res.a11};
}
poly pdiv(poly f,poly g){
    idt n=f.size(),m=g.size(),u=m-1;
    poly q(n-m+1);
    for(idt i=n-m;~i;--i){
        q[i]=qpw(g[u],mod-2,f[i+u]),f[i+u]=0;
        for(idt j=u-1;~j;--j){
            f[i+j]=dilt(f[i+j]-mul(q[i],g[j]));		
        }
    }
    return q;
}
poly RS(const poly&x,const poly&y,idt k){
    auto [g,t]=intpol(x,y);
    auto [fr,z,r]=extgcd(t,g,k);
    //fr = zt+gr
    return pdiv(fr,r);
}

using std::cin;
using std::cout;
void solve(){
    idt n,k;
    cin>>n>>k;
    poly x(n),y(n);
    for(auto&z:x){cin>>z;}
    for(auto&z:y){cin>>z;}
    auto f=RS(x,y,k);
    for(auto z:f){
        cout<<z<<" ";
    }
}
/*
8 0
4 15 5 20 2 6 12 16
25 256 36 441 9 49 169 289 

8 1
4 15 5 20 2 6 12 16
25 256 36 441 9 49 169 288

8 2
4 15 5 20 2 6 12 16
25 256 36 441 90 49 169 200
*/
int main(){
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    solve();
    return 0;
}

相關文章