當學 Min 25 的一個前置知識。
演算法內容。
定義 \(S(n)=\sum_{i=1}^nf(i)\)。對於一個函式 \(g\),有:
所以如果存在函式 \(g\),滿足:
- \(g(1)\neq 0\)
- \(\sum_{i=1}^n(f\times g)(i)\) 可以快速計算
- \(g(i)\) 可以快速計算
可以透過記憶化搜尋+數論分塊快速計算 \(S(n)\)。可以用 unordered_map
儲存結果。
直接計算複雜度為 \(O(n^{\frac{3}{4}})\)。更好的計算是預處理前 \(O(n^{\frac{2}{3}})\) 的 \(S\) 值,可以做到 \(O(n^{\frac{2}{3}})\) 的複雜度。
具體證明可以參考 link。證明本質是積分,簡單的。
常用構建函式 \(g\) 技巧:
- \(\sum_{d|n}^n\mu(d)=[n=1]\)
- \(\sum_{d|n}^nu(d)\frac{n}{d}=\varphi(d)\)
- \(\sum_{d|n}^n\varphi(d)=n\)
- \(i^k·(\frac{n}{i})^k=n^k\)
例子:
\(S(n)=\sum_{i=1}^n\mu(i)\)。
令 \(g(n)=1\),也即常函式。
則 \(\sum_{i=1}^n(g\times \mu)(i)=[n\ge 1]\)
則
\(S(n)=\sum_{i=1}^n\varphi(i)\)
令 \(g(n)=1\)
則 \(\sum_{i=1}^n(g\times \varphi)(i)=\frac{n(n+1)}{2}\)。
則
實操
計算 \(\sum_{i=1}^n\sum_{j=1}^nij\gcd(i,j)\)
先化簡:
令 \(h(n)=\frac{n^2(n+1)^2}{4}\),\(f(n)=n^2\varphi(n)\)。
對於上式,後者可以數論分塊,問題化為求解 \(f\) 的字首和。
令 \(g(n)=n^2\)
則 \(\sum_{i=1}^n(f\times g)(i=\sum_{i=1}^n\sum_{d|i}d^2\varphi(d)\times \frac{i^2}{d^2}=\sum_{i=1}^ni^3=\frac{n^2(n+1)^2}{4}\)
則 \(g(1)=1\)
則 \(g(i)\) 可以快速計算。
所以 \(S_f(n)=\frac{n^2(n+1)^2}{4}-\sum_{i=2}^ni^2S_f(\lfloor\frac{n}{i}\rfloor)\)
數論分塊即可。
模版程式碼一份。
#include<bits/stdc++.h>
using namespace std;
#define N 1050500
#define int long long
const int it=1e6+7;
int v[N],pri[N],tot,p,phi[N],n,m,s[N],inv2,inv6,inv4;
int power(int a,int b){
int ans=1;
while(b){
if(b&1)ans=ans*a%p;
a=a*a%p;b>>=1;
}
return ans;
}
unordered_map<int,int>sit;
void init(){
phi[1]=1;
for(int i=2;i<it;i++){
if(!v[i]){
pri[++tot]=i;phi[i]=i-1;
}
for(int j=1;j<=tot&&i*pri[j]<it;++j){
v[pri[j]*i]=1;
if(i%pri[j]==0){
phi[pri[j]*i]=pri[j]*phi[i];break;
}
phi[pri[j]*i]=phi[pri[j]]*phi[i];
}
}
for(int i=1;i<it;++i)s[i]=(s[i-1]+i*i%p*phi[i]%p)%p;
}
int h(int n){
n%=p;
return n*n%p*(n+1)%p*(n+1)%p*inv4%p;
}
int pfs(int n){
n%=p;
return n*(n+1)%p*(n+n+1)%p*inv6%p;
}
int calc(int n){
if(n<it)return s[n];
if(sit[n]!=0)return sit[n];
int res=h(n);int lst=1,cur=0;
for(int l=2,r;l<=n;l=r+1){
r=min(n,n/(n/l));cur=pfs(r);
res=(res+p-(cur-lst)%p*calc(n/l)%p)%p;lst=cur;
}res=(res%p+p)%p;
return sit[n]=res;
}
signed main(){
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin>>p;cin>>n;inv2=power(2,p-2);inv4=power(4,p-2);inv6=power(6,p-2);
init();
int ans=0,lst=0;
for(int l=1,r;l<=n;l=r+1){
r=min(n,n/(n/l));int cur=calc(r);
ans+=(cur-lst)*h(n/l)%p;ans%=p;lst=cur;
}
ans=(ans%p+p)%p;
cout<<ans<<"\n";
}