abc324E 合併兩字串後能得到某個子序列的方案數

chenfy27發表於2024-03-17

給定n個字串s[n]和字串t,從中任選一對下標(i,j),要求i<=j,讓s[i]與s[j]連起來得到一個新的串,要求由新串刪除0個或多個字元可以得到t,問存在多少對滿足條件的下標對?
1<=n<=5e5; 1<=len(s[i]),len(t)<=5e5

假設由字串x能得到t的字首長度為a,字串y能得到t的字尾長度為b,如果a+b>=len(t),那麼由x+y一定能得到子序列t。計數那裡可以開cnt陣列統計,然後求字首和,這裡偷懶直接套平衡樹模板。

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=b;i>=a;i--)

template <typename TYPE>
struct Treap {
    struct Node {
        TYPE data, sum;
        int rnd, siz, dup, son[2];
        void init(const TYPE & d) {
            data = sum = d;
            rnd = rand();
            siz = dup = 1;
            son[0] = son[1] = 0;
        }
    };
    Treap(size_t sz, bool multi):multiple(multi) {
        node.resize(sz);
        reset();
    }
    int newnode(const TYPE & d) {
        total += 1;
        node[total].init(d);
        return total;
    }
    void reset() { root = total = 0; }
    void maintain(int x) {
        node[x].siz = node[x].dup;
        node[x].sum = node[x].data * node[x].dup;
        if (node[x].son[0]) {
            node[x].siz += node[node[x].son[0]].siz;
            node[x].sum += node[node[x].son[0]].sum;
        }
        if (node[x].son[1]) {
            node[x].siz += node[node[x].son[1]].siz;
            node[x].sum += node[node[x].son[1]].sum;
        }
    }
    void rotate(int d, int &r) {
        int k = node[r].son[d^1];
        node[r].son[d^1] = node[k].son[d];
        node[k].son[d] = r;
        maintain(r);
        maintain(k);
        r = k;
    }
    void insert(const TYPE &data, int &r, bool &ans) {
        if (r) {
            if (!(data < node[r].data) && !(node[r].data < data)) {
                ans = false;
                if (multiple) {
                    node[r].dup += 1;
                    maintain(r);
                }
            } else {
                int d = data < node[r].data ? 0 : 1;
                insert(data, node[r].son[d], ans);
                if (node[node[r].son[d]].rnd > node[r].rnd) {
                    rotate(d^1, r);
                } else {
                    maintain(r);
                }
            }
        } else {
            r = newnode(data);
        }
    }
    void getkth(int k, int r, TYPE& data) {
        int x = node[r].son[0] ? node[node[r].son[0]].siz : 0;
        int y = node[r].dup;
        if (k <= x) {
            getkth(k, node[r].son[0], data);
        } else if (k <= x + y) {
            data = node[r].data;
        } else {
            getkth(k-x-y, node[r].son[1], data);
        }
    }
    TYPE getksum(int k, int r) {
        if (k <= 0 || r == 0) return 0;
        int x = node[r].son[0] ? node[node[r].son[0]].siz : 0;
        int y = node[r].dup;
        if (k <= x) return getksum(k, node[r].son[0]);
        if (k <= x+y) return node[node[r].son[0]].sum + node[r].data * (k-x);
        return node[node[r].son[0]].sum + node[r].data * y + getksum(k-x-y,node[r].son[1]);
    }
    void erase(const TYPE& data, int & r) {
        if (r == 0) return;
        int d = -1;
        if (data < node[r].data) {
            d = 0;
        } else if (node[r].data < data) {
            d = 1;
        }
        if (d == -1) {
            node[r].dup -= 1;
            if (node[r].dup > 0) {
                maintain(r);
            } else {
                if (node[r].son[0] == 0) {
                    r = node[r].son[1];
                } else if (node[r].son[1] == 0) {
                    r = node[r].son[0];
                } else {
                    int dd = node[node[r].son[0]].rnd > node[node[r].son[1]].rnd ? 1 : 0;
                    rotate(dd, r);
                    erase(data, node[r].son[dd]);
                }
            }
        } else {
            erase(data, node[r].son[d]);
        }
        if (r) maintain(r);
    }
    int ltcnt(const TYPE& data, int r) {
        if (r == 0) return 0;
        int x = node[r].son[0] ? node[node[r].son[0]].siz : 0;
        if (data < node[r].data) {
            return ltcnt(data, node[r].son[0]);
        }
        if (!(data < node[r].data) && !(node[r].data < data)) {
            return x;
        }
        return x + node[r].dup + ltcnt(data, node[r].son[1]);
    }
    int gtcnt(const TYPE& data, int r) {
        if (r == 0) return 0;
        int x = node[r].son[1] ? node[node[r].son[1]].siz : 0;
        if (data > node[r].data) {
            return gtcnt(data, node[r].son[1]);
        }
        if (!(data < node[r].data) && !(node[r].data < data)) {
            return x;
        }
        return x + node[r].dup + gtcnt(data, node[r].son[0]);
    }
    int count(const TYPE& data, int r) {
        if (r == 0) return 0;
        if (data < node[r].data) return count(data, node[r].son[0]);
        if (node[r].data < data) return count(data, node[r].son[1]);
        return node[r].dup;
    }
    void prev(const TYPE& data, int r, TYPE& result, bool& ret) {
        if (r) {
            if (node[r].data < data) {
                if (ret) {
                    result = max(result, node[r].data);
                } else {
                    result = node[r].data;
                    ret = true;
                }
                prev(data, node[r].son[1], result, ret);
            } else {
                prev(data, node[r].son[0], result, ret);
            }
        }
    }
    void next(const TYPE& data, int r, TYPE& result, bool& ret) {
        if (r) {
            if (data < node[r].data) {
                if (ret) {
                    result = min(result, node[r].data);
                } else {
                    result = node[r].data;
                    ret = true;
                }
                next(data, node[r].son[0], result, ret);
            } else {
                next(data, node[r].son[1], result, ret);
            }
        }
    }
    vector<Node> node;
    int root, total;
    bool multiple;
    bool insert(const TYPE& data) {
        bool ret = true;
        insert(data, root, ret);
        return ret;
    }
    bool kth(int k, TYPE &data) {
        if (!root || k <= 0 || k > node[root].siz)
            return false;
        getkth(k, root, data);
        return true;
    }
    TYPE ksum(int k) {
        assert(root && k>0 && k<=node[root].siz);
        return getksum(k, root);
    }
    int count(const TYPE &data) {
        return count(data, root);
    }
    int size() const {
        return root ? node[root].siz : 0;
    }
    void erase(const TYPE& data) {
        return erase(data, root);
    }
    int ltcnt(const TYPE& data) {
        return ltcnt(data, root);
    }
    int gtcnt(const TYPE& data) {
        return gtcnt(data, root);
    }
    int lecnt(const TYPE& data) {
        return size() - gtcnt(data, root);
    }
    int gecnt(const TYPE& data) {
        return size() - ltcnt(data, root);
    }
    bool prev(const TYPE& data, TYPE& result) {
        bool ret = false;
        prev(data, root, result, ret);
        return ret;
    }
    bool next(const TYPE& data, TYPE& result) {
        bool ret = false;
        next(data, root, result, ret);
        return ret;
    }
};

int getpre(const string &s, const string &t) {
    int k = 0;
    int ns = s.size(), nt = t.size();
    for (int i = 0, j = 0; i < ns && j < nt; i++) {
        if (s[i] == t[j]) {
            j += 1;
            k += 1;
        }
    }
    return k;
}
int getsuf(const string &s, const string &t) {
    int k = 0;
    int ns = s.size(), nt = t.size();
    for (int i = ns-1, j = nt-1; i >= 0 && j >= 0; i--) {
        if (s[i] == t[j]) {
            j -= 1;
            k += 1;
        }
    }
    return k;
}

const int N = 500005;
int n;
string T, S[N];
void solve() {
    cin >> n >> T;
    Treap<int> tp(N, true);
    rep(i,1,n) {
        cin >> S[i];
        int pre = getpre(S[i], T);
        tp.insert(pre);
    }
    int ans = 0;
    rep(i,1,n) {
        int suf = getsuf(S[i], T);
        ans += tp.gecnt(T.size()-suf);
    }
    cout << ans << "\n";
}

signed main() {
    cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}

相關文章