P4556 [Vani有約會] 雨天的尾巴 /【模板】線段樹合併
在這題裡面講一下線段樹合併。顧名思義就是把多個線段樹合併成一個。
顯然完全二叉線段樹(也就是普通線段樹)是無法更高效的合併的,只能把所有節點加起來建個新樹。但是在動態開點線段樹中,有時候一個樹只有幾條鏈,這時候我們就是可以使用線段樹合併了。
其核心就是隻遍歷兩棵樹重疊的部分實現合併。具體的,對於兩棵線段樹 \(p1\),\(p2\),我們要把 \(p2\) 合併到 \(p1\) 上。不妨先考慮左兒子,假如兩個線段樹都有左兒子,那麼就往下遍歷;假如只有 \(p1\) 有左兒子,顯然不需要修改;假如只有 \(p2\) 有左兒子,\(p1\) 沒有記錄資訊,那麼只需要把 \(p1\) 左兒子的指標改成 \(p2\) 左兒子的指標即可。複雜度不會嚴格證明,一般所有線段樹的總節點個數為 \(O(n\log n)\),合併時遍歷到的只有重疊部分,非重疊部分沒有遍歷,所以每個線段樹的節點至多被遍歷一次,所以時間複雜度是 \(O(n\log n)\)。
void mg(int p1, int p2, int l, int r) {
if(l == r) {
.....
return;
}
int mid = (l + r) >> 1;
if(t[p1].ls && t[p2].ls) mg(t[p1].ls, t[p2].ls, l, mid);
else if(t[p2].ls) t[p1].ls = t[p2].ls;
if(t[p1].rs && t[p2].rs) mg(t[p1].rs, t[p2].rs, mid + 1, r);
else if(t[p2].rs) t[p1].rs = t[p2].rs;
pushup(p1, t[p1].ls, t[p1].rs);
}
回到這題,容易想到樹上差分,然後把操作拆成 \(4\) 個單點修改,離線在樹上線段樹合併即可。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
#define pb push_back
typedef long long i64;
const int N = 1e5 + 10;
int n, m, cnt;
int dep[100010], anc[100010][21], h[100010];
struct node{
int to, nxt;
} e[200010];
void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = h[u];
h[u] = cnt;
}
void dfs(int u, int fa) {
for(int i = h[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
anc[v][0] = u;
dep[v] = dep[u] + 1;
dfs(v, u);
}
}
void init() {
for(int j = 1; j <= 19; j++) {
for(int i = 1; i <= n; i++) {
anc[i][j] = anc[anc[i][j - 1]][j - 1];
}
}
}
int lca(int u, int v) {
if(dep[u] < dep[v]) std::swap(u, v);
for(int i = 19; i >= 0; i--) if(dep[anc[u][i]] >= dep[v]) u = anc[u][i];
if(u == v) return u;
for(int i = 19; i >= 0; i--) if(anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
return anc[u][0];
}
int tot;
std::vector<pii> ve[N];
struct seg {
int ls, rs, mx, cnt;
} t[40 * N];
void pushup(int u, int ls, int rs) {
if(t[ls].cnt >= t[rs].cnt && t[ls].cnt) {
t[u].cnt = t[ls].cnt;
t[u].mx = t[ls].mx;
} else if(t[ls].cnt <= t[rs].cnt && t[rs].cnt) {
t[u].cnt = t[rs].cnt;
t[u].mx = t[rs].mx;
} else {
t[u].cnt = t[u].mx = 0;
}
return;
}
void ins(int &u, int l, int r, int x, int y) {
if(!u) u = ++tot;
if(l == r) {
t[u].cnt += y;
t[u].mx = t[u].cnt ? l : 0;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) ins(t[u].ls, l, mid, x, y);
else ins(t[u].rs, mid + 1, r, x, y);
pushup(u, t[u].ls, t[u].rs);
}
void mg(int p1, int p2, int l, int r) {
if(l == r) {
t[p1].cnt += t[p2].cnt;
t[p1].mx = t[p1].cnt ? l : 0;
return;
}
int mid = (l + r) >> 1;
if(t[p1].ls && t[p2].ls) mg(t[p1].ls, t[p2].ls, l, mid);
else if(t[p2].ls) t[p1].ls = t[p2].ls;
if(t[p1].rs && t[p2].rs) mg(t[p1].rs, t[p2].rs, mid + 1, r);
else if(t[p2].rs) t[p1].rs = t[p2].rs;
pushup(p1, t[p1].ls, t[p1].rs);
}
int ans[N];
void dfs2(int u, int fa) {
for(int i = h[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs2(v, u);
mg(u, v, 1, 1e5);
}
for(auto x : ve[u]) ins(u, 1, 1e5, x.fi, x.se);
ans[u] = t[u].mx;
}
void Solve() {
std::cin >> n >> m;
tot = n;
for(int i = 1; i < n; i++) {
int u, v;
std::cin >> u >> v;
add(u, v), add(v, u);
}
dep[1] = 1;
dfs(1, 0);
init();
while(m--) {
int u, v, w;
std::cin >> u >> v >> w;
int rt = lca(u, v);
ve[u].push_back({w, 1}), ve[v].push_back({w, 1});
ve[rt].push_back({w, -1}), ve[anc[rt][0]].push_back({w, -1});
}
dfs2(1, 0);
for(int i = 1; i <= n; i++) std::cout << ans[i] << "\n";
}
int main() {
Solve();
return 0;
}