題目連結:http://www.lydsy.com/JudgeOnline/problem.php?id=4650
題意:
給你一個字串s,問你s及其子串中,將它們拆分成"AABB"的方式共有多少種。
題解:
先只考慮"AA"的形式。
設pre[i]表示以s[i]結尾的"AA"串共有多少個,nex[i]表示以s[i]開頭的"AA"串共有多少個。
那麼拆分成"AABB"的總方案數 = ∑ pre[i]*nex[i+1]。
然後考慮如何求pre和nex陣列。
對於"AA"形式的串,有一個性質:
假設一個"AA"串的長度為len。
如果在s串上每隔len的距離設定一個關鍵點,那麼這個"AA"串一定經過相鄰的兩個關鍵點。
所以先列舉"AA"串的長度len,然後列舉是哪兩個相鄰的關鍵點。
設當前的兩個相鄰關鍵點分別為:x = i*len, y = (i+1)*len
設字首s[1 to x]和字首s[1 to y]的LCS為a,字尾s[x to n]和字尾s[y to n]的LCP為b
令tt = a+b-1
如果有tt >= len,則就找到了a+b-len個長度為len的"AA"串。
又因為這些"AA"串是必須經過x,y這兩個關鍵點的(否則會重複統計)
所以上面的a = min(a,len), b = min(b,len)。
找到的這些"AA"串會對pre陣列和nex陣列產生貢獻。
所有這些"AA"串的開頭為區間[x-a+1, x+b-len],結尾為區間[y-a+len, y+b-1]。
所以要給nex[x-a+1 to x+b-len]加1,給pre[y-a+len to y+b-1]加1
差分一下,最後求一遍字首個就好。
這樣pre和nex就求好了,然後統計"AABB"的總數即可。
O(nlogn)求字尾陣列
O(nlogn)求pre和nex(調和級數複雜度)
O(n)統計"AABB"
總複雜度O(nlogn)
AC Code:
1 #include <iostream> 2 #include <stdio.h> 3 #include <string.h> 4 #define MAX_N 30005 5 #define MAX_K 20 6 7 using namespace std; 8 9 struct SA 10 { 11 int n; 12 int a[MAX_N]; 13 int sa[MAX_N]; 14 int rk[MAX_N]; 15 int t1[MAX_N]; 16 int t2[MAX_N]; 17 int cnt[MAX_N]; 18 int tsa[MAX_N]; 19 int height[MAX_N]; 20 char s[MAX_N]; 21 22 int lg[MAX_N]; 23 int dat[MAX_N][MAX_K]; 24 25 void rsort() 26 { 27 memset(cnt,0,sizeof(cnt)); 28 for(int i=1;i<=n;i++) cnt[t2[i]]++; 29 for(int i=1;i<=n;i++) cnt[i]+=cnt[i-1]; 30 for(int i=n;i>=1;i--) tsa[cnt[t2[i]]--]=i; 31 memset(cnt,0,sizeof(cnt)); 32 for(int i=1;i<=n;i++) cnt[t1[i]]++; 33 for(int i=1;i<=n;i++) cnt[i]+=cnt[i-1]; 34 for(int i=n;i>=1;i--) sa[cnt[t1[tsa[i]]]--]=tsa[i]; 35 } 36 37 void suffix() 38 { 39 memset(cnt,0,sizeof(cnt)); 40 for(int i=1;i<=n;i++) a[i]=s[i],cnt[a[i]]++; 41 for(int i='a';i<='z';i++) cnt[i]+=cnt[i-1]; 42 for(int i=1;i<=n;i++) rk[i]=cnt[a[i]]; 43 int len=1; 44 while(len<n) 45 { 46 for(int i=1;i<=n;i++) 47 { 48 t1[i]=rk[i]; 49 t2[i]=i+len<=n ? rk[i+len] : 0; 50 } 51 rsort(); 52 for(int i=1;i<=n;i++) 53 { 54 rk[sa[i]]=rk[sa[i-1]]+(t1[sa[i]]!=t1[sa[i-1]] || t2[sa[i]]!=t2[sa[i-1]]); 55 } 56 len<<=1; 57 } 58 int k=0; 59 for(int i=1;i<=n;i++) 60 { 61 k=k?k-1:k; 62 int j=sa[rk[i]-1]; 63 while(a[i+k]==a[j+k]) k++; 64 height[rk[i]]=k; 65 } 66 } 67 68 void init_st() 69 { 70 lg[0]=-1; 71 for(int i=1;i<=n;i++) 72 { 73 lg[i]=lg[i>>1]+1; 74 dat[i][0]=height[i]; 75 } 76 for(int k=1;(1<<k)<=n;k++) 77 { 78 for(int i=1;i+(1<<k)-1<=n;i++) 79 { 80 dat[i][k]=min(dat[i][k-1],dat[i+(1<<(k-1))][k-1]); 81 } 82 } 83 } 84 85 int lcp(int l,int r) 86 { 87 l=rk[l]; r=rk[r]; 88 if(l>r) swap(l,r); l++; 89 int k=lg[r-l+1]; 90 return min(dat[l][k],dat[r-(1<<k)+1][k]); 91 } 92 93 void init(char *_s,int _n) 94 { 95 memset(a,0,sizeof(a)); 96 memset(sa,0,sizeof(sa)); 97 memset(rk,0,sizeof(rk)); 98 memset(tsa,0,sizeof(tsa)); 99 memset(height,0,sizeof(height)); 100 n=_n; 101 for(int i=1;i<=n;i++) s[i]=_s[i]; 102 suffix(); 103 init_st(); 104 } 105 }; 106 107 int n,t; 108 char str[MAX_N]; 109 char rev[MAX_N]; 110 long long ans; 111 long long pre[MAX_N]; 112 long long nex[MAX_N]; 113 SA sa1,sa2; 114 115 inline void add(long long *a,int l,int r) 116 { 117 a[l]++; a[r+1]--; 118 } 119 120 void cal_aa() 121 { 122 memset(pre,0,sizeof(pre)); 123 memset(nex,0,sizeof(nex)); 124 for(int len=1;len<=n/2;len++) 125 { 126 for(int i=1;(i+1)*len<=n;i++) 127 { 128 int x=i*len,y=(i+1)*len; 129 int a=sa2.lcp(n-x+1,n-y+1),b=sa1.lcp(x,y); 130 a=min(a,len); b=min(b,len); 131 int tt=a+b-1; 132 if(tt>=len) 133 { 134 add(nex,x-a+1,x+b-len); 135 add(pre,y-a+len,y+b-1); 136 } 137 } 138 } 139 for(int i=1;i<=n;i++) 140 { 141 pre[i]+=pre[i-1]; 142 nex[i]+=nex[i-1]; 143 } 144 } 145 146 void cal_aabb() 147 { 148 ans=0; 149 for(int i=1;i<n;i++) 150 { 151 ans+=pre[i]*nex[i+1]; 152 } 153 } 154 155 int main() 156 { 157 scanf("%d",&t); 158 while(t--) 159 { 160 scanf("%s",str+1); 161 n=strlen(str+1); 162 for(int i=1,j=n;i<=n;i++,j--) rev[j]=str[i]; 163 sa1.init(str,n); sa2.init(rev,n); 164 cal_aa(); 165 cal_aabb(); 166 printf("%lld\n",ans); 167 } 168 }