CF1575E 題解

yhddd發表於2024-06-06

CF1575E

思路

點分治,記錄當前子樹到分治中心的權值和和換車次數。將新子樹的答案合併時分類討論分治中心到子樹祖先 \(u\to v\) 的顏色。樹狀陣列維護字首和。複雜度 \(O(n\log^2 n)\)

code

int n,k,a[maxn],ans;
int head[maxn],tot;
struct nd{
	int nxt,to,fl;
}e[maxn<<1];
void add(int u,int v,int fl){e[++tot]={head[u],v,fl};head[u]=tot;}
int siz[maxn],w[maxn],rt,sum;
bool vis[maxn];
void getrt(int u,int fa){
	siz[u]=1,w[u]=0;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;if(v==fa||vis[v])continue;
		getrt(v,u);
		siz[u]+=siz[v],w[u]=max(w[u],siz[v]);
	}
	w[u]=max(w[u],sum-siz[u]);
	if(w[u]<=sum/2)rt=u;
}
int dis[maxn],num[maxn],col[maxn];
vector<int> id[maxn];
int inc(int u,int v){(u+=v)>=mod&&(u-=mod);return u;}
void dfs(int u,int fa,int tp){
	siz[u]=1;dis[u]=inc(dis[fa],a[u]);
	id[tp].push_back(u);
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;if(v==fa||vis[v])continue;
		col[v]=e[i].fl;num[v]=num[u]+(col[u]!=col[v]);
		dfs(v,u,tp);siz[u]+=siz[v];
	}
//	cout<<u<<" "<<tp<<" "<<dis[u]<<" "<<num[u]<<"\n";
}
struct bit{
	pii add(pii u,pii v){return {inc(u.fi,v.fi),inc(u.se,v.se)};}
	pii tree[maxn];
	#define lb(x) (x&(-x))
	void upd(int x,int m,pii w){
		while(x<=m)tree[x]=add(tree[x],w),x+=lb(x);
	}
	pii que(int x){
		pii res={0,0};
		while(x>0)res=add(res,tree[x]),x-=lb(x);
		return res;
	}
}t[2];
void sovle(int u){
	vis[u]=1;ans=inc(ans,a[u]);
	siz[u]=1;dis[u]=0;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;if(vis[v])continue;
		id[v].clear();num[v]=1;col[v]=e[i].fl;dfs(v,u,v);siz[u]+=siz[v];
	}
	for(int i=1;i<=siz[u];i++)t[0].tree[i]=t[1].tree[i]={0,0};
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;if(vis[v])continue;
		for(int p:id[v]){
			pii res1=t[e[i].fl].que(min(k-num[p]+1,siz[u])),res2=t[e[i].fl^1].que(min(k-num[p],siz[u]));
			(ans+=(dis[p]+a[u])*(res1.se+res2.se)+res1.fi+res2.fi)%=mod;
			if(num[p]<=k)ans=inc(ans,inc(dis[p],a[u]));
		}
		for(int p:id[v])t[e[i].fl].upd(num[p],siz[u],{dis[p],1});
	}
//	cout<<u<<" "<<ans<<"\n";
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;if(vis[v])continue;
		sum=siz[v];getrt(v,u);sovle(rt);
	}
}
void work(){
	n=read();k=read()+1;
	for(int i=1;i<=n;i++)a[i]=read();
	for(int i=1;i<n;i++){
		int u=read(),v=read(),fl=read();
		add(u,v,fl),add(v,u,fl);
	}
	sum=n;getrt(1,0);sovle(rt);
	printf("%lld\n",ans);
}