BZOJ 4650 [Noi2016]優秀的拆分:字尾陣列

Leohh發表於2018-02-12

題目連結:http://www.lydsy.com/JudgeOnline/problem.php?id=4650

題意:

  給你一個字串s,問你s及其子串中,將它們拆分成"AABB"的方式共有多少種。

 

題解:

  先只考慮"AA"的形式。

  設pre[i]表示以s[i]結尾的"AA"串共有多少個,nex[i]表示以s[i]開頭的"AA"串共有多少個。

  那麼拆分成"AABB"的總方案數 = ∑ pre[i]*nex[i+1]。

 

  然後考慮如何求pre和nex陣列。

  對於"AA"形式的串,有一個性質:

    假設一個"AA"串的長度為len。

    如果在s串上每隔len的距離設定一個關鍵點,那麼這個"AA"串一定經過相鄰的兩個關鍵點。

  所以先列舉"AA"串的長度len,然後列舉是哪兩個相鄰的關鍵點。

 

  設當前的兩個相鄰關鍵點分別為:x = i*len, y = (i+1)*len

  設字首s[1 to x]和字首s[1 to y]的LCS為a,字尾s[x to n]和字尾s[y to n]的LCP為b

  令tt = a+b-1

  如果有tt >= len,則就找到了a+b-len個長度為len的"AA"串。

  又因為這些"AA"串是必須經過x,y這兩個關鍵點的(否則會重複統計)

  所以上面的a = min(a,len), b = min(b,len)。

 

  找到的這些"AA"串會對pre陣列和nex陣列產生貢獻。

  所有這些"AA"串的開頭為區間[x-a+1, x+b-len],結尾為區間[y-a+len, y+b-1]。

  所以要給nex[x-a+1 to x+b-len]加1,給pre[y-a+len to y+b-1]加1

  差分一下,最後求一遍字首個就好。

  這樣pre和nex就求好了,然後統計"AABB"的總數即可。

 

  O(nlogn)求字尾陣列

  O(nlogn)求pre和nex(調和級數複雜度)

  O(n)統計"AABB"

  總複雜度O(nlogn)

 

AC Code:

  1 #include <iostream>
  2 #include <stdio.h>
  3 #include <string.h>
  4 #define MAX_N 30005
  5 #define MAX_K 20
  6 
  7 using namespace std;
  8 
  9 struct SA
 10 {
 11     int n;
 12     int a[MAX_N];
 13     int sa[MAX_N];
 14     int rk[MAX_N];
 15     int t1[MAX_N];
 16     int t2[MAX_N];
 17     int cnt[MAX_N];
 18     int tsa[MAX_N];
 19     int height[MAX_N];
 20     char s[MAX_N];
 21     
 22     int lg[MAX_N];
 23     int dat[MAX_N][MAX_K];
 24     
 25     void rsort()
 26     {
 27         memset(cnt,0,sizeof(cnt));
 28         for(int i=1;i<=n;i++) cnt[t2[i]]++;
 29         for(int i=1;i<=n;i++) cnt[i]+=cnt[i-1];
 30         for(int i=n;i>=1;i--) tsa[cnt[t2[i]]--]=i;
 31         memset(cnt,0,sizeof(cnt));
 32         for(int i=1;i<=n;i++) cnt[t1[i]]++;
 33         for(int i=1;i<=n;i++) cnt[i]+=cnt[i-1];
 34         for(int i=n;i>=1;i--) sa[cnt[t1[tsa[i]]]--]=tsa[i];
 35     }
 36     
 37     void suffix()
 38     {
 39         memset(cnt,0,sizeof(cnt));
 40         for(int i=1;i<=n;i++) a[i]=s[i],cnt[a[i]]++;
 41         for(int i='a';i<='z';i++) cnt[i]+=cnt[i-1];
 42         for(int i=1;i<=n;i++) rk[i]=cnt[a[i]];
 43         int len=1;
 44         while(len<n)
 45         {
 46             for(int i=1;i<=n;i++)
 47             {
 48                 t1[i]=rk[i];
 49                 t2[i]=i+len<=n ? rk[i+len] : 0;
 50             }
 51             rsort();
 52             for(int i=1;i<=n;i++)
 53             {
 54                 rk[sa[i]]=rk[sa[i-1]]+(t1[sa[i]]!=t1[sa[i-1]] || t2[sa[i]]!=t2[sa[i-1]]);
 55             }
 56             len<<=1;
 57         }
 58         int k=0;
 59         for(int i=1;i<=n;i++)
 60         {
 61             k=k?k-1:k;
 62             int j=sa[rk[i]-1];
 63             while(a[i+k]==a[j+k]) k++;
 64             height[rk[i]]=k;
 65         }
 66     }
 67     
 68     void init_st()
 69     {
 70         lg[0]=-1;
 71         for(int i=1;i<=n;i++)
 72         {
 73             lg[i]=lg[i>>1]+1;
 74             dat[i][0]=height[i];
 75         }
 76         for(int k=1;(1<<k)<=n;k++)
 77         {
 78             for(int i=1;i+(1<<k)-1<=n;i++)
 79             {
 80                 dat[i][k]=min(dat[i][k-1],dat[i+(1<<(k-1))][k-1]);
 81             }
 82         }
 83     }
 84     
 85     int lcp(int l,int r)
 86     {
 87         l=rk[l]; r=rk[r];
 88         if(l>r) swap(l,r); l++;
 89         int k=lg[r-l+1];
 90         return min(dat[l][k],dat[r-(1<<k)+1][k]);
 91     }
 92     
 93     void init(char *_s,int _n)
 94     {
 95         memset(a,0,sizeof(a));
 96         memset(sa,0,sizeof(sa));
 97         memset(rk,0,sizeof(rk));
 98         memset(tsa,0,sizeof(tsa));
 99         memset(height,0,sizeof(height));
100         n=_n;
101         for(int i=1;i<=n;i++) s[i]=_s[i];
102         suffix();
103         init_st();
104     }
105 };
106 
107 int n,t;
108 char str[MAX_N];
109 char rev[MAX_N];
110 long long ans;
111 long long pre[MAX_N];
112 long long nex[MAX_N];
113 SA sa1,sa2;
114 
115 inline void add(long long *a,int l,int r)
116 {
117     a[l]++; a[r+1]--;
118 }
119 
120 void cal_aa()
121 {
122     memset(pre,0,sizeof(pre));
123     memset(nex,0,sizeof(nex));
124     for(int len=1;len<=n/2;len++)
125     {
126         for(int i=1;(i+1)*len<=n;i++)
127         {
128             int x=i*len,y=(i+1)*len;
129             int a=sa2.lcp(n-x+1,n-y+1),b=sa1.lcp(x,y);
130             a=min(a,len); b=min(b,len);
131             int tt=a+b-1;
132             if(tt>=len)
133             {
134                 add(nex,x-a+1,x+b-len);
135                 add(pre,y-a+len,y+b-1);
136             }
137         }
138     }
139     for(int i=1;i<=n;i++)
140     {
141         pre[i]+=pre[i-1];
142         nex[i]+=nex[i-1];
143     }
144 }
145 
146 void cal_aabb()
147 {
148     ans=0;
149     for(int i=1;i<n;i++)
150     {
151         ans+=pre[i]*nex[i+1];
152     }
153 }
154 
155 int main()
156 {
157     scanf("%d",&t);
158     while(t--)
159     {
160         scanf("%s",str+1);
161         n=strlen(str+1);
162         for(int i=1,j=n;i<=n;i++,j--) rev[j]=str[i];
163         sa1.init(str,n); sa2.init(rev,n);
164         cal_aa();
165         cal_aabb();
166         printf("%lld\n",ans);
167     }
168 }

 

相關文章