拉格朗日插值學習筆記
應用
眾所周知,在平面直角座標系中,對於任意的 \(n\) 個點,都一定有一個不超過 \(n-1\) 次的函式與之相對應。拉格朗日插值適用於求解這 \(n\) 個點對應的函式。
思路
考慮給定的 \(n\) 個點的座標表示為 \((x_i,y_i)\),不難構造出如下函式:
\[f(x)=\sum_{i=1}^{n}y_ig_i(x)
\]
那麼此時只需要構造出符合要求的 \(g_i\) 即可。不難發現,為了使 \(f\) 符合條件,\(g_i\) 應該滿足:
\[\forall n\in x,g_i(n)=\left\{\begin{matrix}1&n=x_i\\0&n\ne x_i\end{matrix}\right.
\]
其中,\(x\) 表示所有 \(x_i\) 構成的集合。不難發現,當 \(n\ne x_i\) 時函式值為 \(0\),說明該函式 \(g_i(x)\) 一定有因式 \((x-x_j)(i\ne j)\),因此不難看出 \(g_i(x)\) 應當有因式 \(\prod_{j\ne i}(x-x_j)\)。又因為當 \(n=x_i\) 時,函式值為 \(0\),不難構造出:
\[g_i(n)=\frac{\prod_{j\ne i}(n-x_j)}{\prod_{j\ne i}(x_i-x_j)}
\]
如此就構造出了符合條件的函式,此時可以 \(O(n^2)\) 求解。在單次求解的時候可以採用。
不妨進行進一步轉換,令 \(t_i=\frac{y_i}{\prod_{j\ne i}(x_i-x_j)}\),不難看出,\(f(x)=t_i\prod_{j\ne i}(x-x_j)\)。不妨計算出 \(\prod_{i=1}^{n}(x-x_i)\) 後將對應項除掉,接著利用多項式乘法算出每一次項前的係數後,就可以按照多項式加法的方式 \(O(n)\) 求解。這樣的方式適用於多組查詢或需要求出函式對應項係數的情況。
程式碼
\(O(n^2)\) 寫法
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=2e3+5,M=998244353;
int n,k;
int x[N],y[N];
ll QuickPow(ll a,int b){
ll res=1;
while(b>0){
if(b&1)res=res*a%M;
a=a*a%M;
b>>=1;
}
return res;
}
ll Lagrange(int x[],int y[],int X){
ll res=0;
for(int i=1;i<=n;i++){
ll l_i=1;
for(int j=1;j<=n;j++){
if(i!=j)l_i=l_i*(X-x[j])%M*QuickPow(x[i]-x[j],M-2)%M;
}
res=(res+l_i*y[i]%M)%M;
}
return res;
}
int main(){
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)scanf("%d%d",&x[i],&y[i]);
printf("%lld",(Lagrange(x,y,k)+M)%M);
return 0;
}
\(O(n)\) 做法
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=2e3+5,M=998244353;
int n,k;
ll t[N],x[N],y[N],g[N],fs[N],f[N];
ll QuickPow(ll a,int b){
ll res=1;
while(b>0){
if(b&1)res=res*a%M;
a=a*a%M;
b>>=1;
}
return res;
}
void Init(){
for(int i=0;i<n;i++){
t[i]=1;
for(int j=0;j<n;j++){
if(i==j)continue;
t[i]=t[i]*(x[i]-x[j])%M;
}
t[i]=QuickPow(t[i],M-2)*y[i]%M;
}
fs[0]=1;
for(int i=0;i<n;i++){
for(int j=n-1;j>0;j--)fs[j]=(fs[j-1]+fs[j]*(M-x[i])%M)%M;
fs[0]=fs[0]*(M-x[i])%M;
}
for(int i=0;i<n;i++){
ll inv=QuickPow(M-x[i],M-2);
g[0]=fs[0]*inv%M;
for(int j=1;j<n;j++)g[j]=(fs[j]-g[j-1])*inv%M;
for(int j=0;j<n;j++)f[j]=(f[j]+t[i]*g[j]%M)%M;
}
return ;
}
ll Lagrange(ll X){
ll res=0,Pow=1;
for(int i=0;i<n;i++){
res=(res+Pow*f[i]%M)%M;
Pow=Pow*X%M;
}
return res;
}
int main(){
scanf("%d%d",&n,&k);
for(int i=0;i<n;i++)scanf("%lld%lld",&x[i],&y[i]);
Init();
printf("%lld",(Lagrange(k)+M)%M);
return 0;
}