「TAOI-2」Ciallo~(∠・ω< )⌒★ 題解

Supor__Shoop發表於2024-09-27

手玩了一個小時終於做出來了,這不得寫一篇題解記錄一下??

下面設 \(s\) 的長度為 \(n\)\(t\) 的長度為 \(m\)

考慮分類討論:

如果 \(s\) 中有一個子串 \(s'\)\(t\) 完全相同(可以用雜湊進行比較),設 \(s'\)\(s\) 的第 \(l\) 到第 \(r\) 個字元組成的字串,則我們可以刪除 \([1,l-1]\) 或者 \([r+1,n]\) 的某一個子區間,計算出它們的總個數就是 \(\dfrac{l\times (l-1)}{2}\)\(\dfrac{(n-r)\times (n-r+1)}{2}\),兩者之和就是這一種情況的方案數。

我們也可以選定一個子串 \(s'\),滿足 \(s'\) 的某一段字首 \(pre\) 和某一段字尾 \(suc\) 不相交且兩者連線到一起就是 \(t\)。這樣的話我們刪掉 \(pre\)\(suc\) 中間的字元就符合條件了。

例如 ciaohallo 要變成 ciallo 就可以選取字首 cia 和字尾 llo,它們拼在一起就是我們要得到的字串。值得注意的是,為了避免掉與上一種情況出現重複計算,我們需要滿足 \(pre\)\(suc\) 中間必須留有至少一個字元。在上面的例子中,\(pre\)\(suc\) 之間就留有 oha 這一段字元。

但是怎麼計算滿足上述條件的 \(s'\) 的個數呢?我們可以設 \(p_i\) 表示 \(s\) 的第 \(i\) 個到第 \(n\) 個字元與 \(t\)最大公共字首的長度,\(q_i\) 表示 \(s\) 的第 \(1\) 個到第 \(i\) 個字元與 \(t\)最大公共字尾的長度。對於 \(\forall i\in [1,n],j\in[i+m,n]\),如果有 \(p_i+q_j≥m\),則區間 \([i,j]\) 組成的子串 \(s'\) 就是滿足條件的。當然 \(s'\) 變成 \(t\) 的刪除方法也會有不同的代價。但是我們其實可以推出這個代價就是 \(p_i+q_j-m+1\)

為什麼是對的?我們來看一下這張圖:

圖中第一行是 \([i,i+p_i-1]\) 的字串 \(A\),第二行是 \([j-q_j+1,j]\) 的字串 \(B\),我們有 \(p_i=6,q_j=4,m=7\),則 \(p_i+q_j≥m\),這個時候我們將它們進行對齊就得到了這張圖。由於 \(p_i+q_j≥m\),所以 \(A\)\(B\) 一定有重合部分。由圖可知重合部分的長度為 \(p_i+q_j-m=3\)

我們可以考慮 \(pre\) 的選擇方法,進行得到唯一確定的 \(suc\) 的選擇方法。我們的 \(pre\) 可以選擇 ABAABAAABAABABAABA,即重合部分的長度加一 \(p_i+q_j-m+1\)。不難發現 \(pre\) 最長的時候就是取 \(A\),最短的時候就是取 \(A\) 中不與 \(B\) 重合的部分。這個結合圖片應該是非常好理解的,大家可以自己思考一下。

綜上所述,對於每一個 \(i\),我們找到所有的 \(j\in[i+m,n]\) 滿足 \(p_i+q_j≥m\),給答案增加 \(p_i+q_j-m+1\)。當然我們還需要事先讓 \(p_i,q_j\)\(m-1\) 取一個最小值,因為當 \(p_i\) 或者 \(q_j\) 等於 \(m\) 時,就會和第一種情況算重。

可是暴力找是 \(O(n^2)\) 的,我們需要最佳化。條件 \(p_i+q_j≥m\iff q_j≥m-p_i\),所以每一個 \(q_j\) 都是在 \([m-p_i,m-1]\) 內的,這樣我們就不難想到權值線段樹,儲存區間範圍內的 \(sum=\sum q_j\) 以及 \(q_j\) 的個數 \(num\)。則對於每一個 \(i\),它所帶來的代價就是 \(sum+num\times (p_i-m+1)\)

然後就可以愉快地打出程式碼了:

#include<bits/stdc++.h>
#define Q(id,x,y,flag) query(1,1,id,x,y,flag)
using namespace std;
const int MAXN=4e5+5;
const unsigned long long base=179;
int n,m;
char a[MAXN],b[MAXN];
unsigned long long mul[MAXN],Hash1[MAXN],Hash2[MAXN];
int p[MAXN],q[MAXN];
struct node
{
	long long sum;
	int num;//sum表示和,num表示個數
}T[MAXN<<2];
void pushup(int x){ T[x].num=T[x<<1].num+T[x<<1|1].num,T[x].sum=T[x<<1].sum+T[x<<1|1].sum; }
void change_tree(int x,int l,int r,int k)
{
	if(l>r) return;
	if(l==r)
	{
		T[x].sum+=l,T[x].num++;
		return;
	}
	int mid=(l+r)/2;
	if(k<=mid)  change_tree(x<<1,l,mid,k);
	else    change_tree(x<<1|1,mid+1,r,k);
	pushup(x);
}
long long query(int x,int l,int r,int L,int R,int flag)
{
	if(L>R) return 0;
	if(L<=l&&r<=R)
	{
		if(!flag)   return T[x].sum;
		return T[x].num;
	}
	int mid=(l+r)/2;
	long long res=0;
	if(L<=mid)  res+=query(x<<1,l,mid,L,R,flag);
	if(R>mid)   res+=query(x<<1|1,mid+1,r,L,R,flag);
	return res;
}
long long solve(int x){ return 1ll*x*(x+1)/2; }
int main()
{
	cin>>(a+1)>>(b+1);
	n=strlen(a+1),m=strlen(b+1);
	mul[0]=1;
	for(int i=1;i<=n;i++)   Hash1[i]=Hash1[i-1]*base+a[i],mul[i]=mul[i-1]*base;//雜湊預處理
	for(int i=1;i<=m;i++)   Hash2[i]=Hash2[i-1]*base+b[i];
	for(int i=1;i<=n;i++)
	{
		if(a[i]!=b[1])  continue;
		int l=1,r=min(n-i+1,m);
		while(l<=r)//二分找最長公共字首
		{
			int mid=(l+r)/2;
			if(Hash1[i+mid-1]-Hash1[i-1]*mul[mid]==Hash2[mid])  p[i]=mid,l=mid+1;
			else    r=mid-1;
		}
		p[i]=min(p[i],m-1);
	}
	for(int i=n;i>=1;i--)
	{
		if(a[i]!=b[m])  continue;
		int l=1,r=min(i,m);
		while(l<=r)//同理找字尾
		{
			int mid=(l+r)/2;
			if(Hash1[i]-Hash1[i-mid]*mul[mid]==Hash2[m]-Hash2[m-mid]*mul[mid])  q[i]=mid,l=mid+1;
			else    r=mid-1;
		}
		q[i]=min(q[i],m-1);
	}
	long long res=0;
	for(int i=1;i<=n-m+1;i++)
	{
		if(Hash1[i+m-1]-Hash1[i-1]*mul[m]==Hash2[m])    res+=solve(i-1)+solve(n-i-m+1);//第一種情況
	}
	if(q[n])	change_tree(1,1,m-1,q[n]);
	for(int i=n-m;i>=1;i--)//注意是逆序!!!!
	{
		if(p[i])	res+=Q(m-1,m-p[i],m-1,0)+Q(m-1,m-p[i],m-1,1)*(p[i]-m+1);
		if(q[i+m-1])	change_tree(1,1,m-1,q[i+m-1]);//每次要記得更新線段樹
	}
	cout<<res;
	return 0;
}

相關文章