點分治
點分治是一個求樹上路徑問題的演算法,演算法流程通常是:找到子樹中的重心,計算重心的子樹的每一個點與重心的路徑的資料,接著統計整體答案。
Close Vertices
思路
很明顯,這是一道點分治題目,但有兩個限制條件,考慮將兩個條件排序起來,雙指標找第一個條件,樹狀陣列維護第二個條件,但是同一個子樹內不能重複統計,所以將答案減去每個子樹內的答案。
程式碼
#include<iostream>
#include<algorithm>
#define int long long
using namespace std;
inline int read(){register int x = 0, f = 1;register char c = getchar();while (c < '0' || c > '9'){if (c == '-') f = -1;c = getchar();}while (c >= '0' && c <= '9'){x = (x << 1) + (x << 3) + (c ^ 48);c = getchar();}return x * f;}
inline void write(int x){if (x < 0) putchar('-'), x = -x;if (x > 9) write(x / 10);putchar(x % 10 + '0');}
const int N = 4e5 + 10;
int n, k, d, maxn, rt, ans;
int t[N];
int lowbit(int x){return x & (-x);}
void modify(int x, int y){
for (int i = max(x, 1ll); i <= 1e5 + 1; i += lowbit(i)) t[i] += y;
}
int query(int x){
int res = 0;
for (int i = max(x, 1ll); i; i -= lowbit(i)) res += t[i];
return res;
}
struct edge{
int v, w, nxt;
}e[N << 1];
int head[N], cnt;
void add(int u, int v, int w){
e[++cnt] = (edge){v, w, head[u]};
head[u] = cnt;
}
struct node{
int dis, l;
bool operator < (const node &b) const{
if (dis != b.dis) return dis < b.dis;
return l < b.l;
}
}stk1[N], stk2[N];
int mx[N], sz[N], top1, top2, dis[N], l[N];
bool vis[N];
void find(int u, int fa){
sz[u] = 1, mx[u] = 0;
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (v == fa || vis[v]) continue;
find(v, u);
sz[u] += sz[v];
mx[u] = max(mx[u], sz[v]);
}
mx[u] = max(mx[u], maxn - sz[u]);
if (mx[u] < mx[rt]) rt = u;
}
void dfs(int u, int fa){
stk2[++top2] = (node){dis[u], l[u]};
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (v == fa || vis[v]) continue;
dis[v] = dis[u] + e[i].w;
l[v] = l[u] + 1;
dfs(v, u);
}
}
void calc(int u){
top1 = 0;
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (vis[v]) continue;
top2 = 0, dis[v] = e[i].w, l[v] = 1;
dfs(v, u);
sort(stk2 + 1, stk2 + top2 + 1);
for (int j = 1; j <= top2; j++) stk1[++top1] = stk2[j], modify(stk2[j].l + 1, 1);
int l = 1, r = top2;
while (l <= top2){
modify(stk2[l].l + 1, -1);
while (l < r && stk2[l].dis + stk2[r].dis > d) modify(stk2[r--].l + 1, -1);
if (l >= r) break;
ans -= query(k - stk2[l].l + 1);
l++;
}
}
stk1[++top1] = (node){0, 0};
sort(stk1 + 1, stk1 + top1 + 1);
for (int i = 1; i <= top1; i++) modify(stk1[i].l + 1, 1);
int l = 1, r = top1;
while (l <= top1){
modify(stk1[l].l + 1, -1);
while (l < r && stk1[l].dis + stk1[r].dis > d) modify(stk1[r--].l + 1, -1);
if (l >= r) break;
ans += query(k - stk1[l].l + 1);
l++;
}
}
void solve(int u){
vis[u] = 1;
calc(u);
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (vis[v]) continue;
maxn = sz[v];
rt = 0;
find(v, 0);
solve(rt);
}
}
signed main(){
n = read(), k = read(), d = read();
for (int i = 1; i < n; i++){
int v = read(), w = read();
add(i + 1, v, w), add(v, i + 1, w);
}
maxn = mx[rt] = n;
find(1, 0);
solve(rt);
cout << ans;
return 0;
}
Luogu 的 CF Remotejudge 沒修好,氣死我也。
P5351 Ruri Loves Maschera
思路
首先,路徑的最大值很好計算出來,考慮如何統計答案,對於一條權值為 \(k\) 的鏈,用樹狀陣列查詢比 \(k\) 小的個數,這些邊的權值都為 \(k\),因為每個鏈都會被大於其權值的鏈統計,所以只需統計小於其權值的然後匹配即可,注意要容斥。
程式碼
#include<iostream>
#include<algorithm>
using namespace std;
inline int read(){register int x = 0, f = 1;register char c = getchar();while (c < '0' || c > '9'){if (c == '-') f = -1;c = getchar();}while (c >= '0' && c <= '9'){x = (x << 1) + (x << 3) + (c ^ 48);c = getchar();}return x * f;}
inline void write(int x){if (x < 0) putchar('-'), x = -x;if (x > 9) write(x / 10);putchar(x % 10 + '0');}
const int N = 1e5 + 10;
int n, L, R, rt;
long long ans;
int t[N];
int lowbit(int x){return x & -x;}
void modify(int x, int y){
for (int i = x; i <= n + 1; i += lowbit(i)) t[i] += y;
}
int query(int x){
int res = 0;
for (int i = x; i; i -= lowbit(i)) res += t[i];
return res;
}
struct edge{
int v, w, nxt;
}e[N << 1];
int head[N], cnt;
void add(int u, int v, int w){
e[++cnt] = (edge){v, w, head[u]};
head[u] = cnt;
}
struct node{
long long dis;
int l;
bool operator < (const node &b) const{
return dis < b.dis;
}
}stk1[N], stk2[N];
int sz[N], mx[N], maxn, dis[N], l[N], top1, top2;
bool vis[N];
void find(int u, int fa){
sz[u] = 1, mx[u] = 0;
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (v == fa || vis[v]) continue;
find(v, u);
sz[u] += sz[v];
mx[u] = max(mx[u], sz[v]);
}
mx[u] = max(mx[u], maxn - sz[u]);
if (mx[u] < mx[rt]) rt = u;
}
void dfs(int u, int fa){
stk2[++top2] = (node){dis[u], l[u]};
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (v == fa || vis[v]) continue;
dis[v] = max(dis[u], e[i].w), l[v] = l[u] + 1;
dfs(v, u);
}
}
void calc(int u){
top1 = 0;
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (vis[v]) continue;
top2 = 0, dis[v] = e[i].w, l[v] = 1;
dfs(v, u);
sort(stk2 + 1, stk2 + top2 + 1);
for (int j = 1; j <= top2; j++){
ans -= ((R - stk2[j].l + 1 <= 0 ? 0 : query(R - stk2[j].l + 1)) - (L - stk2[j].l + 1 <= 0 ? 0 : query(L - stk2[j].l))) * stk2[j].dis;
modify(stk2[j].l + 1, 1);
}
for (int j = 1; j <= top2; j++){
modify(stk2[j].l + 1, -1);
stk1[++top1] = stk2[j];
}
}
stk1[++top1] = (node){0, 0};
sort(stk1 + 1, stk1 + top1 + 1);
for (int i = 1; i <= top1; i++){
ans += ((R - stk1[i].l + 1 <= 0 ? 0 : query(R - stk1[i].l + 1)) - (L - stk1[i].l + 1 <= 0 ? 0 : query(L - stk1[i].l))) * stk1[i].dis;
modify(stk1[i].l + 1, 1);
}
for (int i = 1; i <= top1; i++) modify(stk1[i].l + 1, -1);
}
void solve(int u){
vis[u] = 1;
calc(u);
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (vis[v]) continue;
maxn = sz[v];
rt = 0;
find(v, 0);
solve(rt);
}
}
int main(){
n = read(), L = read(), R = read();
for (int i = 1; i < n; i++){
int u = read(), v = read(), w = read();
add(u, v, w), add(v, u, w);
}
maxn = mx[rt] = n;
find(1, 0);
solve(rt);
cout << (ans << 1);
return 0;
}
P2634 [國家集訓隊] 聰聰可可
思路
這是一道點分治板題,但也能夠從中獲得啟發,做法:將和 \(mod\) \(3\) 的餘數用桶陣列記錄下來,再更新答案即可。
分數約分就用 \(\gcd\) 就行了。
程式碼
#include<iostream>
#define int long long
using namespace std;
inline int read(){register int x = 0, f = 1;register char c = getchar();while (c < '0' || c > '9'){if (c == '-') f = -1;c = getchar();}while (c >= '0' && c <= '9'){x = (x << 1) + (x << 3) + (c ^ 48);c = getchar();}return x * f;}
inline void write(int x){if (x < 0) putchar('-'), x = -x;if (x > 9) write(x / 10);putchar(x % 10 + '0');}
const int N = 2e4 + 10;
int n, rt, ans1, ans2;
struct edge{
int v, w, nxt;
}e[N << 1];
int head[N], cnt;
void add(int u, int v, int w){
e[++cnt] = (edge){v, w, head[u]};
head[u] = cnt;
}
int sz[N], mx[N], maxn, dis[N], stk1[N], top1, stk2[N], top2, ton[4];
bool vis[N];
void find(int u, int fa){
sz[u] = 1, mx[u] = 0;
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (v == fa || vis[v]) continue;
find(v, u);
sz[u] += sz[v];
mx[u] = max(mx[u], sz[v]);
}
mx[u] = max(mx[u], maxn - sz[u]);
if (mx[u] < mx[rt]) rt = u;
}
void dfs(int u, int fa){
stk2[++top2] = dis[u];
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (v == fa || vis[v]) continue;
dis[v] = dis[u] + e[i].w;
dfs(v, u);
}
}
void calc(int u){
top1 = 0, ton[0] = 1;
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (vis[v]) continue;
top2 = 0, dis[v] = e[i].w;
dfs(v, u);
for (int j = 1; j <= top2; j++) ans1 += ton[(3 - stk2[j] % 3) % 3];
for (int j = 1; j <= top2; j++) ton[stk2[j] % 3]++, stk1[++top1] = stk2[j];
}
for (int i = 1; i <= top1; i++) ton[stk1[i] % 3]--;
}
void solve(int u){
vis[u] = 1;
calc(u);
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (vis[v]) continue;
rt = 0;
maxn = sz[v];
find(v, 0);
solve(rt);
}
}
int gcd(int a, int b){
return (b == 0 ? a : gcd(b, a % b));
}
signed main(){
n = read();
for (int i = 1; i < n; i++){
int u = read(), v = read(), w = read();
add(u, v, w), add(v, u, w);
}
maxn = mx[rt] = n;
find(1, 0);
solve(rt);
ans1 = ans1 * 2 + n;
ans2 = n * n;
int g = gcd(ans1, ans2);
cout << ans1 / g << '/' << ans2 / g;
return 0;
}
P3714 [BJOI2017] 樹的難題
思路
是一道超級難的點分治題,但是難的在於統計答案,而不是套用板子就行,在點分治中,一條鏈需要維護權值,邊數,這個鏈到根的那條邊的顏色,以及這是該子樹內第幾條鏈。
如何統計答案?對於一個根,將鏈到根的邊的顏色排序,如果相同,按鏈的編號排序(即第幾條鏈),開兩課線段樹,一顆線段樹存相同顏色的權值最大值,另一棵存不同顏色的權值最大值,相同顯然要減去當前顏色的權值,排序的作用是先處理同種顏色同個子樹的鏈,由於後面的與前面會匹配,所以前面的可以不和後面的匹配。
程式碼
#include<iostream>
#include<climits>
#include<algorithm>
#include<cmath>
#define int long long
using namespace std;
inline int read(){register int x = 0, f = 1;register char c = getchar();while (c < '0' || c > '9'){if (c == '-') f = -1;c = getchar();}while (c >= '0' && c <= '9'){x = (x << 1) + (x << 3) + (c ^ 48);c = getchar();}return x * f;}
inline void write(int x){if (x < 0) putchar('-'), x = -x;if (x > 9) write(x / 10);putchar(x % 10 + '0');}
const int N = 2e5 + 10;
int n, m, L, R, rt, ans = INT_MIN;
int c[N];
int t1[N << 2], t2[N << 2];
inline void pushup(int now, int t[]){
t[now] = max(t[now << 1], t[now << 1 | 1]);
}
inline void build(int now, int l, int r, int t[]){
if (l == r){
t[now] = LLONG_MIN;
return;
}
int mid = (l + r) >> 1;
build(now << 1, l, mid, t);
build(now << 1 | 1, mid + 1, r, t);
pushup(now, t);
}
inline void modify(int now, int l, int r, int x, int k, int t[]){
if (l == r){
if (k == LLONG_MIN) t[now] = k;
else t[now] = max(t[now], k);
return;
}
int mid = (l + r) >> 1;
if (x <= mid) modify(now << 1, l, mid, x, k, t);
else modify(now << 1 | 1, mid + 1, r, x, k, t);
pushup(now, t);
}
inline int query(int now, int l, int r, int x, int y, int t[]){
if (x <= l && r <= y) return t[now];
int mid = (l + r) >> 1, res = LLONG_MIN;
if (x <= mid) res = max(res, query(now << 1, l, mid, x, y, t));
if (mid + 1 <= y) res = max(res, query(now << 1 | 1, mid + 1, r, x, y, t));
return res;
}
struct edge{
int v, w, col, nxt;
}e[N << 1];
int head[N], cnt;
inline void add(int u, int v, int w, int col){
e[++cnt] = (edge){v, w, col, head[u]};
head[u] = cnt;
}
struct node{
int dis, l, c, id;
bool operator < (const node &b) const{
if (c != b.c) return c < b.c;
return id < b.id;
}
}stk[N];
int sz[N], mx[N], maxn, top, dis[N], l[N];
bool vis[N];
inline void find(int u, int fa){
sz[u] = 1, mx[u] = 0;
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (v == fa || vis[v]) continue;
find(v, u);
sz[u] += sz[v];
mx[u] = max(mx[u], sz[v]);
}
mx[u] = max(mx[u], maxn - sz[u]);
if (mx[u] < mx[rt]) rt = u;
}
inline void dfs(int u, int fa, int col, int last, int id){
stk[++top] = (node){dis[u], l[u], col, id};
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (v == fa || vis[v]) continue;
if (e[i].col != last) dis[v] = dis[u] + e[i].w;
else dis[v] = dis[u];
l[v] = l[u] + 1;
dfs(v, u, col, e[i].col, id);
}
}
inline void calc(int u){
top = 0;
int id = 0;
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (vis[v]) continue;
dis[v] = e[i].w, l[v] = 1;
dfs(v, u, e[i].col, e[i].col, ++id);
}
sort(stk + 1, stk + top + 1);
build(1, 1, top + 1, t1);
build(1, 1, top + 1, t2);
modify(1, 1, top + 1, 1, 0, t2);
int l = 1, r = 1;
for (int i = 1; i <= top; i++){
if (stk[i].id != stk[i - 1].id && i != 1){
if (stk[i].c == stk[i - 1].c){
for (int j = l; j <= r; j++) modify(1, 1, top + 1, stk[j].l + 1, stk[j].dis, t1);
}else{
for (int j = l; j <= r; j++) modify(1, 1, top + 1, stk[j].l + 1, LLONG_MIN, t1), modify(1, 1, top + 1, stk[j].l + 1, stk[j].dis, t2);
}
l = r = i;
}else r = i;
if (min(R - stk[i].l, R) + 1 <= 0) continue;
int x = query(1, 1, top + 1, max(min(L - stk[i].l, top) + 1, 1ll), max(min(R - stk[i].l, top) + 1, 1ll), t1);
int y = query(1, 1, top + 1, max(min(L - stk[i].l, top) + 1, 1ll), max(min(R - stk[i].l, top) + 1, 1ll), t2);
if (x != LLONG_MIN) ans = max(ans, x - c[stk[i].c] + stk[i].dis);
if (y != LLONG_MIN) ans = max(ans, y + stk[i].dis);
}
}
inline void solve(int u){
vis[u] = 1;
calc(u);
for (int i = head[u]; i; i = e[i].nxt){
int v = e[i].v;
if (vis[v]) continue;
rt = 0;
maxn = sz[v];
find(v, 0);
solve(rt);
}
}
signed main(){
n = read(), m = read(), L = read(), R = read();
for (int i = 1; i <= m; i++) c[i] = read();
for (int i = 1; i < n; i++){
int u = read(), v = read(), col = read(), w = c[col];
add(u, v, w, col), add(v, u, w, col);
}
maxn = mx[rt] = n;
find(1, 0);
solve(rt);
cout << ans;
return 0;
}