P8906 [USACO22DEC] Breakdown P [最短路]

chenwenmo發表於2024-11-08

P8906 [USACO22DEC] Breakdown P

Solution

  • 經典 trick,刪邊比較難處理,轉換成加邊,倒著處理。

  • 那我們接下來要考慮,怎麼記錄狀態,以及,每加一次邊要如何更新狀態。

  • 還是比較套路地,我們可以求出 \(1\) 到某個點 \(i\) 經過 \(k/2\) 條邊的最短路,再求出 \(i\)\(n\) 經過 \(k-k/2\) 條邊的最短路。根據這種思想,我們可能會思考折半搜尋,分治之類的演算法來解決這道題。

  • 如果我們要求 \(k/2\),就要求 \(k/4\),然後又要求 \(k/8\)。注意資料範圍 \(k\le 8\),這種資料範圍可能會成為解題的關鍵。那麼,\(k/2\le 4\)\(k/4\le 2\)\(k/8\le 1\)\(1\) 顯然就是邊界,就等於兩點間的邊權。如何求 \(2\) 呢?是不是隻需要合併兩條邊。求 \(4\) 只需要把 \(2\) 的答案合併。這樣一來,我們要記錄的資訊就非常少,不需要去寫分治之類的遞迴程式。

  • 至於怎麼合併。我們先考慮 \(1\) 到某一點 \(i\) 的情況,\(i\)\(n\) 同理。我們設 \(f_i\) 表示 \(1\)\(i\) 經過 \(k/2\) 條邊的最短路,設 \(g_{i,j}\) 表示 \(i\)\(j\) 只經過兩條邊的最短路。為啥 \(f\) 只記錄終點,\(g\) 要記錄起點和終點,因為 \(f\) 顯然只有終點才是有用的,\(g\) 的話要記錄起點和終點才能合併給 \(f\),不然咋合併?

  • 先大致口胡一遍演算法。

    • 如果 \(k/2=4\),那 \(g+g\to f\)
    • 如果 \(k/2=3\),那麼 \(g+w\to f\)\(w\) 是邊權。
    • 如果 \(k/2=2\),那麼 \(g\to f\)
    • 如果 \(k/2=1\),那麼 \(w\to f\)
  • 那我們接下來考慮如何更新狀態,且是在正確的複雜度內。

  • 先考慮如何更新 \(g\)。假設當前加的是邊 \((u\to v)\),那麼 \(i\to u\to v\) 這條路徑可能被更新 (\(i\) 是任意點),\(O(n)\) 列舉一個 \(i\),路徑 \(u\to v\to j\) 也可能被更新,再 \(O(n)\) 列舉一個 \(j\)。更新 \(g\) 的總複雜度 \(O(n)\)

  • 考慮如何更新 \(f\),這裡以 \(k/2=4\) 為例。

    • 假設 \(f_i\) 經過的是這樣一條邊 \(1\to u\to v\to j\to i\)

    • 如果當前加的邊是 \(j\to i\),那麼只需要 \(O(n)\) 列舉一個 \(v\),然後 \(g_{1,v}+g_{v,i}\to f_i\) 即可。

    • 如果當前加的邊是 \(v\to j\),只需要 \(O(n)\) 列舉 \(i\),然後 \(g_{1,v}+g_{v,i}\to f_i\)

    • 如果當前加的邊是 \(u\to v\)\(O(n)\) 列舉 \(i\),然後也是 \(g_{1,v}+g_{v,i}\to f_i\)。(欸,怎麼寫起來都是一樣的,不管了qwq)

    • 以上情況的總複雜度都是 \(O(n)\)

    • 那如果當前加入的是 \(1\to u\) 呢?我們肯定要 \(O(n^2)\) 列舉 \(v\)\(i\),但是注意到,起點為 \(1\) 的邊只有 \(O(n)\) 條,於是均攤複雜度只有 \(O(n^3)\),還是能過的。

  • 綜上,總複雜度 \(O(n^3)\),於是就可以愉快地 coding 了。

Code

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ull = unsigned long long;

const int N = 3e2 + 5, M = N * N, inf = 0x3f3f3f3f;

int n, m, k, w[N][N], ans[M];

void Min(int &a, int b) {return a = min(a, b), void();}

struct Edge {int u, v;} e[M];

struct Worker {
    int f[N]; // st(起點) 到 i 的經過 k 條邊的最短路
    int g[N][N]; // i 到 j 經過 2 條邊的最短路
    int w[N][N]; // 邊權
    int st, k;
    void init(int _st, int _k) {
        st = _st, k = _k;
        memset(f, 0x3f, sizeof(f));
        memset(g, 0x3f, sizeof(g));
        memset(w, 0x3f, sizeof(w));
    }
    void add(int u, int v, int _w) {
        w[u][v] = _w;
        if (k == 1) {
            if (u == st) f[v] = w[u][v];
            return;
        }
        for (int i = 1; i <= n; i++) {
            Min(g[i][v], w[i][u] + w[u][v]);
            Min(g[u][i], w[u][v] + w[v][i]);
        }
        if (k == 2) {
            for (int i = 1; i <= n; i++) f[i] = g[st][i];
            return;
        }
        if (k == 3) {
            if (u == st) {
                for (int i = 1; i <= n; i++) {
                    for (int j = 1; j <= n; j++) {
                        Min(f[i], g[u][j] + w[j][i]);
                    }
                }
            } else {
                for (int i = 1; i <= n; i++) {
                    Min(f[i], w[st][u] + g[u][i]);
                    Min(f[i], w[st][v] + g[v][i]);
                    // Min(f[u], w[st][i] + g[i][u]); 該情況不會被更新
                    Min(f[v], w[st][i] + g[i][v]);
                }
            }
        }
        if (k == 4) {
            if (u == st) {
                for (int i = 1; i <= n; i++) {
                    for (int j = 1; j <= n; j++) {
                        Min(f[i], g[u][j] + g[j][i]);
                    }
                }
            } else {
                for (int i = 1; i <= n; i++) {
                    Min(f[i], g[st][u] + g[u][i]);
                    Min(f[i], g[st][v] + g[v][i]);
                    // Min(f[u], g[st][i] + g[i][u]); 該情況不會被更新
                    Min(f[v], g[st][i] + g[i][v]);
                }
            }
        }
    }
}work1, workn;

int main() {
    cin >> n >> k;
    m = n * n;
    for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) cin >> w[i][j];
    for (int i = 1; i <= m; i++) cin >> e[i].u >> e[i].v;
    int k1 = k / 2, kn = k - k1;
    work1.init(1, k1), workn.init(n, kn);
    for (int i = m; i >= 1; i--) {
        ans[i] = inf;
        for (int j = 1; j <= n; j++) ans[i] = min(ans[i], work1.f[j] + workn.f[j]);
        work1.add(e[i].u, e[i].v, w[e[i].u][e[i].v]);
        workn.add(e[i].v, e[i].u, w[e[i].u][e[i].v]);
    }
    for (int i = 1; i <= m; i++) cout << (ans[i] == inf ? -1 : ans[i]) << "\n";
    return 0;
}```

相關文章