Tree

liuboom發表於2024-11-05

P4178 Tree

題目描述:

給定一棵 n 個節點的樹,每條邊有邊權,求出樹上兩點距離小於等於 k 的點對數量。

資料範圍:

\(1≤n≤4×10^4\)
\(1≤k≤2×10^4\)

說句閒話

感謝著名CB大師red_fire傾情推薦%%%
在機房三個人唇槍舌劍了一小會,我們的CB大師直接開搞線段樹,而仔細研讀了資料範圍的本蒟蒻認為sort十分的清真,於是就有了這場對決

Solution:

首先我們不難想到像這樣
“樹上兩點距離小於等於 k的點對數量。”
的題顯然可以用澱粉質解決
廢話今天可是點分治專題啊:(

我們考慮如何統計答案:對於一個點cent:把它的所有兒子vcentdis全部求出來並記錄到一個陣列A中,然後對A排序,排序後在上面跑一個雙指標,求出對於每個l,從右往左數第一個滿足\(A[l]+A[r]\le k\)的,然後這部分的貢獻就是r-l

注意,儘管區間長度是r-l+1,但是[l,l]並不是一個點對,所以不能將[l,l]統計進答案

但是我們會發現這樣寫的話同一個y內的點對也有可能被統計進答案,我們只需要對於x的每一個兒子y,去一下y內的重就好了。

Code:

#include<bits/stdc++.h>
#define int long long
const int N=4e4+5;
using namespace std;
long long ans;
int n,m,e_cnt,k;
int head[N<<1],nxt[N<<1],to[N<<1],w[N<<1];
void add(int x,int y,int z)
{
    e_cnt++;
    to[e_cnt]=y;nxt[e_cnt]=head[x];w[e_cnt]=z;
    head[x]=e_cnt;
}
int tot,cent;
int vis[N],siz[N],mx[N],A[N],dis[N];
void get_cent(int x,int fa)
{
    siz[x]=1;mx[x]=0;
    for(int i=head[x],y;i;i=nxt[i])
    {
        y=to[i];
        if(y==fa||vis[y])continue;
        get_cent(y,x);
        siz[x]+=siz[y];
        mx[x] = mx[x] > siz[y] ? mx[x] : siz[y];
    }
    mx[x]  = mx[x] > tot-siz[x] ? mx[x] : tot-siz[x];
    cent = mx[cent] < mx[x] ? cent : x;
}
void get_ans(bool tag)
{
    sort(A+1,A+1+A[0]);
    for(int l=1,r=A[0];l<=A[0];l++)
    {
        while(A[l]+A[r]>k)r--;
        if(l<r)
        ans+= (tag ? l-r : r-l);
        if(r<l)break;
    }
    A[0]=0;
}
inline void get_dis(int x,int fa)
{
    A[++A[0]]=dis[x];
    for(int i=head[x],y;i;i=nxt[i])
    {
        y=to[i];
        if(y==fa||vis[y])continue;
        dis[y]=dis[x]+w[i];
        get_dis(y,x);
    }
}
void calc(int x)
{
    dis[x]=0;
    get_dis(x,0);
    vis[x]=1;
    get_ans(0);
    for(int i=head[x],y;i;i=nxt[i])
    {
        y=to[i];
        if(vis[y])continue;
        dis[y]=dis[x]+w[i];
        get_dis(y,x);
        get_ans(1);
    }
}
void solve(int x)
{
    calc(x);
    for(int i=head[x],y;i;i=nxt[i])
    {
        y=to[i];
        if(vis[y])continue;
        cent=0;tot=siz[y];
        get_cent(y,x);
        solve(cent);
    }
}
void work()
{
    cin>>n;
    for(int i=1,x,y,z;i<n;i++)
    {
        scanf("%lld%lld%lld",&x,&y,&z);
        add(x,y,z);
        add(y,x,z);
    }
    cin>>k;
    mx[cent=0]=n;
    tot=n;
    get_cent(1,0);
    solve(cent);
    printf("%lld",ans);
}
#undef int
int main()
{
    //freopen("P4178_1.in","r",stdin);
    //freopen("Tree.out","w",stdout);
    work();
    return 0;
}

相關文章