Manacher 演算法

zuoqingyuan111發表於2024-08-19

演算法介紹

\(\text{Manacher}\) 演算法(又名馬拉車),是一種常用於處理迴文字串的演算法。其程式碼量很小,卻可以在 \(O(n)\) 的時間複雜度內處理問題

演算法思想

和其他大多數演算法一樣,\(\text{Manacher}\) 演算法利用現有的資訊獲得下一部分的資訊。

經典例題:給定一個字串 \(s\)。求出其最長迴文子串

\(|s|\le 1.1\times 10^7\)

該題暴力演算法法是,列舉 \(s\) 的任意一個子串 \(s_{l,r}\),判斷其是否是迴文串,如果是,就用其更新答案。在不加任何最佳化的前提下演算法時間複雜度 \(O(n^3)\)

我們思考一個迴文串的性質。對於一個迴文串,其必然有一個迴文中心,滿足這個迴文串關於迴文中心對稱。迴文中心到字串兩端的距離成為迴文半徑。迴文中心分為兩類,分別是在一個字元上:例如 \(\texttt{ababa}\) 的迴文中心在中間的 \(\texttt{a}\) 上;在者是兩個字元的間隙之間,例如 \(\texttt{abba}\) 的迴文中心在兩個 \(\texttt{b}\) 中間。為了將問題普遍化,我們在一個字串的任意一個字元旁插入一個不存在於這個字串的字元。例如 \(\texttt{ababa}\) 轉化為 \(\texttt{\#b\#a\#b\#a\#}\)\(\texttt{abba}\) 轉化為 \(\texttt{\#a\#b\#b\#a\#}\) 這樣,我們把兩類問題都轉換為一類問題。下面討論的均是改變後的字串

對於任何一個迴文串,迴文中心都只有一個。我們列舉迴文中心,定義 \(P_i\) 表示以 \(i\) 為中心的最長迴文半徑,定義 \(l,r,mid\) 分別為當前匹配出最靠右的迴文串 \(p\) 的左右邊界和迴文中心。

演算法的執行,我們分兩種情況

  1. \(i>r\):因為 \(i\) 位於我們最靠右迴文串邊界之外。無法利用已知資訊求解。暴力列舉其迴文半徑,找出 \(P_i\)

  2. \(i\le r\):因為 \(i\) 位於一個迴文串內,迴文中心為 \(mid\)。我們再分兩種情況討論

  • \(i\) 為迴文中心的最長迴文子串 \(t\) 完全包含在 \(p\)。因為迴文串的性質。\(mid\) 左邊必然有一個迴文串關於 \(t\) 對應,記為 \(t'\)。因為迴文串只有一個迴文中心。所以 \(t'\) 的迴文中心與 \(t\) 的迴文中心關於 \(mid\) 對稱。我們用 \(j\) 表示 \(t'\) 的迴文中心。則 \(j=mid\times 2-i\)。也可以得到 \(P_i=P_{mid\times 2-i}\)

  • \(i\) 為迴文中心的最長迴文子串 \(t\) 不完全包含在 \(p\)。但 \(t\)\(p\) 中迴文的部分一定和 \(t'\) 對應。且 \(P_j\ge j-l+1\)。根據迴文的對稱性,雖然無法保證 \(P_i=P_j\)。但可以保證 \(P_i\ge j-l+1=r-i+1\)。找到 \(P_i\) 的最小值後無法保證迴文半徑最大,繼續暴力列舉

最後,在計算完 \(P_i\) 的值後,記得更新 \(l,r,mid\)

時間複雜度分析:

雖然有很多暴力的擴充,但每次擴充後,\(r\) 都會更新,且每次更新都是往後移。移動的次數都是暴力列舉的次數。所以 \(r\) 至多被列舉到右邊界,時間複雜度 \(O(n)\)

答案計算

最後還有答案計算,我們令操作中,插入了輔助字元的字串為 \(s\)。原串為 \(s'\)結論:則原串中最長迴文串的長度為 \(\max P_i-1\)。在 \(s'\) 中的最長迴文子串長度為 \(2P_i-1\)(減去中間迴文中心的重疊部分),其中每個字元的前面都對應一個輔助字元(即字元 \(\texttt{\#}\) 之類),最後還多出一個輔助字元。所以原串中最長迴文子串的長度為 \(\dfrac{2P_i-2}{2}=P_i-1\)

例題:

\(1\)【模板】manacher

模板題

點選檢視程式碼
#include <iostream>
#include <cstdio>
using namespace std;
const int N=3e7+10;
char s[N];
int p[N],mid,r=-1,n,ans=0;
void read(){
    char ch=getchar();
    s[n]='~',s[++n]='#';//首段加上一個‘~’是為了處理越界情況
    while(ch>='a'&&ch<='z')s[++n]=ch,s[++n]='#',ch=getchar();
    s[++n]='#';
    return;
}
int main(){
    read();
    for(int i=1;i<=n;i++){
        if(i>r)p[i]=1;
        else p[i]=min(p[mid*2-i],r-i+1);
        while(s[i-p[i]]==s[i+p[i]])p[i]++;
        if(i+p[i]-1>r)r=p[i]+i-1,mid=i;
        if(p[i]>ans)ans=p[i];
    }
    printf("%d\n",ans-1);
    return 0;
}

\(2\)[國家集訓隊] 最長雙迴文串

對於雙迴文串中的每個“斷點”(字元 # 的位置)。找到以其為左右端點的最長迴文串 \(lef_i,righ_i\)。答案即為 \(\max\limits_{lef_i,righ_i\ne 0}lef_i+righ_i\)。由於在演算法匹配中,只求出了一些點的 \(lef_i,righ_i\)。所以還需要遞推來補全所有的點,然後再求解。

點選檢視程式碼
#include <iostream>
#include <cstdio>
#define int long long
using namespace std;
const int N=3e5+10;
char s[N];
int n,mid,r=-1,p[N],L[N],R[N],ans;
void read(){
    char ch=getchar();
    s[++n]='~';s[++n]='#';
    while(ch>='a'&&ch<='z')s[++n]=ch,s[++n]='#',ch=getchar();
    return;
}
signed main(){
    read();
    for(int i=1;i<=n;i++){
        if(i>r)p[i]=1;
        else p[i]=min(p[mid*2-i],r-i+1);
        while(s[i-p[i]]==s[i+p[i]])p[i]++;
        if(i+p[i]-1>r)r=p[i]+i-1,mid=i;
        L[i-p[i]+1]=max(L[i-p[i]+1],p[i]-1);
        R[i+p[i]-1]=max(R[i+p[i]-1],p[i]-1);
    }
    for(int i=n-2;i>=2;i-=2)R[i]=max(R[i+2]-2,R[i]);
    for(int i=4;i<=n;i+=2)L[i]=max(L[i-2]-2,L[i]);
    for(int i=2;i<=n;i+=2)if(L[i]&&R[i])ans=max(ans,L[i]+R[i]);
    printf("%lld\n",ans);
    return 0;
}

\(3\)[國家集訓隊] 拉拉隊排練

\(P_i\) 並不僅僅代表以 \(i\) 為中心的最長迴文半徑加一。還代表原串中以 \(i\) 為迴文中心的迴文串的個數,因為 \([i+P_i-1,i-P_i+1],[i+P_i-2,i-P_i+2]\dots[i+1,i-1],[i,i]\) 都在最大的迴文串中,迴文的性質使得他們都是迴文串。

如果我們在演算法的過程中,用 \(sum_k\) 表示長度為 \(k\) 的迴文串的差分陣列。如果得到一個 \(i\)\(P_i=x\),那麼只需將 \(sum_x\) 加上 \(1\)。在一起計算時先計算長的迴文串,在將 \(sum_x\) 加到前面去。就可以完成計算。

因為題目要求一個“小團體”長度必須為單數。所以要對一些情況特判。\(sum_i=sum_i+sum_{i+2}\)。因為 \(k\) 很大,也會用到快速冪的技巧。同時記得特判 \(-1\) 的情況。

點選檢視程式碼
#include <iostream>
#include <cstdio>
using namespace std;
const int N=2e6+10,mod=19930726;
typedef long long ll;
ll len,k,n,mid,r=-1,p[N],maxn=0,sum[N],ans=1,z;
char s[N];
void read(){
    char ch=getchar();
    s[n]='~',s[++n]='#';
    while(ch<'a'||ch>'z')ch=getchar();
    while(ch>='a'&&ch<='z')s[++n]=ch,s[++n]='#',ch=getchar();
    return;
}
ll ksm(ll a,ll b){
    ll ans=1;
    while(b){
        if(b&1)ans=(a*ans)%mod;
        a=(a*a)%mod;
        b>>=1;
    }
    return ans;
}
int main(){
    scanf("%lld %lld",&len,&k);
    read();
    for(int i=1;i<=n;i++){
        if(i>r)p[i]=1;
        else p[i]=min(p[2*mid-i],r-i+1);
        while(s[i-p[i]]==s[i+p[i]])p[i]++;
        if(i+p[i]-1>r)r=i+p[i]-1,mid=i;
        if(s[i]!='#')sum[p[i]-1]++,maxn=max(maxn,p[i]-1);
    }
    for(int i=maxn;i>=1&&k>0;i--){
        sum[i]+=sum[i+2],z=min(k,sum[i]);
        ans=(ans*ksm(i,z))%mod;
        k-=z;
    }
    if(k>0)printf("-1\n");
    else printf("%lld\n",ans);
    return 0;
}

The End

感覺還比較有用

相關文章