2020 年第一屆遼寧省大學生程式設計競賽 D.開心消消樂(點分治)

Code92007發表於2020-10-28

題目

題解

wa了12發的點分治終於過了,就是xjb亂搞題……

維護了六個量,(u到v的鏈上出現的第一種顏色col,col的次數cnt,最後一種鏈的顏色las,las的次數num,鏈的長度len,鏈的權值w)

統計必過u的答案的時候,

用 任意兩個合法的 減去 任意兩個col相同的合法的,作異色答案

加上col相同且長度均為1的(不抵消的)合法的,作同色不抵消,

加上col相同且長度之和大於等於3的(抵消的)合法的,作同色抵消

減去全在v的答案,然後點分治下去即可

統計答案可以用雙指標統計,但是既然有sort的log,二分只是常數大一點…

程式碼

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<assert.h>
using namespace std;
#define pb push_back
typedef long long ll;
const int N=1e5+10;
int head[N],cnt;
struct edge{int v,nex;ll w;}e[2*N];
void add(int u,int v,ll w){e[++cnt]=edge{v,head[u],w};head[u]=cnt;} 
bool vis[N];
int n,r,u,v;
ll k,res,w;
int siz,f[N],sz[N],rt;
struct node{
	ll col,cnt,len,w,las,num;
}d[N],q[N],real; 
vector<node>now,tmp,my[4];
bool cmp1(const node &a,const node &b){
	return a.w-1ll*k*a.len<1ll*b.w-k*b.len;
}
bool cmp2(const node &a,const node &b){
	if(a.col!=b.col)return a.col<b.col;
	return a.cnt<b.cnt;
}
void init(int n){
	cnt=0;
	for(int i=1;i<=n;++i){
		head[i]=vis[i]=0;
	}
}
//找下一次的重心rt 
void getrt(int u,int fa,bool op){
	f[u]=0;sz[u]=1;
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].v;
		if(v==fa||vis[v])continue;
		getrt(v,u,op);
		f[u]=max(f[u],sz[v]);
		sz[u]+=sz[v];
	}
	if(op){
		f[u]=max(f[u],siz-sz[u]);
		if(f[u]<f[rt])rt=u;
	}
}
//計算重心u到子樹內每個點的距離 
void getdis(int u,int fa){
	q[++r]=d[u];
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].v;
		ll w=e[i].w;
		if(v==fa||vis[v])continue;
		d[v].len=d[u].len+1;
		d[v].las=w;
		d[v].w=d[u].w;
		if(d[u].col==0){
			d[v].col=w;
			d[v].cnt=1;
			d[v].num=1; 
		}
		else{
			d[v].col=d[u].col;
			d[v].num=(d[u].las==w)?(d[u].num+1):1;
			d[v].cnt=(d[v].num==d[v].len)?(d[u].cnt+1):d[u].cnt;
		}
		if(d[v].num<3){
			d[v].w+=w;
		}
		else if(d[v].num==3){
			d[v].w-=2ll*w;
		}
		getdis(v,u);
	}
}
//計算以u為根的子樹的答案
ll cal(int u,ll col,ll cnt,ll len,ll sum,ll las,ll num){
    r=0;d[u]={col,cnt,len,sum,las,num};
    getdis(u,0);
    ll ans=0;
    sort(q+1,q+r+1,cmp1);
    for(int i=1;i<=r;++i){
        int x=1,y=r;
        while(x<=y){
            int mid=(x+y)/2;
            if(q[i].w+q[mid].w>=1ll*k*(q[i].len+q[mid].len))y=mid-1;
            else x=mid+1;
        }
        ans+=(r-x+1);//[x,r]
        if(q[i].w+q[i].w>=1ll*k*(q[i].len+q[i].len))ans--;
    }
    sort(q+1,q+r+1,cmp2);
    for(int i=1;i<=r;){
        int j=i;
        now.clear();tmp.clear();
        for(int z=1;z<=3;++z){
			my[z].clear();
		}
        tmp.pb(q[i]);
        if(q[i].cnt==1)now.pb(q[i]);
        if(q[i].cnt){
			real=q[i];
			if(q[i].cnt<=2)real.w-=q[i].col*q[i].cnt;
			my[min(q[i].cnt,3ll)].pb(real);
		}
        while(j+1<=r && q[j+1].col==q[i].col){
            j++;
            tmp.pb(q[j]);
            if(q[j].cnt==1)now.pb(q[j]);
            if(q[j].cnt){
				real=q[j];
				if(q[j].cnt<=2)real.w-=q[j].col*q[j].cnt;
				my[min(q[j].cnt,3ll)].pb(real);
			}
        }
        sort(now.begin(),now.end(),cmp1);
        sort(tmp.begin(),tmp.end(),cmp1);
        for(int z=1;z<=3;++z){
        	sort(my[z].begin(),my[z].end(),cmp1);
        }
		int up=tmp.size();up--;
        for(int z=0;z<=up;++z){
            int x=0,y=up;
            while(x<=y){
                int mid=(x+y)/2;
                if(tmp[z].w+tmp[mid].w>=1ll*k*(tmp[z].len+tmp[mid].len))y=mid-1;
                else x=mid+1;
            }
            ans-=(up-x+1);
            if(tmp[z].w+tmp[z].w>=1ll*k*(tmp[z].len+tmp[z].len))ans++;
        }
        up=now.size();up--;
        for(int z=0;z<=up;++z){
            int x=0,y=up;
            while(x<=y){
                int mid=(x+y)/2;
                if(now[z].w+now[mid].w>=1ll*k*(now[z].len+now[mid].len))y=mid-1;
                else x=mid+1;
            }
            ans+=(up-x+1);
            if(now[z].w+now[z].w>=1ll*k*(now[z].len+now[z].len))ans--;
        }
        for(int s=1;s<=3;++s){
			int sz=my[s].size();
			for(int z=1;z<=3;++z){
				if(s+z<3)continue;
				int up=my[z].size();up--;
				for(int h=0;h<sz;++h){
					int x=0,y=up;
					while(x<=y){
						int mid=(x+y)/2;
						if(my[s][h].w+my[z][mid].w>=1ll*k*(my[s][h].len+my[z][mid].len))y=mid-1;
						else x=mid+1;
					}
					ans+=(up-x+1);
					if(s==z && my[s][h].w+my[z][h].w>=1ll*k*(my[s][h].len+my[z][h].len))ans--;
				}
			}
		}
        i=j+1;
    }
    return ans/2;
}
void dfs(int u){
	//每次用在u的子樹裡任取減去在v的子樹裡的答案
	//每次只計算 必經過u的答案 
	res+=cal(u,0,0,0,0,0,0);
	vis[u]=1;
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].v;
		ll w=e[i].w;
		if(vis[v])continue;
		res-=cal(v,w,1,1,w,w,1);
		getrt(v,u,0);//獲得正確的sz[v] 
		siz=sz[v];rt=0;
		getrt(v,u,1);
		dfs(rt);
	}
 } 
int main(){
	int T;
	scanf("%d",&T);
	while(T--){
		scanf("%d%lld",&n,&k);
		init(n);
		for(int i=1;i<n;++i){
			scanf("%d%d%lld",&u,&v,&w);
			add(u,v,w);add(v,u,w);
		}
		res=0;
		f[0]=siz=n;rt=0;
		getrt(1,0,1),dfs(rt);
		printf("%lld\n",res);
	}
	return 0;
}
/*
4
5 4
1 2 3
1 3 5
2 4 5
3 5 4
8 50809177
1 2 700805901
2 3 32145015
3 4 792263333
3 5 538420696
1 6 351870424
2 7 263716407
5 8 818097140
*/

 

相關文章