SGU 505

SFlyer發表於2024-07-10

link

這裡講兩種做法,一個線上,一個離線。

線上

我們分別考慮字首和字尾。有一個比較重要的結論,就是把 \(s\) 按照字典序排序以後,相同字首的出現位置(其實就是 rank)是連續的。\(s\) 翻過來,相同字尾的也是連續的。

這樣我們就可以求出每一個詢問字首和字尾對應的區間是什麼,然後就要求區間重合的數有多少個就可以了。一個做法是,設兩個排完序的 \(id\)\(p,q\),我們把 \(p\) 弄成在 \(q\) 中出現的位置,這樣就要求 \(p\) 區間內在一個範圍內數的個數。

現在的問題是給定一個陣列 \(a\),詢問 \(a_{l\sim r}\)\(c\le a_i\le d\) 的個數,也就是小於 \(d\) 的個數減去小於 \(c-1\) 的個數。可以用主席樹維護。(也可以分塊,但是這樣就是 \(\mathcal{O}(n\sqrt{n}\log n)\) 的了,不知道能不能過)

還有一個可能的問題是查詢字首/字尾對應區間。這個用二分有可能特判比較多(端點的特判是不是這個字首/字尾),所以我們可以直接用一個 trie,每一個節點記錄對應到這個點的區間。

可能要卡一點空間。程式碼不算難寫。

#include <bits/stdc++.h>

using namespace std;

using ll = long long;

const int N = 1e5+5;

int n,val[N],idx[N];

struct node {
	string s;
	int id;
	bool operator < (const node &a) const{
		return s<a.s;
	}
} a[N],ra[N];

string s[N],rs[N];

int rt[N],cnt;

struct tnode {
	int l,r,sum;
} t[N*20];

void upd(int l,int r,int &x,int y,int pos){
	t[++cnt]=t[y];
	t[cnt].sum++;
	x=cnt;
	if (l==r){
		return;
	}
	int mid=l+r>>1;
	if (pos<=mid){
		upd(l,mid,t[x].l,t[y].l,pos);
	}
	else{
		upd(mid+1,r,t[x].r,t[y].r,pos);
	}
}

int qy(int l,int r,int x,int y,int k){
	if (k==0){
		return 0;
	}
	if (r<=k){
		return t[y].sum-t[x].sum;
	}
	int mid=l+r>>1;
	if (k<=mid){
		return qy(l,mid,t[x].l,t[y].l,k);
	}
	else{
		return t[t[y].l].sum-t[t[x].l].sum+qy(mid+1,r,t[x].r,t[y].r,k);
	}
}

struct trie {
	int ch[N*10][26],mn[N*10],mx[N*10],tot=1;
	void init(){
		for (int i=0; i<N*10; i++){
			mn[i]=1e9;
			mx[i]=-1e9;
		}
	}
	void ins(string s,int id){
		int cur=1;
		for (auto c : s){
			if (!ch[cur][c-'a']){
				ch[cur][c-'a']=++tot;
			}
			mn[cur]=min(mn[cur],id);
			mx[cur]=max(mx[cur],id);
			cur=ch[cur][c-'a'];
		}
		mn[cur]=min(mn[cur],id);
		mx[cur]=max(mx[cur],id);
	}
	pair<int,int> seg(string s){
		int cur=1;
		for (auto c : s){
			if (!ch[cur][c-'a']){
				return {-1,-1};
			}
			cur=ch[cur][c-'a'];
		}
		return {mn[cur],mx[cur]};
	}
} ta,tr;

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);

	cin>>n;
	for (int i=1; i<=n; i++){
		cin>>a[i].s;
		a[i].id=i;
		ra[i].s=a[i].s;
		reverse(ra[i].s.begin(),ra[i].s.end());
		ra[i].id=i;
	}
	sort(a+1,a+1+n);
	sort(ra+1,ra+1+n);
	for (int i=1; i<=n; i++){
		s[i]=a[i].s;
		rs[i]=ra[i].s;
		idx[ra[i].id]=i;
	}
	for (int i=1; i<=n; i++){
		val[i]=idx[a[i].id];
	}
	for (int i=1; i<=n; i++){
		upd(1,n,rt[i],rt[i-1],val[i]);
	}
	ta.init();
	tr.init();
	for (int i=1; i<=n; i++){
		ta.ins(s[i],i);
		tr.ins(rs[i],i);
	}
	int q;
	cin>>q;
	while (q--){
		string x,y;
		cin>>x>>y;
		reverse(y.begin(),y.end());
		auto u=ta.seg(x);
		auto v=tr.seg(y);
		if (u.first==-1 || v.first==-1){
			cout<<"0\n";
			continue;
		}
		int ans=qy(1,n,rt[u.first-1],rt[u.second],v.second);
		ans-=qy(1,n,rt[u.first-1],rt[u.second],v.first-1);
		cout<<ans<<"\n";
	}
	return 0;
}

離線

這個是用 ACAM 的做法。首先複習一下 ACAM 能做什麼:求一個文字串裡面出現其他串的個數,多次詢問。

如果我們把一個串寫成 \(s_i|s_i\) 的形式(如 ab 寫成 ab|ab),然後查詢相當於問有沒有一個 \(suf|pre\) 的子串。我們把所有的文字串用 ,(或其他字元)連成一個長串,就可以直接查詢了!

這個除了模板的程式會更好寫,但是唯一劣勢是離線。

Credit: Codeforces @cayaxi09.

#include <bits/stdc++.h>

using namespace std;

using ll = long long;

const int N = 5e5+5;

int n;
string T,s[N];
int ch[N][30],tag[N],fa[N],ans[N];
int cnt,vis[N],in[N],mp[N];

void init(){
	cnt=1;
	for (int i=0; i<N; i++){
		memset(ch[i],0,sizeof ch[i]);
		tag[i]=fa[i]=0;
	}
	for (int i=0; i<N; i++){
		vis[i]=0;
	}
}

int get(char c){
	if ('a'<=c && c<='z'){
		return c-'a';
	} 
	if (c=='|'){
		return 26;
	}
	return 27;
}

void ins(string s,int id){
	int cur=1;
	for (int i=0; i<s.size(); i++){
		int c=get(s[i]);
		if (!ch[cur][c]){
			ch[cur][c]=++cnt;
		}
		cur=ch[cur][c];
	}
	if (!tag[cur]){
		tag[cur]=id;
	}
	mp[id]=tag[cur];
}

void get_fail(){
	queue<int> q;
	q.push(1);
	for (int i=0; i<28; i++){
		ch[0][i]=1;
	}
	fa[1]=0;
	while (!q.empty()){
		int u=q.front();
		q.pop();
		for (int i=0; i<28; i++){
			if (!ch[u][i]){
				ch[u][i]=ch[fa[u]][i];
			}
			else{
				fa[ch[u][i]]=ch[fa[u]][i];
				in[fa[ch[u][i]]]++;
				q.push(ch[u][i]);
			}
		}
	}
}

void qy(string s){
	int cur=1;
	for (int i=0; i<s.size(); i++){
		int c=get(s[i]);
		int z=ch[cur][c];
		ans[z]++;
		cur=z;
	}
}

void sol(){
	queue<int> q;
	for (int i=1; i<=cnt; i++){
		if (!in[i]){
			q.push(i);
		}
	}
	while (!q.empty()){
		int u=q.front();
		q.pop();
		vis[tag[u]]=ans[u];
		in[fa[u]]--;
		ans[fa[u]]+=ans[u];
		if (!in[fa[u]]){
			q.push(fa[u]);
		}
	}
}

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);

	cin>>n;
	init();
	for (int i=1; i<=n; i++){
		string t;
		cin>>t;
		T+=t;
		T+="|";
		T+=t;
		T+=",";
	}
	int q;
	cin>>q;
	for (int i=1; i<=q; i++){
		string x,y;
		cin>>x>>y;
		s[i]=y+"|"+x;
		ins(s[i],i);
	}
	get_fail();
	qy(T);
	sol();
	for (int i=1; i<=q; i++){
		cout<<vis[mp[i]]<<"\n";
	}
	return 0;
}