列舉子串的中心,往兩側擴充套件,將兩側對應位置的字元交替寫下來,得到一個字串$S$。
若前後長度為$L$的子串迴圈同構,則在$S$中它們對應長度為$2L$的字首,需要滿足它可以由不超過$2$個偶迴文串拼接而成。
有一個結論是,若$S=uv$,其中$uv$都是偶迴文串,那麼要麼$u$是$S$的最長偶迴文字首,要麼$v$是$S$的最長偶迴文字尾。
證明:
設$S=x_1y_1=x_2y_2=x_3y_3$。
假設結論不成立,那麼顯然雙迴文劃分數$\geq 3$,設$x_1$為$S$的最長迴文字首,$y_3$是$S$的最長迴文字尾,$x_2$和$y_2$都是迴文串,則$y_1$和$x_3$都不是迴文串。
因為$x_1$是最長迴文字首,所以$|x_1|>|x_2|$,同理$|y_2|<|y_3|$,則$|x_1|>|x_2|>|x_3|$。
................[v.....]
[x1...................][y1....]
[x2..........][y2.............]
[x3.....][y3..................]
設$x_1=x_2v$,那麼因為$v$是迴文串$y_2$的字首,所以$v^R$是$y_2$的字尾,也是$y_3$的字尾,因為$y_3$是迴文串,所以$v$是$y_3$的字首,得出$x_3v$是$x_1$的字首。
因為$x_1$是迴文串,$x_2$是$x_1$的字首,所以$x_2^R$是$x_1$的字尾,又因為$x_2$是迴文串,所以$x_2$也是$x_1$的字尾,所以長度$|x_1|-|x_2|=|v|$是$x_1$的一個週期,也是$x_1$的字首$x_3v$的一個週期。這說明$|v|$也是$v^Rx_3^R$的一個週期,即$x_3^R$是$v^Rv^R...v^Rv^R$的字首。
因為$v$是迴文串$x_1$的字尾,所以$v^R$是$x_1$的字首,而$|v|$是$x_1$的週期,所以$x_1$是$v^Rv^R...v^Rv^R$的字首,那麼$x_1$的字首$x_3$也是$v^Rv^R...v^Rv^R$的字首。
因為$x_3^R$和$x_3$都是$v^Rv^R...v^Rv^R$的字首,所以$x_3=x_3^R$,即$x_3$是迴文串,和假設矛盾。所以結論成立。
通過Manacher預處理出每個位置的最長迴文半徑$f$,即可求出每個字首的最長偶迴文字首和最長偶迴文字尾,剩下部分可以根據$f$陣列$O(1)$判斷一個子串是否是迴文串。
時間複雜度$O(n^2)$。
#include<cstdio> const int N=10010; int n,i,a[N],c[N],s[N],f[N],pre[N],suf[N],ans; inline int min(int a,int b){return a<b?a:b;} inline void umin(int&a,int b){a>b?(a=b):0;} inline void umax(int&a,int b){a<b?(a=b):0;} inline bool check(int l,int r){ if(l>r)return 1; return l+r+f[l+r]>r+r; } inline void solve(int x){ int i,j,r,p,m=0,len; for(i=x,j=x+1;i&&j<=n;i--,j++)c[++m]=a[i],c[++m]=a[j]; for(i=1;i<=m;i++)s[i<<1]=c[i],s[i<<1|1]=-1; s[0]=-2,s[1]=-1,s[len=(m+1)<<1]=-3; for(r=p=0,f[1]=1,i=2;i<len;i++){ for(f[i]=r>i?min(r-i,f[p*2-i]):1;s[i-f[i]]==s[i+f[i]];f[i]++); if(i+f[i]>r)r=i+f[i],p=i; } for(i=0;i<=m+1;i++)pre[i]=0,suf[i]=len; for(i=3;i<len;i+=2){ if(f[i]==i)pre[f[i]-1]=f[i]-1; umin(suf[(i+f[i]-1)>>1],i>>1); } for(i=1;i<=m;i++)umax(pre[i],pre[i-1]); for(i=m;i;i--)umin(suf[i],suf[i+1]); for(i=0;i<=m;i++)if(suf[i]>=i)suf[i]=0;else suf[i]=(i-suf[i])<<1; for(i=2;i<=m;i+=2)if(check(pre[i]+1,i)||check(1,i-suf[i]))ans++; } int main(){ scanf("%d",&n); for(i=1;i<=n;i++)scanf("%d",&a[i]); for(i=1;i<n;i++)solve(i); return printf("%d",ans),0; }