手玩了一個小時終於做出來了,這不得寫一篇題解記錄一下??
下面設 \(s\) 的長度為 \(n\),\(t\) 的長度為 \(m\)。
考慮分類討論:
如果 \(s\) 中有一個子串 \(s'\) 與 \(t\) 完全相同(可以用雜湊進行比較),設 \(s'\) 是 \(s\) 的第 \(l\) 到第 \(r\) 個字元組成的字串,則我們可以刪除 \([1,l-1]\) 或者 \([r+1,n]\) 的某一個子區間,計算出它們的總個數就是 \(\dfrac{l\times (l-1)}{2}\) 和 \(\dfrac{(n-r)\times (n-r+1)}{2}\),兩者之和就是這一種情況的方案數。
我們也可以選定一個子串 \(s'\),滿足 \(s'\) 的某一段字首 \(pre\) 和某一段字尾 \(suc\) 不相交且兩者連線到一起就是 \(t\)。這樣的話我們刪掉 \(pre\) 和 \(suc\) 中間的字元就符合條件了。
例如 ciaohallo
要變成 ciallo
就可以選取字首 cia
和字尾 llo
,它們拼在一起就是我們要得到的字串。值得注意的是,為了避免掉與上一種情況出現重複計算,我們需要滿足 \(pre\) 和 \(suc\) 中間必須留有至少一個字元。在上面的例子中,\(pre\) 和 \(suc\) 之間就留有 oha
這一段字元。
但是怎麼計算滿足上述條件的 \(s'\) 的個數呢?我們可以設 \(p_i\) 表示 \(s\) 的第 \(i\) 個到第 \(n\) 個字元與 \(t\) 的最大公共字首的長度,\(q_i\) 表示 \(s\) 的第 \(1\) 個到第 \(i\) 個字元與 \(t\) 的最大公共字尾的長度。對於 \(\forall i\in [1,n],j\in[i+m,n]\),如果有 \(p_i+q_j≥m\),則區間 \([i,j]\) 組成的子串 \(s'\) 就是滿足條件的。當然 \(s'\) 變成 \(t\) 的刪除方法也會有不同的代價。但是我們其實可以推出這個代價就是 \(p_i+q_j-m+1\)。
為什麼是對的?我們來看一下這張圖:
圖中第一行是 \([i,i+p_i-1]\) 的字串 \(A\),第二行是 \([j-q_j+1,j]\) 的字串 \(B\),我們有 \(p_i=6,q_j=4,m=7\),則 \(p_i+q_j≥m\),這個時候我們將它們進行對齊就得到了這張圖。由於 \(p_i+q_j≥m\),所以 \(A\) 和 \(B\) 一定有重合部分。由圖可知重合部分的長度為 \(p_i+q_j-m=3\)。
我們可以考慮 \(pre\) 的選擇方法,進行得到唯一確定的 \(suc\) 的選擇方法。我們的 \(pre\) 可以選擇 ABA
,ABAA
,ABAAB
,ABAABA
,即重合部分的長度加一 \(p_i+q_j-m+1\)。不難發現 \(pre\) 最長的時候就是取 \(A\),最短的時候就是取 \(A\) 中不與 \(B\) 重合的部分。這個結合圖片應該是非常好理解的,大家可以自己思考一下。
綜上所述,對於每一個 \(i\),我們找到所有的 \(j\in[i+m,n]\) 滿足 \(p_i+q_j≥m\),給答案增加 \(p_i+q_j-m+1\)。當然我們還需要事先讓 \(p_i,q_j\) 與 \(m-1\) 取一個最小值,因為當 \(p_i\) 或者 \(q_j\) 等於 \(m\) 時,就會和第一種情況算重。
可是暴力找是 \(O(n^2)\) 的,我們需要最佳化。條件 \(p_i+q_j≥m\iff q_j≥m-p_i\),所以每一個 \(q_j\) 都是在 \([m-p_i,m-1]\) 內的,這樣我們就不難想到權值線段樹,儲存區間範圍內的 \(sum=\sum q_j\) 以及 \(q_j\) 的個數 \(num\)。則對於每一個 \(i\),它所帶來的代價就是 \(sum+num\times (p_i-m+1)\)。
然後就可以愉快地打出程式碼了:
#include<bits/stdc++.h>
#define Q(id,x,y,flag) query(1,1,id,x,y,flag)
using namespace std;
const int MAXN=4e5+5;
const unsigned long long base=179;
int n,m;
char a[MAXN],b[MAXN];
unsigned long long mul[MAXN],Hash1[MAXN],Hash2[MAXN];
int p[MAXN],q[MAXN];
struct node
{
long long sum;
int num;//sum表示和,num表示個數
}T[MAXN<<2];
void pushup(int x){ T[x].num=T[x<<1].num+T[x<<1|1].num,T[x].sum=T[x<<1].sum+T[x<<1|1].sum; }
void change_tree(int x,int l,int r,int k)
{
if(l>r) return;
if(l==r)
{
T[x].sum+=l,T[x].num++;
return;
}
int mid=(l+r)/2;
if(k<=mid) change_tree(x<<1,l,mid,k);
else change_tree(x<<1|1,mid+1,r,k);
pushup(x);
}
long long query(int x,int l,int r,int L,int R,int flag)
{
if(L>R) return 0;
if(L<=l&&r<=R)
{
if(!flag) return T[x].sum;
return T[x].num;
}
int mid=(l+r)/2;
long long res=0;
if(L<=mid) res+=query(x<<1,l,mid,L,R,flag);
if(R>mid) res+=query(x<<1|1,mid+1,r,L,R,flag);
return res;
}
long long solve(int x){ return 1ll*x*(x+1)/2; }
int main()
{
cin>>(a+1)>>(b+1);
n=strlen(a+1),m=strlen(b+1);
mul[0]=1;
for(int i=1;i<=n;i++) Hash1[i]=Hash1[i-1]*base+a[i],mul[i]=mul[i-1]*base;//雜湊預處理
for(int i=1;i<=m;i++) Hash2[i]=Hash2[i-1]*base+b[i];
for(int i=1;i<=n;i++)
{
if(a[i]!=b[1]) continue;
int l=1,r=min(n-i+1,m);
while(l<=r)//二分找最長公共字首
{
int mid=(l+r)/2;
if(Hash1[i+mid-1]-Hash1[i-1]*mul[mid]==Hash2[mid]) p[i]=mid,l=mid+1;
else r=mid-1;
}
p[i]=min(p[i],m-1);
}
for(int i=n;i>=1;i--)
{
if(a[i]!=b[m]) continue;
int l=1,r=min(i,m);
while(l<=r)//同理找字尾
{
int mid=(l+r)/2;
if(Hash1[i]-Hash1[i-mid]*mul[mid]==Hash2[m]-Hash2[m-mid]*mul[mid]) q[i]=mid,l=mid+1;
else r=mid-1;
}
q[i]=min(q[i],m-1);
}
long long res=0;
for(int i=1;i<=n-m+1;i++)
{
if(Hash1[i+m-1]-Hash1[i-1]*mul[m]==Hash2[m]) res+=solve(i-1)+solve(n-i-m+1);//第一種情況
}
if(q[n]) change_tree(1,1,m-1,q[n]);
for(int i=n-m;i>=1;i--)//注意是逆序!!!!
{
if(p[i]) res+=Q(m-1,m-p[i],m-1,0)+Q(m-1,m-p[i],m-1,1)*(p[i]-m+1);
if(q[i+m-1]) change_tree(1,1,m-1,q[i+m-1]);//每次要記得更新線段樹
}
cout<<res;
return 0;
}