CF1575E
思路
點分治,記錄當前子樹到分治中心的權值和和換車次數。將新子樹的答案合併時分類討論分治中心到子樹祖先 \(u\to v\) 的顏色。樹狀陣列維護字首和。複雜度 \(O(n\log^2 n)\)。
code
int n,k,a[maxn],ans;
int head[maxn],tot;
struct nd{
int nxt,to,fl;
}e[maxn<<1];
void add(int u,int v,int fl){e[++tot]={head[u],v,fl};head[u]=tot;}
int siz[maxn],w[maxn],rt,sum;
bool vis[maxn];
void getrt(int u,int fa){
siz[u]=1,w[u]=0;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(v==fa||vis[v])continue;
getrt(v,u);
siz[u]+=siz[v],w[u]=max(w[u],siz[v]);
}
w[u]=max(w[u],sum-siz[u]);
if(w[u]<=sum/2)rt=u;
}
int dis[maxn],num[maxn],col[maxn];
vector<int> id[maxn];
int inc(int u,int v){(u+=v)>=mod&&(u-=mod);return u;}
void dfs(int u,int fa,int tp){
siz[u]=1;dis[u]=inc(dis[fa],a[u]);
id[tp].push_back(u);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(v==fa||vis[v])continue;
col[v]=e[i].fl;num[v]=num[u]+(col[u]!=col[v]);
dfs(v,u,tp);siz[u]+=siz[v];
}
// cout<<u<<" "<<tp<<" "<<dis[u]<<" "<<num[u]<<"\n";
}
struct bit{
pii add(pii u,pii v){return {inc(u.fi,v.fi),inc(u.se,v.se)};}
pii tree[maxn];
#define lb(x) (x&(-x))
void upd(int x,int m,pii w){
while(x<=m)tree[x]=add(tree[x],w),x+=lb(x);
}
pii que(int x){
pii res={0,0};
while(x>0)res=add(res,tree[x]),x-=lb(x);
return res;
}
}t[2];
void sovle(int u){
vis[u]=1;ans=inc(ans,a[u]);
siz[u]=1;dis[u]=0;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(vis[v])continue;
id[v].clear();num[v]=1;col[v]=e[i].fl;dfs(v,u,v);siz[u]+=siz[v];
}
for(int i=1;i<=siz[u];i++)t[0].tree[i]=t[1].tree[i]={0,0};
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(vis[v])continue;
for(int p:id[v]){
pii res1=t[e[i].fl].que(min(k-num[p]+1,siz[u])),res2=t[e[i].fl^1].que(min(k-num[p],siz[u]));
(ans+=(dis[p]+a[u])*(res1.se+res2.se)+res1.fi+res2.fi)%=mod;
if(num[p]<=k)ans=inc(ans,inc(dis[p],a[u]));
}
for(int p:id[v])t[e[i].fl].upd(num[p],siz[u],{dis[p],1});
}
// cout<<u<<" "<<ans<<"\n";
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(vis[v])continue;
sum=siz[v];getrt(v,u);sovle(rt);
}
}
void work(){
n=read();k=read()+1;
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<n;i++){
int u=read(),v=read(),fl=read();
add(u,v,fl),add(v,u,fl);
}
sum=n;getrt(1,0);sovle(rt);
printf("%lld\n",ans);
}