Luogu P9606 CERC2019 ABB 題解 [ 綠 ] [ KMP ] [ 字串雜湊 ]

KS_Fszha發表於2024-12-10

ABB:KMP 的做法非常巧妙。

雜湊

思路

顯然正著做一遍雜湊,倒著做一遍雜湊,然後列舉迴文中心即可。

時間複雜度 \(O(n)\)

程式碼

#include <bits/stdc++.h>
#define fi first
#define se second
#define lc (p<<1)
#define rc ((p<<1)|1)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pi;
const int N=400005;
const ull base=13331;
ull phash[N],shash[N],pw[N];
int n,ans;
char s[N];
void dohash()
{
    pw[0]=1;
    for(int i=1;i<N;i++)pw[i]=pw[i-1]*base;
    for(int i=1;i<=n;i++)phash[i]=phash[i-1]*base+s[i];
    for(int i=n;i>=1;i--)shash[i]=shash[i+1]*base+s[i];
}
ull gethash(int op,int l,int r)
{
    if(op==0)return (phash[r]-phash[l-1]*pw[r-l+1]);
    return (shash[l]-shash[r+1]*pw[r-l+1]);
}
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>s+1;
    ans=n-1;
    if(n==1)
    {
        cout<<0;
        return 0;
    }
    dohash();
    for(int i=1;i<=n;i++)
    {
        if(i-1>=n-(i+1)+1)
        {
            int len=n-(i+1)+1;
            if(gethash(0,i-1-len+1,i-1)==gethash(1,i+1,i+1+len-1))
            {
                ans=min(ans,i-1-len);
            }
        }
        if(i>=n-(i+1)+1)
        {
            int len=n-(i+1)+1;
            if(gethash(0,i-len+1,i)==gethash(1,i+1,i+1+len-1))
            {
                ans=min(ans,i-len);
            }
        }
    }
    cout<<ans;
    return 0;
}

雜湊的程式碼又臭又長,肯定不是這題的最優解。

KMP

思路

我們考慮這題要求的本質是什麼,顯然是求字串 \(s\) 最長的迴文字尾的長度 \(l\),那麼答案就是 \(n-l\)。因為剩下的那些必須對稱過去。

那麼如何求最長的迴文字尾呢?我們可以先把字串 \(s\) 反轉為 \(s'\),把字尾轉化為字首便於處理。

由於迴文串正著讀和反著讀都一樣,可以發現迴文字尾滿足是字串 \(s'\) 的一段字首,也是字串 \(s\) 的一段字尾

因此如果字串 \(s'\) 的一段字首與字串 \(s\) 的一段字尾相等,就說明這是一個合法的迴文字尾。

如果要讓這個字尾最長,那麼顯然是 KMP 的 \(next_n\),就求出來了。

實現上,我們可以將最後的主串設定為 \(s'\) 與任意一個分隔符與 \(s\) 拼接起來的字串,然後做一遍 KMP 即可。

時間複雜度 \(O(n)\)

程式碼

非常簡短,注意字串開兩倍的長度。

#include <bits/stdc++.h>
using namespace std;
const int N=800005;
int n,ne[N];
char s[N];
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>s+n+2;
    s[n+1]='%';
    for(int i=1,j=2*n+1;i<=n;i++,j--)s[i]=s[j];
    int now=0;
    for(int i=2;i<=2*n+1;i++)
    {
        while(now&&s[now+1]!=s[i])now=ne[now];
        if(s[now+1]==s[i])now++;
        ne[i]=now;
    }
    cout<<n-ne[2*n+1];
    return 0;
}

相關文章