字尾陣列複習

Arashimu發表於2022-02-20

字尾陣列

陣列的定義

一下排名均是在字典序下的排名

\(sa[i]\):排名為\(i\)的字尾的編號

\(rank[i]\):第\(i\)個字尾串的排名

\(rank[sa[i]]=i\)\(sa[rank[i]]=i\)

\(height[i]\):排名為\(i\)的字尾和排名為\(i-1\)的字尾的最長公共字首

模板:

#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 1e6 + 5;
int n, m;
char s[N];
int sa[N], x[N], y[N], rk[N], height[N], c[N];
void get_sa()
{
    //基數排序o(n)
    //先對第一關鍵字(模式串出現的字母)排序
    for (int i = 1; i <= n; i++) c[x[i] = s[i]]++;
    for (int i = 2; i <= m; i++) c[i] += c[i - 1];
    for (int i = n; i > 0; i--)  sa[c[x[i]]--] = i;

    for (int k = 1; k <= n; k <<= 1)
    {
        int num = 0;
        //以長度為k,i為第一關鍵字,i+k為第二關鍵字
        for (int i = n - k + 1; i <= n; i++) y[++num] = i; //從n-k+1開始第二關鍵字為空字元,最小,所以排最前面
        for (int i = 1; i <= n; i++) //實際上只有n-k個數
            if (sa[i] > k) y[++num] = sa[i] - k;

        //對第二關鍵字排序
        for (int i = 1; i <= m; i++) c[i] = 0;
        for (int i = 1; i <= n; i++) c[x[i]]++;
        for (int i = 2; i <= m; i++) c[i] += c[i - 1];
        for (int i = n; i; i--)  sa[c[x[y[i]]]--] = y[i], y[i] = 0;
        swap(x, y);
        x[sa[1]] = 1, num = 1;
        for (int i = 2; i <= n; i++)
            x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++num;
        if (num == n)break;
        m = num;
    }
}


void get_height()
{
    for (int i = 1; i <= n; i++) rk[sa[i]] = i;
    for (int i = 1, k = 0; i <= n; i++)
    {
        if (rk[i] == 1) continue;
        if (k) k--;
        int j = sa[rk[i] - 1];
        while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) k++;
     
        height[rk[i]] = k;
    }
}
int main()
{
    scanf("%s", s + 1);
    n = strlen(s + 1), m = 122;//字母z的ASCLL值為122;
    get_sa();
    get_height();

    return 0;
}

應用

\(1.\)求本質不同的子串的個數:

普通做法是列舉左右端點然後雜湊判重。考慮我們當前列舉到的左端點是\(l\),那麼我們右端點就要依次列舉\(l+1,l+1,...,n\),發現,其實就是列舉第\(l\)個字尾的所有字首。然後考慮如何判重,將第\(l\)個字尾的所有字首按字典序排序記為,通過\(height\)陣列我們知道排名相鄰的兩個串的最長公共字首,先記他們的長度依次是\(len_1,len_2,...,len_{n-l+1}\),那麼第一個穿可以產生\(len_1\)個前面沒有出現過的串(因為是第一個,所以不需要判重),再考慮第二個串,很容易知道第二個串貢獻的不同的串個數為\(len_2-height[2]\)\(height[2]\)記錄的是第二個串和第一個串的最長共字首,這部分在第一個串中被統計過了,所以不需要再統計。綜上,所有的答案就是\(\sum_{i=1}^nlen_i-height[i]\)就是答案。

例題:\(生成魔咒\)

題意:每次加入一個字元,問此時本質不同的子串的數量

Sol:正向加不好做,考慮先把所有字元加入後,翻轉(翻轉和正向的答案其實是一樣的)後從前往後依次刪掉字元然後求答案。假設翻轉後的串為\(S\),將\(S\)的所有字尾排序後,每次從前刪掉一個字元是,相當於刪去一個字尾,\(sa\)陣列利用連結串列維護,假設刪掉的字尾排名為\(i\),那麼第\(i-1\)和第\(i+1\)的串的最長公共字首即\(height[i+1]=min(height[i],height[i+1])\)的。然後按照上面的方法求解即可。

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e6 + 5;
int n, m;
int s[N],l[N],r[N];
int sa[N], x[N], y[N], rk[N], height[N], c[N];
void get_sa()
{
    //基數排序o(n)
    //先對第一關鍵字(模式串出現的字母)排序
    for (int i = 1; i <= n; i++) c[x[i] = s[i]]++;
    for (int i = 2; i <= m; i++) c[i] += c[i - 1];
    for (int i = n; i > 0; i--)  sa[c[x[i]]--] = i;

    for (int k = 1; k <= n; k <<= 1)
    {
        int num = 0;
        for (int i = n - k + 1; i <= n; i++) y[++num] = i; //以長度為k,i為第一關鍵字,i+k為第二關鍵字
        for (int i = 1; i <= n; i++)
            if (sa[i] > k) y[++num] = sa[i] - k;

        //對第二關鍵字排序
        for (int i = 1; i <= m; i++) c[i] = 0;
        for (int i = 1; i <= n; i++) c[x[i]]++;
        for (int i = 2; i <= m; i++) c[i] += c[i - 1];
        for (int i = n; i; i--)  sa[c[x[y[i]]]--] = y[i], y[i] = 0;
        swap(x, y);
        x[sa[1]] = 1, num = 1;
        for (int i = 2; i <= n; i++)
            x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++num;
        if (num == n)break;
        m = num;
    }
}


void get_height()
{
    for (int i = 1; i <= n; i++) rk[sa[i]] = i;
    for (int i = 1, k = 0; i <= n; i++)       //利用 height[rk[i]]>= height[rk[i-1]]-1
    {
        if (rk[i] == 1) continue;
        if (k) k--;
        int j = sa[rk[i] - 1];
        while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) k++;
     
        height[rk[i]] = k;
    }
}
unordered_map<int,int>mp;

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin>>n;
    vector<ll>ans(n+1);
    for(int i=n;i>=1;i--)
    {
        cin>>s[i];
        if(mp[s[i]]==0) mp[s[i]]=++m;
        s[i]=mp[s[i]];
    }
    get_sa();
    get_height();

    ll res=0;
    for(int i=1;i<=n;i++)
    {
        res+=n-sa[i]+1-height[i];
        l[i]=i-1,r[i]=i+1;
    }

    l[n+1]=n,r[0]=1;
    for(int i=1;i<=n;i++) //從前往後刪
    {
        ans[i]=res;
        int k=rk[i],nk=r[k];
        res-=n-sa[k]+1-height[k];
        res-=n-sa[nk]+1-height[nk];
        height[nk]=min(height[nk],height[k]);
        res+=n-sa[nk]+1-height[nk];
        r[l[k]]=r[k],l[r[k]]=l[k];
    }
    for(int i=n;i>=1;i--)cout<<ans[i]<<'\n';
    return 0;
}


\(H - Can You Solve the Harder Problem?\)

題意:給定一個陣列,求所有本質不同的子段的最大值的和

Sol:(字尾陣列+單調棧+RMQ)

如果不考慮本質不同的限制,那麼就是直接列舉每個數作為左端點,然後利用單調棧找到右邊第一個大於這個數的下標,計算貢獻即可(可以參考\(Atcoder Minimum Sum\))。考慮限制要怎麼操作呢?

定義陣列\(nxt[i]\)為第\(i\)個數右邊第一個比它大的數的下標。\(suf[i]\)表示從\(i\)開始的依次遞增的貢獻的字尾和,轉移是\(suf[i]=a[i]\times(nxt[i]-i)+suf[nxt[i]]\),什麼意思呢,如下圖!
image

按字典序排好序後,假設當前計算到第\(i\)個串,如果\(height[i]\)不等於,說明第\(i\)個串和第\(i-1\)個串有公共字首,如果直接加到答案裡會算重。設\(l=sa[i],r=sa[i]+height[i]-1\),即區間\([l,r]\)是重複的部分,設\(p\)\([l,r]\)中值最大的數的下標,則\(nxt[p]\)一定大於\(r\),因為\(a[p]\)\([l,r]\)中最大的了,而\(a[nxt[p]]>a[p]\),因此\(nxt[p]\)只能存在於比\(r\)大的地方。所以答案貢獻就是\(suf[nxt[p]]+a[p]\times (nxt[p]-r-1)\),為什麼是\(nxt[p]-r-1\)是因為第\(i-1\)個串已經計算了\([p+1,r]\)對數\(a[p]\)的貢獻,可能有人會覺得這樣這樣子會不會遺漏統計\([r+1,nxt[p-1]]\)\([p+1,r]\)中的數的貢獻,其實不會,為什麼,因為我們計算的重複只是和第\(i-1\)個串比較,第\(i-1\)個串是不會統計\([p+1,r]\)中的數對答案的貢獻,所以我們在第\(i\)個串不需要考慮這部分。
image

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N =2e5+10;
int n, m;
int s[N];
int sa[N], x[N], y[N], rk[N], height[N], c[N];
void get_sa()
{
    //基數排序o(n)
    //先對第一關鍵字(模式串出現的字母)排序
    for(int i=1;i<=n;i++) c[i]=sa[i]=0;
    for (int i = 1; i <= n; i++) c[x[i] = s[i]]++;
    for (int i = 2; i <= m; i++) c[i] += c[i - 1];
    for (int i = n; i > 0; i--)  sa[c[x[i]]--] = i;

    for (int k = 1; k <= n; k <<= 1)
    {
        int num = 0;
        for (int i = n - k + 1; i <= n; i++) y[++num] = i; //以長度為k,i為第一關鍵字,i+k為第二關鍵字
        for (int i = 1; i <= n; i++)
            if (sa[i] > k) y[++num] = sa[i] - k;

        //對第二關鍵字排序
        for (int i = 1; i <= m; i++) c[i] = 0;
        for (int i = 1; i <= n; i++) c[x[i]]++;
        for (int i = 2; i <= m; i++) c[i] += c[i - 1];
        for (int i = n; i; i--)  sa[c[x[y[i]]]--] = y[i], y[i] = 0;
        swap(x, y);
        x[sa[1]] = 1, num = 1;
        for (int i = 2; i <= n; i++)
            x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++num;
        if (num == n)break;
        m = num;
    }
}


void get_height()
{
    for (int i = 1; i <= n; i++) rk[sa[i]] = i;
    for (int i = 1, k = 0; i <= n; i++)       //利用 height[rk[i]]>= height[rk[i-1]]-1
    {
        if (rk[i] == 1) continue;
        if (k) k--;
        int j = sa[rk[i] - 1];
        while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) k++;
     
        height[rk[i]] = k;
    }
}
int f[N][20];
void RMQ_init()
{
    for (int i = 1; i <= n+1; i++) f[i][0] = i;
        for (int j = 1; j <=18 ; j++)
            for (int i = 1; i + (1 << j) - 1 <= n; i++) {
                int x = f[i][j - 1], y = f[i + (1 << (j - 1))][j - 1];
                f[i][j] = s[x] > s[y] ?  x: y;  //下標和最大值看情況轉換
            }
}
int RMQ_query(int l, int r) {
        int k = log2(r - l + 1);
        int x = f[l][k], y = f[r - (1 << k) + 1][k];
        return s[x] > s[y] ? x : y;
}
int nxt[N],stk[N];
ll suf[N];
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int T;
    cin>>T;
    while(T--)
    {

        cin>>n;
        for(int i=1;i<=n;i++) nxt[i]=suf[i]=stk[i]=0;
        stk[n+1]=0;
        vector<int>v;
        for(int i=1;i<=n;i++)
        {
            cin>>s[i];
            v.push_back(s[i]);
        }
        sort(v.begin(),v.end());
        v.erase(unique(v.begin(),v.end()),v.end());
        m=v.size();
        for(int i=1;i<=n;i++) s[i]=lower_bound(v.begin(),v.end(),s[i])-v.begin()+1;
        
        RMQ_init();
        get_sa();
        get_height();

  
        int top=1;
        s[n+1]=1e6+10,stk[top]=n+1;
        for(int i=n;i>=1;i--)
        {
            while(top&&s[stk[top]]<=s[i]) top--;
            stk[++top]=i;
            nxt[i]=stk[top-1];
        }
   
        nxt[n+1]=n+1;
        suf[n+1]=0;
        for(int i=n;i>=1;i--) suf[i]=1ll*v[s[i]-1]*(nxt[i]-i);
        for(int i=n;i>=1;i--) suf[i]+=suf[nxt[i]];

        ll ans=0;
    
        for(int i=1;i<=n;i++)
        {
            int lcp=height[i];
            if(lcp==0) ans+=suf[sa[i]];
            else
            {
                int l=sa[i],r=sa[i]+lcp-1;
                int p=RMQ_query(l,r);
                ans+=suf[nxt[p]]+1ll*v[s[p]-1]*(nxt[p]-r-1);
            }
        }
        cout<<ans<<'\n';
    }

   
    return 0;
}


\(Atcoder Minimum Sum\)

題意

給定一個長度為\(N\)\(1\)\(N\)的一個排列,讓你計算\(\sum_{l=1}^n\sum_{r=l}^nmin(a_l,a_{l+1},...,a_r)\)

\(1\le N\le 200000\)

Sol

考慮統計每一個\(a_i\)作為貢獻的區間數量

很容易想到先找到\(a_i\)右邊第一個比\(a_i\)小的數的下標\(r[i]\)和左邊第一個比\(a_i\)小的數的下標\(l[i]\)。(經典單調棧操作)

比如這樣一個序列\(1,4,3,6,9,2,7,8\)

考慮數字\(3\)下標為\(3\)的貢獻,則左邊第一個比\(3\)小的數的下標為\(1\),則右邊第一個比\(3\)小的數的下標為\(6\),那麼顯然有貢獻的區間為\([2,3],[3,3],[3,4],[3,5],[2,4],[2,5]\),由乘法原理知道有\((3-1)\times (6-3)\)個,即\((i-l[i])\times (r[i]-i)\)個。

相關文章