藍橋杯第 3 場 演算法季度賽第八題 升級電纜題解

Athanasy發表於2024-06-30

題目連結:升級電纜

貌似大部分人一開始想偏了,想些多 \(\log\) 的做法。大思路很簡單,常見的最大化最小值,那麼就是考慮二分最小值,然後透過限制進行 \(check\)

顯然 \(<mid\) 的所有速度需要增大,增大會使用開銷 \(c\),考慮 \(c\) 之和不超過 \(limit\),除此之外注意到可能新的 \(v\) 還是 \(< mid\),所以需要注意上界 \(r\) 其實應該是最小的 \(s\),變化後的速度,這樣保證每個比它小的 \(v\) 變化以後一定比它大。

先講講比較無腦的一些做法,考慮鏈資訊使用樹剖維護,那麼我們可以任意查詢一條鏈上的資訊,這個查詢的複雜度是雙 \(\log\) 的,然後再支援二分雜七雜八的,那麼這題可能有一些比較假的 \(3\log\) 做法,可能有些人可以卡過去吧。當然可以用 \(GBT\) 最佳化一個 \(\log\)

觀察到答案顯然是跟二分性有關,那麼常見的這種樹上的思路,並且是多次查詢的有兩個方向:

  1. 二分 + \(\text{樹類 ds 查詢資訊 進行 check}\)

  2. 在樹上進行二分答案。

後者比較苛刻,對於樹類問題,常常需要可能某個軸需要維護的是可差分的資訊。

說說前者,比較暴力一點,觀察到這個題涉及到的一些資訊需要維護:

  1. 關於 \(v,s\) 所在的值域限制軸。

  2. 關於 \(c\) 開銷總和的資訊統計。

  3. 關於樹上鍊 \((u,v)\) 的限制。

那麼我們發現,有兩對偏序限制 \((1,3)\)。那麼我們至少需要維護一個類似樹套樹的結構才能維護關於第二點 \(c\) 的開銷總和資訊。

考慮特殊的樹套樹,單 \(\log\) 的主席樹。

基於第一種做法,顯然沒啥好說的,既然是使用主席樹維護資訊,那麼觀察到 \(1\) 的偏序關係即為:\(\le v/s\),即為字首偏序,而 \(3\) 的偏序則是樹上的範圍性查詢。顯然 \(1\) 這點可以直接使用主席樹本身就是維護了一個外層的字首偏序資訊,即 \(root[i]\) 表示 \(val \le i\) 的主席樹,內層基於邊化點以後再維護一個基於 剖分序 的偏序資訊即可透過樹上字首和進行查詢。

碼量上應該還是不算小的,還需要寫一個邊化點的樹鏈剖分,不會的可以做做 \(QTree\),還需要離散化下。這個做法顯然是雙 \(\log\) 的。當然了,考慮下離線做法。對於離線帶多次二分而言的題,顯然整體二分最為合適,思考下二分答案以後,將所有小於它的邊進行啟用,然後進行判斷該二分到哪一側。重點是維護邊啟用操作,其實邊化點以後就是若干個單修操作,用樹狀陣列或者線段樹都可以。不過這玩意如果是剖分序顯然三支 \(\log\) 了,所以最好用 \(GBT\) 最佳化到兩支 \(\log\),主要還是鏈路徑資訊查詢部分,單修復雜度很好控制。

考慮下第二種做法:

需要支援樹上二分答案,那麼資訊要求是比較苛刻的,外軸顯然是需要一個用於可差分的資訊偏序限制,而維護的資訊則為可差分資訊,內軸即為真正的二分答案軸。

這其實是一個基礎板子,不會的可以去學學樹上主席樹:Count on a tree

那麼問題就很簡單了,對外軸維護關於 \((u,v)\) 的偏序,即 \(root[u]\) 即為 \(u \rightarrow 1\) 這條路徑上的所有點的累計資訊,放在了對應的主席樹上。

而二分的軸即為內軸,設定資訊的軸,即為 \(v/s\) 的值域軸。

而維護的資訊必須為可差分資訊,那麼顯然為前兩者作用下的 \(sum_c\),關於這個限制條件下的操作累計和。

這樣一來就可以進行正確的主席樹上的二分了,然後有些人想要找二分的上界 \(s_{min}\) 還寫了一些樹剖、倍增之類的查詢鏈上最小值,這是完全沒必要的。

考慮下:當存在 \(j<i,s_j <s_i\),那麼此時此刻顯然 \(s_i\) 不合法,那麼只需要讓 \(s_i\) 左側不合法就行:

當左側出現 \(s\) 時,顯然一定 check 失敗。這太容易了,設定 \(s_j\) 處為無窮大,這樣一來開銷一定超過 \(limit\),一定不成立,一定往左找,這樣一來如果存在最小的 \(s\) 是二分答案,那麼我們也可以恰好二分到它了。注意 \(limit \le 1e18\),那麼這個無窮大最好使用 \(int128\) 來進行儲存操作和,至於 \(lca\) 隨便倍增求求就行了。單 \(\log\) 常數很小,如果是巢狀 \(\log\),最好離散化下,否則常數是很大的。

參照程式碼
#include <bits/stdc++.h>

// #pragma GCC optimize(2)
// #pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")

#define isPbdsFile

#ifdef isPbdsFile

#include <bits/extc++.h>

#else

#include <ext/pb_ds/priority_queue.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/trie_policy.hpp>
#include <ext/pb_ds/tag_and_trait.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/list_update_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/exception.hpp>
#include <ext/rope>

#endif

using namespace std;
using namespace __gnu_cxx;
using namespace __gnu_pbds;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tii;
typedef tuple<ll, ll, ll> tll;
typedef unsigned int ui;
typedef unsigned long long ull;
#define hash1 unordered_map
#define hash2 gp_hash_table
#define hash3 cc_hash_table
#define stdHeap std::priority_queue
#define pbdsHeap __gnu_pbds::priority_queue
#define sortArr(a, n) sort(a+1,a+n+1)
#define all(v) v.begin(),v.end()
#define yes cout<<"YES"
#define no cout<<"NO"
#define Spider ios_base::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
#define MyFile freopen("..\\input.txt", "r", stdin),freopen("..\\output.txt", "w", stdout);
#define forn(i, a, b) for(int i = a; i <= b; i++)
#define forv(i, a, b) for(int i=a;i>=b;i--)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define endl '\n'
//用於Miller-Rabin
[[maybe_unused]] static int Prime_Number[13] = {0, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};

template <typename T>
int disc(T* a, int n)
{
    return unique(a + 1, a + n + 1) - (a + 1);
}

template <typename T>
T lowBit(T x)
{
    return x & -x;
}

template <typename T>
T Rand(T l, T r)
{
    static mt19937 Rand(time(nullptr));
    uniform_int_distribution<T> dis(l, r);
    return dis(Rand);
}

template <typename T1, typename T2>
T1 modt(T1 a, T2 b)
{
    return (a % b + b) % b;
}

template <typename T1, typename T2, typename T3>
T1 qPow(T1 a, T2 b, T3 c)
{
    a %= c;
    T1 ans = 1;
    for (; b; b >>= 1, (a *= a) %= c) if (b & 1) (ans *= a) %= c;
    return modt(ans, c);
}

template <typename T>
void read(T& x)
{
    x = 0;
    T sign = 1;
    char ch = getchar();
    while (!isdigit(ch))
    {
        if (ch == '-') sign = -1;
        ch = getchar();
    }
    while (isdigit(ch))
    {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    x *= sign;
}

template <typename T, typename... U>
void read(T& x, U&... y)
{
    read(x);
    read(y...);
}

template <typename T>
void write(T x)
{
    if (typeid(x) == typeid(char)) return;
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 ^ 48);
}

template <typename C, typename T, typename... U>
void write(C c, T x, U... y)
{
    write(x), putchar(c);
    write(c, y...);
}


template <typename T11, typename T22, typename T33>
struct T3
{
    T11 one;
    T22 tow;
    T33 three;

    bool operator<(const T3 other) const
    {
        if (one == other.one)
        {
            if (tow == other.tow) return three < other.three;
            return tow < other.tow;
        }
        return one < other.one;
    }

    T3()
    {
        one = tow = three = 0;
    }

    T3(T11 one, T22 tow, T33 three) : one(one), tow(tow), three(three)
    {
    }
};

template <typename T1, typename T2>
void uMax(T1& x, T2 y)
{
    if (x < y) x = y;
}

template <typename T1, typename T2>
void uMin(T1& x, T2 y)
{
    if (x > y) x = y;
}

constexpr int N = 1e5 + 10;
constexpr int MX = 1e9;
constexpr ll INF = 1e18;
constexpr int T = log2(N) + 1;
typedef __int128 i128;

struct Node
{
    int left, right;
    i128 sum;
} node[N << 6];

#define left(x) node[x].left
#define right(x) node[x].right
#define sum(x) node[x].sum
int cnt;
int fa[N][T + 1], deep[N];

inline void add(const int pre, int& curr, const int pos, const ll val, const int l = 1, const int r = MX)
{
    node[curr = ++cnt] = node[pre];
    sum(curr) += val;
    if (l == r) return;
    const int mid = l + r >> 1;
    if (pos <= mid) add(left(pre),left(curr), pos, val, l, mid);
    else add(right(pre),right(curr), pos, val, mid + 1, r);
}

inline int query(const int rtU, const int rtV, const int rtLCA, const ll sumV, const int l = 1, const int r = MX)
{
    if (l == r) return l;
    const int mid = l + r >> 1;
    const i128 leftSum = sum(left(rtU)) + sum(left(rtV)) - 2 * sum(left(rtLCA));
    if (leftSum > sumV) return query(left(rtU),left(rtV),left(rtLCA), sumV, l, mid);
    return query(right(rtU),right(rtV),right(rtLCA), sumV - leftSum, mid + 1, r);
}

typedef tuple<int, int, int, int> t4;
vector<t4> child[N];
int n, q;
int root[N];

inline void dfs(const int curr, const int pa)
{
    deep[curr] = deep[fa[curr][0] = pa] + 1;
    forn(i, 1, T) fa[curr][i] = fa[fa[curr][i - 1]][i - 1];
    for (const auto [nxt,v,c,s] : child[curr])
    {
        if (nxt == pa) continue;
        root[nxt] = root[curr];
        add(root[nxt], root[nxt], v, c);
        add(root[nxt], root[nxt], s, INF);
        dfs(nxt, curr);
    }
}

inline int LCA(int x, int y)
{
    if (deep[x] < deep[y]) swap(x, y);
    forv(i, T, 0) if (deep[fa[x][i]] >= deep[y]) x = fa[x][i];
    if (x == y) return x;
    forv(i, T, 0) if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

inline void solve()
{
    cin >> n;
    forn(i, 1, n-1)
    {
        int x, y, v, c, s;
        cin >> x >> y >> v >> c >> s;
        child[x].emplace_back(y, v, c, s);
        child[y].emplace_back(x, v, c, s);
    }
    dfs(1, 0);
    cin >> q;
    while (q--)
    {
        int u, v;
        ll sum;
        cin >> u >> v >> sum;
        const int lca = LCA(u, v);
        cout << query(root[u], root[v], root[lca], sum) << endl;
    }
}

signed int main()
{
    // MyFile
    Spider
    //------------------------------------------------------
    // clock_t start = clock();
    int test = 1;
    //    read(test);
    // cin >> test;
    forn(i, 1, test) solve();
    //    while (cin >> n, n)solve();
    //    while (cin >> test)solve();
    // clock_t end = clock();
    // cerr << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl;
}

\[時間複雜度為:預處理\ O(n(\log{n}+\log{V})),查詢\ O(q\log{V}) \]

相關文章