[POJ 3415] Common Substrings (字尾陣列+單調棧優化)

zichenzhiguang發表於2016-09-08

連結

POJ 3415


題意

給出兩個字串s1、s2和一個整數K,求有多少個長度大於K的公共子串。


思路

字尾陣列的一個經典問題,對height陣列分組後使用單調棧將複雜度優化至O(n),單調棧還真是強大啊。。。

整個過程是先將s1、s2串起來,用比較小的數分隔,求出字尾陣列,按照K的公共字首長度去分組,對組內每個s2字尾,前方出現的每個同組s1字尾都與其存在長度不小於K的公共字首。之後再計算一遍組內s1字尾前方出現的各個同組s2字尾與其公共字首,累加即可得到結果。
對height陣列建st表,可以查詢任意兩個字尾的最長公共字首,然而即使這麼做複雜度也是O(n^2),不能滿足時間要求。
在我們掃描分組的時候,可以對height維護一個“單調棧”,由於字尾i、j(rank[i] < rank[j])之間的最長公共字首是min(height[rank[i] + 1], …, height[rank[i] + k], …, height[rank[j]],因此LCP(sa[rank[i]], sa[rank[j]])一定小於等於LCP(sa[rank[i] + 1], sa[rank[j]]),所以後入棧的字尾,需要對在它前方入棧且LCP大於它的字尾的貢獻做消減,使整個棧滿足“順序性”和“單調性”。
PS:其實我自己寫著寫著也寫懵逼了,看程式碼吧。。。這道題我幾乎是看別人部落格看到把程式碼背下來了才懂這個單調棧是怎麼優化的。


程式碼

#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
typedef long long lint;
#define maxn 200200
bool cmp(int *r, int a, int b, int l)
{ return r[a] == r[b] && r[a + l] == r[b + l]; }
int ta[maxn], tb[maxn], bk[maxn];
void da(int *r, int *sa, int n, int m)
{
    int i, j, p, *x = ta, *y = tb, *t;
    for(i = 0; i < m; i++) bk[i] = 0;
    for(i = 0; i < n; i++) bk[x[i] = r[i]]++;
    for(i = 1; i < m; i++) bk[i] += bk[i-1];
    for(i = 0; i < n; i++) sa[--bk[x[i]]] = i;
    for(j = 1, p = 1; p < n; j *= 2, m = p)
    {
        for(p = 0, i = n - j; i < n; i++) y[p++] = i;
        for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j;
        for(i = 0; i < m; i++) bk[i] = 0;
        for(i = 0; i < n; i++) bk[x[i]]++;
        for(i = 1; i < m; i++) bk[i] += bk[i-1];
        for(i = n-1; i >= 0; i--) sa[--bk[x[y[i]]]] = y[i];
        for(t = x, x = y, y = t, x[sa[0]] = 0, p = 1, i = 1; i < n; i++)
            x[sa[i]] = cmp(y, sa[i-1], sa[i], j) ? p - 1 : p++;
    }
}
int Rank[maxn], SA[maxn], Height[maxn];
void calheight(int *r, int n)
{
    for(int i = 0; i < n; i++) Rank[SA[i]] = i;
    for(int i = 0, k = 0; i < n; i++)
    {
        k ? k-- : 0;
        if(Rank[i])
            while(r[i + k] == r[SA[Rank[i] - 1] + k])
                k++;
        Height[Rank[i]] = k;
    }
}

struct _node
{
    int h, t;
} _stack[maxn];

char s1[maxn >> 1], s2[maxn >> 1];
int r[maxn], n, m;
int main()
{
    int K;
    while((cin >> K) && K)
    {
        scanf("%s%s", s1, s2);
        n = 0;
        for(int i = 0; s1[i]; i++)
            r[n++] = s1[i];
        r[m = n++] = 1;
        for(int i = 0; s2[i]; i++)
            r[n++] = s2[i];
        r[n++] = 0;

        da(r, SA, n, 1<<8);
        calheight(r, n);

        lint o = 0, tot = 0;
        for(int i = 3, top = 0, cnt; i < n; i++)
        {
            if(Height[i] < K) { tot = top = 0; continue; }
            cnt = 0;
            if(SA[i - 1] < m) { tot += Height[i] - K + 1; cnt++; }
            while(top > 0 && Height[i] < _stack[top - 1].h)
            {
                tot -= _stack[top - 1].t * (_stack[top - 1].h - Height[i]);
                cnt += _stack[top - 1].t;
                top--;
            }
            _stack[top].h = Height[i];
            _stack[top++].t = cnt;
            if(SA[i] > m) o += tot;
        }
        tot = 0;
        for(int i = 3, top = 0, cnt; i < n; i++)
        {
            if(Height[i] < K) { tot = top = 0; continue; }
            cnt = 0;
            if(SA[i - 1] > m) { tot += Height[i] - K + 1; cnt++; }
            while(top > 0 && Height[i] < _stack[top - 1].h)
            {
                tot -= _stack[top - 1].t * (_stack[top - 1].h - Height[i]);
                cnt += _stack[top - 1].t;
                top--;
            }
            _stack[top].h = Height[i];
            _stack[top++].t = cnt;
            if(SA[i] < m) o += tot;
        }

        cout << o << endl;
    }
    return 0;
}

相關文章