link
這裡講兩種做法,一個線上,一個離線。
線上
我們分別考慮字首和字尾。有一個比較重要的結論,就是把 \(s\) 按照字典序排序以後,相同字首的出現位置(其實就是 rank)是連續的。\(s\) 翻過來,相同字尾的也是連續的。
這樣我們就可以求出每一個詢問字首和字尾對應的區間是什麼,然後就要求區間重合的數有多少個就可以了。一個做法是,設兩個排完序的 \(id\) 為 \(p,q\),我們把 \(p\) 弄成在 \(q\) 中出現的位置,這樣就要求 \(p\) 區間內在一個範圍內數的個數。
現在的問題是給定一個陣列 \(a\),詢問 \(a_{l\sim r}\) 中 \(c\le a_i\le d\) 的個數,也就是小於 \(d\) 的個數減去小於 \(c-1\) 的個數。可以用主席樹維護。(也可以分塊,但是這樣就是 \(\mathcal{O}(n\sqrt{n}\log n)\) 的了,不知道能不能過)
還有一個可能的問題是查詢字首/字尾對應區間。這個用二分有可能特判比較多(端點的特判是不是這個字首/字尾),所以我們可以直接用一個 trie,每一個節點記錄對應到這個點的區間。
可能要卡一點空間。程式碼不算難寫。
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1e5+5;
int n,val[N],idx[N];
struct node {
string s;
int id;
bool operator < (const node &a) const{
return s<a.s;
}
} a[N],ra[N];
string s[N],rs[N];
int rt[N],cnt;
struct tnode {
int l,r,sum;
} t[N*20];
void upd(int l,int r,int &x,int y,int pos){
t[++cnt]=t[y];
t[cnt].sum++;
x=cnt;
if (l==r){
return;
}
int mid=l+r>>1;
if (pos<=mid){
upd(l,mid,t[x].l,t[y].l,pos);
}
else{
upd(mid+1,r,t[x].r,t[y].r,pos);
}
}
int qy(int l,int r,int x,int y,int k){
if (k==0){
return 0;
}
if (r<=k){
return t[y].sum-t[x].sum;
}
int mid=l+r>>1;
if (k<=mid){
return qy(l,mid,t[x].l,t[y].l,k);
}
else{
return t[t[y].l].sum-t[t[x].l].sum+qy(mid+1,r,t[x].r,t[y].r,k);
}
}
struct trie {
int ch[N*10][26],mn[N*10],mx[N*10],tot=1;
void init(){
for (int i=0; i<N*10; i++){
mn[i]=1e9;
mx[i]=-1e9;
}
}
void ins(string s,int id){
int cur=1;
for (auto c : s){
if (!ch[cur][c-'a']){
ch[cur][c-'a']=++tot;
}
mn[cur]=min(mn[cur],id);
mx[cur]=max(mx[cur],id);
cur=ch[cur][c-'a'];
}
mn[cur]=min(mn[cur],id);
mx[cur]=max(mx[cur],id);
}
pair<int,int> seg(string s){
int cur=1;
for (auto c : s){
if (!ch[cur][c-'a']){
return {-1,-1};
}
cur=ch[cur][c-'a'];
}
return {mn[cur],mx[cur]};
}
} ta,tr;
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n;
for (int i=1; i<=n; i++){
cin>>a[i].s;
a[i].id=i;
ra[i].s=a[i].s;
reverse(ra[i].s.begin(),ra[i].s.end());
ra[i].id=i;
}
sort(a+1,a+1+n);
sort(ra+1,ra+1+n);
for (int i=1; i<=n; i++){
s[i]=a[i].s;
rs[i]=ra[i].s;
idx[ra[i].id]=i;
}
for (int i=1; i<=n; i++){
val[i]=idx[a[i].id];
}
for (int i=1; i<=n; i++){
upd(1,n,rt[i],rt[i-1],val[i]);
}
ta.init();
tr.init();
for (int i=1; i<=n; i++){
ta.ins(s[i],i);
tr.ins(rs[i],i);
}
int q;
cin>>q;
while (q--){
string x,y;
cin>>x>>y;
reverse(y.begin(),y.end());
auto u=ta.seg(x);
auto v=tr.seg(y);
if (u.first==-1 || v.first==-1){
cout<<"0\n";
continue;
}
int ans=qy(1,n,rt[u.first-1],rt[u.second],v.second);
ans-=qy(1,n,rt[u.first-1],rt[u.second],v.first-1);
cout<<ans<<"\n";
}
return 0;
}
離線
這個是用 ACAM 的做法。首先複習一下 ACAM 能做什麼:求一個文字串裡面出現其他串的個數,多次詢問。
如果我們把一個串寫成 \(s_i|s_i\) 的形式(如 ab
寫成 ab|ab
),然後查詢相當於問有沒有一個 \(suf|pre\) 的子串。我們把所有的文字串用 ,
(或其他字元)連成一個長串,就可以直接查詢了!
這個除了模板的程式會更好寫,但是唯一劣勢是離線。
Credit: Codeforces @cayaxi09.
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 5e5+5;
int n;
string T,s[N];
int ch[N][30],tag[N],fa[N],ans[N];
int cnt,vis[N],in[N],mp[N];
void init(){
cnt=1;
for (int i=0; i<N; i++){
memset(ch[i],0,sizeof ch[i]);
tag[i]=fa[i]=0;
}
for (int i=0; i<N; i++){
vis[i]=0;
}
}
int get(char c){
if ('a'<=c && c<='z'){
return c-'a';
}
if (c=='|'){
return 26;
}
return 27;
}
void ins(string s,int id){
int cur=1;
for (int i=0; i<s.size(); i++){
int c=get(s[i]);
if (!ch[cur][c]){
ch[cur][c]=++cnt;
}
cur=ch[cur][c];
}
if (!tag[cur]){
tag[cur]=id;
}
mp[id]=tag[cur];
}
void get_fail(){
queue<int> q;
q.push(1);
for (int i=0; i<28; i++){
ch[0][i]=1;
}
fa[1]=0;
while (!q.empty()){
int u=q.front();
q.pop();
for (int i=0; i<28; i++){
if (!ch[u][i]){
ch[u][i]=ch[fa[u]][i];
}
else{
fa[ch[u][i]]=ch[fa[u]][i];
in[fa[ch[u][i]]]++;
q.push(ch[u][i]);
}
}
}
}
void qy(string s){
int cur=1;
for (int i=0; i<s.size(); i++){
int c=get(s[i]);
int z=ch[cur][c];
ans[z]++;
cur=z;
}
}
void sol(){
queue<int> q;
for (int i=1; i<=cnt; i++){
if (!in[i]){
q.push(i);
}
}
while (!q.empty()){
int u=q.front();
q.pop();
vis[tag[u]]=ans[u];
in[fa[u]]--;
ans[fa[u]]+=ans[u];
if (!in[fa[u]]){
q.push(fa[u]);
}
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n;
init();
for (int i=1; i<=n; i++){
string t;
cin>>t;
T+=t;
T+="|";
T+=t;
T+=",";
}
int q;
cin>>q;
for (int i=1; i<=q; i++){
string x,y;
cin>>x>>y;
s[i]=y+"|"+x;
ins(s[i],i);
}
get_fail();
qy(T);
sol();
for (int i=1; i<=q; i++){
cout<<vis[mp[i]]<<"\n";
}
return 0;
}