[題解]AT_abc272_f [ABC272F] Two Strings

WaterSunHB發表於2024-06-23

思路

實際上對於一個字串 \(S\) 進行一個 \(f(S,x)\) 的操作本質上就是在 \(S + S\) 中擷取一段長度為 \(n\) 的子串。

於是你不難想到把 \(A,B\) 拼起來,形成一個字串 \(S = A + A + B + B\),然後比較字尾。你發現這是對的,因為兩個串的字典序大小從前往後比較的,因此你儘管是比較的字尾,本質上還是比較的兩個子串。

比較字尾大小,直接 SA 得出 \(rk\) 陣列。再次回到題目中,要求 \(f(A,i) \leq f(B,j)\) 的數量,你發現這個條件等效於比較兩個字尾 \(rk\) 的大小。

但是直接列舉 \(i,j\) 是不現實的,但是你可以先將 \(f(A,i)\)\(rk\) 值丟到一個桶裡面,然後對於所有的 \(f(B,j)\) 對答案的貢獻就是在桶中 \(rk\) 小於等於它本身 \(rk\) 的值的數量。因此對桶做一次字首和即可。

但是本題中有一個坑點在與字串應當在 \(A,B\) 字串交界處插入一個極小的字元,在最後插入一個極大的字元。

中間插入極小字元是因為當 \(A = B\) 時,\(B_1\) 就會影響兩個字尾的大小關係;最後插入極大字元的原因是,本題求的是小於等於的情況,因此當兩個子串相同時也應該記錄答案,但是在 SA 中,\(rk\) 的大小就會亂飄,為了使我們可以透過上述方法求答案我們應該保證當 \(f(A,i) = f(B,j)\) 時,後者排在後面。

Code

#include <bits/stdc++.h>  
#define re register  
#define int long long  
  
using namespace std;  
  
const int N = 1e6 + 10;  
int len,n,m = 250,ans;  
int sum[N];  
int sa[N],rk[N],prk[N],tmp[N],cnt[N];  
char a[N],b[N],s[N];  
  
inline int read(){  
    int r = 0,w = 1;  
    char c = getchar();  
    while (c < '0' || c > '9'){  
        if (c == '-') w = -1;  
        c = getchar();  
    }  
    while (c >= '0' && c <= '9'){  
        r = (r << 3) + (r << 1) + (c ^ 48);  
        c = getchar();  
    }  
    return r * w;  
}  
  
signed main(){  
    len = read();  
    scanf("%s%s",a + 1,b + 1);  
    for (re int i = 1;i <= len;i++) s[++n] = a[i];  
    for (re int i = 1;i <= len;i++) s[++n] = a[i];  
    s[++n] = '&';  
    for (re int i = 1;i <= len;i++) s[++n] = b[i];  
    for (re int i = 1;i <= len;i++) s[++n] = b[i];  
    s[++n] = '|';  
    for (re int i = 1;i <= n;i++){  
        rk[i] = s[i];  
        cnt[rk[i]]++;  
    }  
    for (re int i = 1;i <= m;i++) cnt[i] += cnt[i - 1];  
    for (re int i = n;i;i--) sa[cnt[rk[i]]--] = i;  
    for (re int w = 1;w < n;w <<= 1){  
        int num = 0,p = 0;  
        for (re int i = n - w + 1;i <= n;i++) tmp[++num] = i;  
        for (re int i = 1;i <= n;i++){  
            if (sa[i] > w) tmp[++num] = sa[i] - w;  
        }  
        for (re int i = 1;i <= m;i++) cnt[i] = 0;  
        for (re int i = 1;i <= n;i++) cnt[rk[i]]++;  
        for (re int i = 1;i <= m;i++) cnt[i] += cnt[i - 1];  
        for (re int i = n;i;i--) sa[cnt[rk[tmp[i]]]--] = tmp[i];  
        for (re int i = 1;i <= n;i++) prk[i] = rk[i];  
        for (re int i = 1;i <= n;i++){  
            if (prk[sa[i]] == prk[sa[i - 1]] && prk[sa[i] + w] == prk[sa[i - 1] + w]) rk[sa[i]] = p;  
            else rk[sa[i]] = ++p;  
        }  
        if (p == n) break;  
        m = p;  
    }  
    for (re int i = 1;i <= len;i++) sum[rk[i]]++;  
    for (re int i = 1;i <= 1e6;i++) sum[i] += sum[i - 1];  
    for (re int i = 2 * len + 2;i <= 3 * len + 1;i++) ans += sum[rk[i]];  
    printf("%lld",ans);  
    return 0;  
}  

相關文章