CF504E Misha and LCP on Tree

LawrenceD發表於2024-08-15

給定一棵 n個節點的樹,每個節點有一個小寫字母。
有 m 組詢問,每組詢問為樹上 a→b 和 c→d 組成的字串的最長公共字首。
\(n≤3×10^5,m≤10^6\)

兩個字串求任意子串的最長公共字首,可以二分+雜湊。
樹上的路徑透過樹剖分出\(log(n)\)個區間得到\(log(n)\)個字串,每次比較一個區間的雜湊值。當區間雜湊值不一樣,二分。
實現要注意很多細節。a→lca, lca→b 要分開考慮並且求雜湊。

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define P pair<LL, LL>
#define mk make_pair
#define fs first
#define sc second
const int N = 3e5 + 11;
const LL modd[11] = {19260817, 381335473,  252412907, 173275057, 415924151, 490280677, 215695621, 311094167};
bool vis[11];
LL mod1, mod2;
int n, m;
int head[N], nex[N<<1], to[N<<1], size;
int dep[N], sz[N], dfn[N], son[N], id[N], fa[N], top[N], cnt;
char s[N];
P B = mk(127, 23333), h1[N], h2[N], p[N];
vector<P> up, dw, f, g;
P operator * (P a, P b){
    return mk(a.fs * b.fs % mod1, a.sc * b.sc % mod2);
}
P operator + (P a, P b){
    return mk((a.fs + b.fs) % mod1, (a.sc + b.sc) % mod2);
}
P operator - (P a, P b){
    return mk((a.fs + mod1 - b.fs) % mod1, (a.sc + mod2 - b.sc) % mod2);
}
int get_mod(){
    int x;
    do{
        x = rand() % 8;
    }while(vis[x]);
    vis[x] = 1;
    return modd[x];
}
int read(){
    int x = 0;
    char ch = getchar();
    while(ch > '9' || ch < '0')ch = getchar();
    while(ch >= '0' && ch <= '9')x = 10 * x + ch - 48, ch = getchar();
    return x;
}
void print(int x){
    if(x > 9)print(x / 10);
    putchar(x % 10 + 48);
}
void dfs(int u){
    sz[u] = 1;
    dep[u] = dep[fa[u]] + 1;
    for(int i = head[u];i;i = nex[i]){
        int v = to[i];
        if(v == fa[u])continue;
        fa[v] = u;
        dfs(v);
        sz[u] += sz[v];
        if(sz[son[u]] < sz[v])son[u] = v;
    }
}
void dfs(int u, int tp){
    top[u] = tp;
    dfn[u] = ++cnt; 
    id[cnt] = u;
    if(son[u])dfs(son[u], tp);
    for(int i = head[u];i; i = nex[i]){
        int v = to[i];
        if(v == fa[u] || v == son[u])continue;
        dfs(v, v);
    }
}
int LCA(int x, int y){
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]])swap(x, y);
        x = fa[top[x]];
    } 
    return dep[x] > dep[y] ? y : x;
}
vector<P> get(int x, int y){
    int lca = LCA(x, y);
    up.clear();
    while(top[x] != top[lca])up.push_back(mk(dfn[x], dfn[top[x]])), x = fa[top[x]];
    up.push_back(mk(dfn[x], dfn[lca]));
    while(top[y] != top[lca])dw.push_back(mk(dfn[top[y]], dfn[y])), y = fa[top[y]];
    if(lca != y){
        dw.push_back(mk(dfn[son[lca]], dfn[y]));
    }
    while(dw.size()) up.push_back(dw.back()), dw.pop_back();
    return up;
}
P get(int fl, int st, int len){
    if(fl)return h2[st-len+1] - h2[st+1] * p[len];
    return h1[st+len-1] - h1[st-1] * p[len];
}
void add(int x, int y){
    to[++size] = y; nex[size] = head[x]; head[x] = size;
}
int main(){
    srand(time(0));
    mod1 = get_mod();
    mod2 = get_mod();
    cin>>n;
    scanf("%s", s + 1);
    p[0] = mk(1, 1); p[1] = B;
    int u, v;
    for(int i = 2;i <= n; i++){
        u = read(); v = read();
        add(u, v);add(v, u);
        p[i] = p[i-1] * B;
    }
    dfs(1);
    dfs(1, 1);
    for(int i = 1;i <= n; i++){
        int x = s[id[i]];
        h1[i] = h1[i-1] * B + mk(x, x);
    }
    for(int i = n;i >= 1; i--){
        int x = s[id[i]];
        h2[i] = h2[i+1] * B + mk(x, x);
    }
    m = read();
    while(m--){
        int a = read(), b = read(), c = read(), d = read();
        f = get(a, b); g = get(c, d);
        int i = 0, j = 0, ans = 0;
        while(i < f.size() && j < g.size()){
            int df1 = f[i].fs, df2 = f[i].sc;
            int dg1 = g[j].fs, dg2 = g[j].sc;
            int flf = df2 < df1, flg = dg2 < dg1;
            int lf = abs(df1 - df2) + 1, lg = abs(dg2 - dg1) + 1;
            int len = min(lf, lg);
            P hf = get(flf, df1, len), hg = get(flg, dg1, len);
            if(hf == hg){
                if(len == lf)i++;
                else f[i].fs = flf == 1 ? df1 - len : df1 + len;
                if(len == lg)j++;
                else g[j].fs = flg == 1 ? dg1 - len : dg1 + len;
                ans += len;
            }
            else{
                int l = 0, r = len, res = 0;
                while(l <= r){
                    int mid = l + r >> 1;
                    if(get(flf, df1, mid) == get(flg, dg1, mid))l = mid + 1, res = mid;
                    else r = mid - 1;
                }
                ans += res;
                break;
            }
        }
        print(ans);
        puts("");
    }
    return 0;
}