思路
實際上對於一個字串 \(S\) 進行一個 \(f(S,x)\) 的操作本質上就是在 \(S + S\) 中擷取一段長度為 \(n\) 的子串。
於是你不難想到把 \(A,B\) 拼起來,形成一個字串 \(S = A + A + B + B\),然後比較字尾。你發現這是對的,因為兩個串的字典序大小從前往後比較的,因此你儘管是比較的字尾,本質上還是比較的兩個子串。
比較字尾大小,直接 SA 得出 \(rk\) 陣列。再次回到題目中,要求 \(f(A,i) \leq f(B,j)\) 的數量,你發現這個條件等效於比較兩個字尾 \(rk\) 的大小。
但是直接列舉 \(i,j\) 是不現實的,但是你可以先將 \(f(A,i)\) 的 \(rk\) 值丟到一個桶裡面,然後對於所有的 \(f(B,j)\) 對答案的貢獻就是在桶中 \(rk\) 小於等於它本身 \(rk\) 的值的數量。因此對桶做一次字首和即可。
但是本題中有一個坑點在與字串應當在 \(A,B\) 字串交界處插入一個極小的字元,在最後插入一個極大的字元。
中間插入極小字元是因為當 \(A = B\) 時,\(B_1\) 就會影響兩個字尾的大小關係;最後插入極大字元的原因是,本題求的是小於等於的情況,因此當兩個子串相同時也應該記錄答案,但是在 SA 中,\(rk\) 的大小就會亂飄,為了使我們可以透過上述方法求答案我們應該保證當 \(f(A,i) = f(B,j)\) 時,後者排在後面。
Code
#include <bits/stdc++.h>
#define re register
#define int long long
using namespace std;
const int N = 1e6 + 10;
int len,n,m = 250,ans;
int sum[N];
int sa[N],rk[N],prk[N],tmp[N],cnt[N];
char a[N],b[N],s[N];
inline int read(){
int r = 0,w = 1;
char c = getchar();
while (c < '0' || c > '9'){
if (c == '-') w = -1;
c = getchar();
}
while (c >= '0' && c <= '9'){
r = (r << 3) + (r << 1) + (c ^ 48);
c = getchar();
}
return r * w;
}
signed main(){
len = read();
scanf("%s%s",a + 1,b + 1);
for (re int i = 1;i <= len;i++) s[++n] = a[i];
for (re int i = 1;i <= len;i++) s[++n] = a[i];
s[++n] = '&';
for (re int i = 1;i <= len;i++) s[++n] = b[i];
for (re int i = 1;i <= len;i++) s[++n] = b[i];
s[++n] = '|';
for (re int i = 1;i <= n;i++){
rk[i] = s[i];
cnt[rk[i]]++;
}
for (re int i = 1;i <= m;i++) cnt[i] += cnt[i - 1];
for (re int i = n;i;i--) sa[cnt[rk[i]]--] = i;
for (re int w = 1;w < n;w <<= 1){
int num = 0,p = 0;
for (re int i = n - w + 1;i <= n;i++) tmp[++num] = i;
for (re int i = 1;i <= n;i++){
if (sa[i] > w) tmp[++num] = sa[i] - w;
}
for (re int i = 1;i <= m;i++) cnt[i] = 0;
for (re int i = 1;i <= n;i++) cnt[rk[i]]++;
for (re int i = 1;i <= m;i++) cnt[i] += cnt[i - 1];
for (re int i = n;i;i--) sa[cnt[rk[tmp[i]]]--] = tmp[i];
for (re int i = 1;i <= n;i++) prk[i] = rk[i];
for (re int i = 1;i <= n;i++){
if (prk[sa[i]] == prk[sa[i - 1]] && prk[sa[i] + w] == prk[sa[i - 1] + w]) rk[sa[i]] = p;
else rk[sa[i]] = ++p;
}
if (p == n) break;
m = p;
}
for (re int i = 1;i <= len;i++) sum[rk[i]]++;
for (re int i = 1;i <= 1e6;i++) sum[i] += sum[i - 1];
for (re int i = 2 * len + 2;i <= 3 * len + 1;i++) ans += sum[rk[i]];
printf("%lld",ans);
return 0;
}