hdu7462-字串【SAM,二分】

QuantAsk發表於2024-08-13

正題

題目連結:https://acm.hdu.edu.cn/showproblem.php?pid=7462


題目大意

你有一個由 \(a,b\) 組成的字串 \(s\)
\(m\) 個操作:

  • 詢問有多少個本質不同的串 \(t\) 使得 \(s[l,r]\)\(t\) 的子串且兩個串在 \(s\) 中的出現次數相同。
  • 詢問有多少個本質不同的串 \(t\) 使得 \(t\)\(s[l,r]\) 的子串且兩個串在 \(s\) 中的出現次數相同。

強制線上

\(\sum |s|\leq 5\times 10^5,\sum q\leq 5\times 10^5\)


解題思路

首先 \(t\) 肯定也是 \(s\) 的子串,所以我們考慮在SAM上解決這個問題。

先建一個SAM,我們先考慮詢問1,設詢問區間為 \([l,r]\) ,我們先在SAM上找到對應節點(經典方法建立SAM時記錄每個字尾所在節點,然後從這個節點在parent樹往上倍增跳到 \(len\) 區間為 \(r-l+1\) 的位置)。

那麼此時所有出現次數相同的串 \(t=s[x,r]\ (x\leq l)\) 也在同一個節點處,而且我們可以用 \(len_x-(r-l)\) 得到個數。

此時考慮右端點往右擴充套件的情況,相當於往當前節點SAM字元 \(s_{r+1}\) 的方向走一步,暴力跳肯定是不對的,我們分析一下性質。

假設 \(s_{r+1}=a\) ,假設當前節點 \(x\) 既有字元 \(a\) 的邊也有字元 \(b\) 的邊,那麼說明往下走了之後出現次數肯定會比當前串少,所以就沒有往下走的必要了。也就是我們往下走的路徑肯定是SAM上的一條出現次數相同的鏈。

我們可以先處理出每個節點在原串的出現次數,然後把所有這樣的鏈拉出來,詢問時直接二分我們能夠往後走到哪個位置,順便使用字首和記錄一下 \(len_x\) 的和即可。

對於詢問2操作類似的變為往前調,記錄 \(len_{fa_x}\) 的和,但是需要考慮的一點是因為當前節點的長度是一個區間 \(len_{fa_x}\leq r-l+1\leq len_x\),我們往前走到 \(x'\) 時可能存在 \(r-l\leq len_{fa_x'}\) 的情況,也就是對應節點需要往上跳,但是我們考慮如果記 \(c_x\)\(x\) 節點對應串的出現次數,那麼有 \(c_{fa_{x'}}>c_{x'}\geq c_{x}\) ,所有如果往上跳了 \(s[l,r-1]\) 出現次數肯定就比 \(s[l,r]\) 多了,就沒有繼續的必要了,所以我們二分時還需要維護一下中間是否需要往上跳。

時間複雜度:\(O(n\log n)\)

因為是賽時程式碼很多東西沒想明白所以程式會寫的比較臃腫。


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<cctype>
#define ull unsigned long long
#define ll long long
using namespace std;
const int N=1e6+10;
const ull g=131;
struct node{
	int c,fa;
	ll l,r;
	ull h;
}zero;
int T,n,q,las,cnt,tot;
int fa[N],len[N],ch[N][2],ct[N],f[N][20];
int ed[N],pos[N],seg[N],p[N];
vector<node> v[N];
vector<int> G[N];
ull pw[N],has[N];
char s[N];
int read(){
	int x=0,f=1;char c=getchar();
	while(!isdigit(c)){if(c=='-')f=-f;c=getchar();}
	while(isdigit(c)){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
	return x*f;
}
void ins(int c){
	int p=las;int np=las=++cnt;
	len[np]=len[p]+1;
	for(;p&&!ch[p][c];p=fa[p])ch[p][c]=np;
	if(!p)fa[np]=1;
	else{
		int q=ch[p][c];
		if(len[p]+1==len[q])fa[np]=q;
		else{
			int nq=++cnt;len[nq]=len[p]+1;
			ch[nq][0]=ch[q][0];ch[nq][1]=ch[q][1];
//			memcpy(ch[nq],ch[q],sizeof(ch[nq]));
			fa[nq]=fa[q];fa[np]=fa[q]=nq;
			for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;
		}
	}
	return;
}
void dfs(int x){
	for(int i=0;i<G[x].size();i++){
		f[G[x][i]][0]=x,dfs(G[x][i]);
		ct[x]+=ct[G[x][i]];
	}
	return;
}
bool cmp(int x,int y)
{return (len[x]==len[y])?(x>y):(len[x]<len[y]);}
int main()
{
	pw[0]=1;
	for(int i=1;i<N;i++)pw[i]=pw[i-1]*g;
	scanf("%d",&T);
	while(T--){
		scanf("%s",s+1);n=strlen(s+1);
		las=cnt=1;
		for(int i=1;i<=n;i++){
			ins(s[i]-'a');
			ed[i]=las;ct[las]++;
			has[i]=has[i-1]*g+s[i]-'a';
		}
		for(int i=2;i<=cnt;i++)
			G[fa[i]].push_back(i);
		dfs(1);
		for(int j=1;j<20;j++)
			for(int i=1;i<=cnt;i++)
				f[i][j]=f[f[i][j-1]][j-1];
		for(int i=1;i<=cnt;i++)p[i]=i;
		sort(p+1,p+1+cnt,cmp);
		for(int i=1;i<=cnt;i++){
			int x=p[i],c=0;
			if(!pos[x])pos[x]=++tot,v[pos[x]].push_back(zero);
			if(ch[x][0]&&ct[x]==ct[ch[x][0]])pos[ch[x][0]]=pos[x],c=0;
			if(ch[x][1]&&ct[x]==ct[ch[x][1]])pos[ch[x][1]]=pos[x],c=1;
			node w;
			w.c=c;w.r=len[x];w.l=len[fa[x]];
			w.fa=len[fa[x]];
			v[pos[x]].push_back(w);
			seg[x]=v[pos[x]].size()-1;
		}
		for(int i=1;i<=tot;i++){
			for(int j=1;j<v[i].size();j++){
				v[i][j].h=v[i][j-1].h*g+v[i][j].c;
				v[i][j].l+=v[i][j-1].l;
				v[i][j].r+=v[i][j-1].r;
			}
		}
		scanf("%d",&q);long long lasans=0;
		while(q--){
			int op=read(),l=read(),r=read();
			l=(l+lasans-1)%n+1;r=(r+lasans-1)%n+1;
			if(op==1){
				int x=ed[r];
				for(int j=19;j>=0;j--)
					if(len[f[x][j]]>=r-l+1)x=f[x][j];
				int id=pos[x],now=seg[x];
				int L=now,R=v[id].size()-2;
				while(L<=R){
					int mid=(L+R)>>1,dl=mid-now+1;
					if(v[id][mid].h-v[id][now-1].h*pw[dl]==has[r+dl]-has[r]*pw[dl])L=mid+1;
					else R=mid-1;
				}
				int dl=L-now;
				long long ans=(v[id][L].r-v[id][now-1].r)-1ll*((r-l)+(r-l+dl))*(dl+1)/2ll;
				printf("%lld\n",ans);lasans=ans%n;
			}else{
				int x=ed[r];
				for(int j=19;j>=0;j--)
					if(len[f[x][j]]>=r-l+1)x=f[x][j];
				int id=pos[x],now=seg[x];
				int L=max(1,now-(r-l+1)),R=now-1;
				while(L<=R){
					int mid=(L+R)>>1,dl=now-mid;
					if(v[id][now-1].h-v[id][mid-1].h*pw[dl]==has[r]-has[r-dl]*pw[dl]
					&&v[id][mid].fa<(r-l+1)-dl)R=mid-1;
					else L=mid+1;
				}
				int dl=now-L;
				long long ans=1ll*((r-l+1)+(r-l+1-dl))*(dl+1)/2ll-(v[id][now].l-v[id][L-1].l);
				printf("%lld\n",ans);lasans=ans%n;
			}
		}
		
		for(int i=1;i<=cnt;i++){
			fa[i]=ch[i][0]=ch[i][1]=len[i]=ct[i]=0;
			pos[i]=seg[i]=ed[i]=p[i]=has[i]=0;
			G[i].clear();
		}
		for(int i=1;i<=tot;i++)v[i].clear();
		tot=cnt=0;
	}
	return 0;
}

相關文章