【小 w 的代數】(提供一種 n^2 log 的解法)

Aqr_Rn發表於2024-10-21

前言:

updated on 10.22 發現 \(n\log n\) 很好實現,做法看懂部落格後看評論區 14 樓

不卡常,目前 accoders 和 h.hszxoj 的最優解

賣點

記錄 CTH 的發言

CTH:你這真是 n^3 的
CTH:我也不知道你線段樹最佳化個啥,\(n^3 \log n\)
CTH:你最佳化到哪了啊
CTH:······你從賽時打這個題到現在 11 個小時了,你從 \(n^3\) 打到 \(n^3\log n\)
CTH:······再怎麼著,我也不會一道題調三天
CTH:我一直都說這麼打這麼打,你打的是啥呀
CTH:你連題面都沒弄懂唄

@CTHoi 黑子!叫!!!

不過還是要感謝 CTH 對這道題實現的大力支援與幫助,不然可能現在還調不出來呢

解法:

一句話以代之:微改線段樹合併最佳化圓方樹上 DP

考慮 \(O(n^3)\) 的“暴力”做法:

對於每一個節點為根跑一遍 dfs,求出以該點為終點的路徑有多少方案數,每次 dfs 進行以下樹上 dp。

我們在回溯的時候用子節點更新父親,那麼就相當於我們求從葉子到根的遞減路徑,這樣才可以最佳化。

\(f_{i,j}\) 表示以 \(i\) 號節點為根的子樹中以 \(j\) 為終點的答案,有轉移如下:

\[f_{u,i}=\sum_{u->v} \sum_{j=i+1}^n f_{v,j} \]

這樣樹的部分就做完了,但有環的地方怎麼辦?如下圖:

1 號點是 dfs 過程中第一次到達的點,我們把環中的 dfs 第一次到達的點稱為入點,此時我們直接讓 1 號點所在環內其他的點跑 dfs,也就是 2、3 點,假設現在它們都跑出來了以自己根的子樹內的 \(f\) 陣列。

實際情況其實環內除入點之外的每個點為根的子樹都分別可能順時針和逆時針各走一遍走到入點。

所以我們列舉每個點作為起點分別向右和向左各走一次,每一次維護出一個陣列 \(A_i\) 表示當前情況下以 \(i\) 為結尾的路徑方案數。

每次走到一個新點 \(x\),轉移如下:\(A' = A+f_x+\sum_{i=x+1}^n A_i\)

這樣時間複雜度不對。如上圖,最後得到的 \(A\) 陣列如下:

\[\begin{aligned} A =& f_2+\sum_{i=3}^n f_{2,i} (以 2 為起點向右走) \\ &+ f_2 (以 2 為起點向左走) \\ &+ f_3+\sum_{i=2}^n f_{2,j} (以 3 為起點向左走) \\ &+ f_3 (以 3 為起點向右走) \\ =& f_2+f_3+\sum_{i=3}^n f_{2,i} (把向左走的合到一起) \\ &+ f_2+f_3+\sum_{i=2}^n f_{2,j} (把向右走的合到一起) \end{aligned}\]

容易發現對入點的貢獻就是順時針走一遍逆時針走一遍,減去(重複的)每顆子樹原本的貢獻

那麼我們看做把 1、2 的連邊斷開,從 2 -> 3 -> 1 的路徑進行一次上述轉移;

反過來,從 3 -> 2 -> 1 的路徑再進行一次轉移。

所以總體思路就是,每遇到一個環把環上其他點的子樹先跑出來,然後順時針逆時針各跑一遍維護環中的子樹對入點的貢獻。

發現很簡單,\(n^3\) 就做完了,但 lxyt 說的好:“\(500^3\) 很難過啊”,Ratio 也說得好:“\(500^3\) 除非你常數極小”。

考慮最佳化,我們稱以 \(i\) 為終點的遞減路徑的方案數為 \(i\) 的方案數。

每次更新 \(u\) 點的方案時,會計算 \(u\) 點所有的子節點 \(v\) 的子樹內 所有大於 \(u\) 的點的方案數的總和,發現其實這就是簡單區間求和,可以線段樹維護。

  • 而是先整體考慮樹的轉移:

    對於每一個葉子點開一棵線段樹(動態開點),那麼我們每次從一個 \(v\) 點回溯到父節點 \(u\) 並更新它時,直接區間求和求出 \(v\) 子樹內對 \(u\) 的貢獻,將這一貢獻單點更新到 \(u\) 的線段樹上,併合並 \(v\) 的線段樹到 \(u\) 上。

  • 環的轉移:

    按暴力思路順時針走一遍,把走過的點的線段樹以及新產生的貢獻合併到一個新線段樹上,最後把對入點的貢獻加到入點的線段樹上並把整顆線段樹合併上去;

    逆時針的時候,為了避免重複,我們照樣開上述一顆線段樹 \(x\),再額外開一顆線段樹只把 \(x\) 這顆線段樹上在轉移過程中新產生的貢獻加上,而不加每顆子樹原本就有的線段樹的部分。

單點更新和區間查詢單次 \(log\),每次 dfs 共 \(O(n\log n)\),整體複雜度為 \(O(n^2 \log n)\)

code:

看懂思路的話,程式碼也會很好寫啦。只有短短 6 k 啦,啦啦啦 6 k 啊 6 k,它比 5 k 多 1 k

#include<bits/stdc++.h>
#define int long long
#define lson ls[rt]
#define rson rs[rt]
#define Aqrfre(x, y) freopen(#x ".in", "r", stdin),freopen(#y ".out", "w", stdout)
#define mp make_pair
#define Type int
#define qr(x) x=read()
typedef __int128 INT;
typedef long long ll;
using namespace std;

inline Type read(){
    char c=getchar(); Type x=0, f=1;
    while(!isdigit(c)) (c=='-'?f=-1:f=1), c=getchar();
    while(isdigit(c)) x=(x<<1)+(x<<3)+(c^48), c=getchar();
    return x*f;
}

const int N = 505; 
const int M = 5e6;
const int mod = 998244353;

int n, m, tot, dis[N][N]; ll res;
vector<int>to[N], belong[N]; bool beA[N][N];

int top, th, s[N], bcc, low[N], dfn[N];
vector<int>BCC[N]; int sz[N];
inline void tarjan(int x, int p){
    s[++top] = x;
    low[x] = dfn[x] = ++th;
    for(int y : to[x]){
        if(!dfn[y]){
            tarjan(y, x);
            low[x] = min(low[x], low[y]);
            if(low[y] == dfn[x]){
                ++bcc;
                do{ BCC[bcc].emplace_back(s[top]); sz[bcc]++;
                    belong[s[top]].emplace_back(bcc);
                    beA[s[top]][bcc] = true;
                }while(s[top--] != y);

                BCC[bcc].emplace_back(x); sz[bcc]++;
                belong[x].emplace_back(bcc);
                beA[x][bcc] = true;
            }
        }
        else low[x] = min(low[x], dfn[y]);
    }
}

vector<int>tw[N][N];
ll v[M]; int rtot, ls[M], rs[M], root[M];
inline void pushup(int rt){
    v[rt] = (ll)(v[lson] + v[rson]) % mod;
} 

inline void update(int &rt, int l, int r, int pos, int val){
    if(!rt) rt = ++rtot;
    if(l == r){
        (v[rt] += val) %= mod;
        return;
    }

    int mid = (l + r) >> 1;
    if(pos <= mid) update(lson, l, mid, pos, val);
    else update(rson, mid+1, r, pos, val);

    pushup(rt);
}

inline int merge(int x, int y, int l, int r){
    if(!x or !y) return x + y;
    if(l == r){
        (v[x] += v[y]) %= mod;
        return x;
    }
    int mid = (l + r) >> 1;
    ls[x] = merge(ls[x], ls[y], l, mid);
    rs[x] = merge(rs[x], rs[y], mid+1, r);

    pushup(x); return x;
}

inline void mcpy(int &x, int y, int l, int r){
    if(!y) return;
    x = ++rtot;
    v[x] = v[y];
    if(l == r) return;
    
    int mid = (l + r) >> 1;
    mcpy(ls[x], ls[y], l, mid);
    mcpy(rs[x], rs[y], mid+1, r);
}

inline void mergeAdd(int &x, int y, int l, int r){
    if(!y) return;
    if(!x) x = (++rtot);
    if(l == r){
        (v[x] += v[y]) %= mod;
        return;
    }
    int mid = (l + r) >> 1;
    mergeAdd(ls[x], ls[y], l, mid);
    mergeAdd(rs[x], rs[y], mid+1, r);

    pushup(x); return;
}

inline int query(int rt, int l, int r, int pos){
    if(!rt) return 0;
    if(l >= pos){
        return v[rt] % mod;
    }
    int mid = (l + r) >> 1, res = 0;
    if(mid >= pos) res = query(lson, l, mid, pos);
    (res += query(rson, mid+1, r, pos)) %= mod;
    
    return res;
}

inline void watch(int rt, int l, int r){       
    int mid = l + r >> 1;
    if(ls[rt]) watch(lson, l, mid);
    if(rs[rt]) watch(rson, mid+1, r);
    if(rt) cout<<l<<" "<<r<<' '<<v[rt]<<'\n';
}

inline int qpos(int rt, int l, int r, int pos){
    if(l == r) return v[rt] % mod;
    int mid = (l + r) >> 1;
    if(mid >= pos) return qpos(lson, l, mid, pos);
    else return qpos(rson, mid+1, r, pos);
}

int ned, tem;
inline void dp(int x, int p, int goal, int whi, int op){
    if(x == goal) return;
    int num = 0;
    for(int y : tw[whi][x]){
        if(y == p) continue;

        num = query(root[ned], 1, n, y);

        mergeAdd(root[ned], root[y], 1, n);
        update(root[ned], 1, n, y, num);

        if(op == 1){
            update(root[tem], 1, n, y, num);
        }

        if(y == goal) break;
        dp(y, x, goal, whi, op);
    }
}

bool vis[N], flag[N];
inline void dfs(int x, int p, int bel){
    update(root[x], 1, n, x, 1);
    for(int whi : belong[x]){
        if(whi == bel) continue;
        if(flag[whi]) continue;
        flag[whi] = true;

        if(sz[whi] == 2){
            int num = 0;
            for(int y : tw[whi][x]){
                if(y == p or vis[y]) continue;
                vis[y] = true;
                dfs(y, x, whi);
                root[x] = merge(root[x], root[y], 1, n);
                num += query(root[y], 1, n, x);
            }
            update(root[x], 1, n, x, num);

            continue;
        }

        for(int i : BCC[whi]){
            if(x == i) continue;
            vis[i] = true;
            dfs(i, 0, whi);
        }
        
        int a = 0, b = 0;
        for(int i : tw[whi][x]){
            if(a) b = i;
            else a = i;
        }

        ned++; mcpy(root[ned], root[a], 1, n);
        dp(a, x, b, whi, 0);
        mergeAdd(root[x], root[ned], 1, n);
        update(root[x], 1, n, x, query(root[ned], 1, n, x));

        ned++; mcpy(root[ned], root[b], 1, n); tem = ned + 1;
        dp(b, x, a, whi, 1);
        mergeAdd(root[x], root[tem], 1, n);
        update(root[x], 1, n, x, query(root[tem], 1, n, x));
    }
}

inline void clean(){
    ned = max(ned, tem);
    fill(root+1, root+1+ned, 0);
    fill(flag, flag+1+n, 0);
    fill(vis, vis+1+n, 0);
    fill(ls, ls+1+rtot, 0);
    fill(rs, rs+1+rtot, 0);
    fill(v, v+1+rtot, 0);
    rtot = 0; ned = n;
}

signed main(){ //algebra
    Aqrfre(algebra, algebra);

    qr(n), qr(m); ned = n;
    for(int i=1; i<=m; i++){
        int qr(x), qr(y);
        dis[x][y] = dis[y][x] = 1;
        to[x].emplace_back(y);
        to[y].emplace_back(x);
    }
    for(int i=1; i<=n; i++)
        if(!dfn[i]) tarjan(i, 0);

    for(int i=1; i<=bcc; i++)
        for(int x : BCC[i])
            for(int y : BCC[i]){
                if(x == y or !dis[x][y]) continue;
                tw[i][x].emplace_back(y);
            }


    int la = 0;
    for(int i=1; i<=n; i++){
        clean(); dfs(i, 0, 0);
        (res += qpos(root[i], 1, n, i)) %= mod;
    }

    cout<<res<<"\n";


    return 0;
}

相關文章