這道題目卡常卡了兩個半小時仍然沒有卡過。。。等進隊了讓隊友幫忙卡一下吧
主要想一下思路
最主要的就是在計算路徑長度的時候,假設當前遞迴到了點\(i\),那麼從點\(i\)出發的兩條路徑合併在一起,如果第一條邊的顏色相同的話就會重複計算,為了解決這個問題,我們只用對每個點進行排序,將相同顏色的點放在一起,處理相同顏色的點內部,以及不同顏色的點之間的路徑就好了,程式碼如下
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+10;
const ll INF=1e12;
int n,m,L,R;
ll c[N],maxdis[N];
ll Max[N<<2][2];
int siz[N],maxx[N];
struct node
{
int num,color;
}temp;
vector<node> G[N];
int tag[2][N],New[N],TAG[2],NEW;
bool vis[N],ins[N],tf[2][N];
int sum,rt;
ll ans;
ll dist[N];
int len[N];
inline ll mymax(ll a,ll b)
{
return a>b?a:b;
}
void calcsiz(int x,int fa)
{
siz[x]=1;
maxx[x]=0;
for(register int j=0;j<G[x].size();j++)
if(G[x][j].num!=fa&&!vis[G[x][j].num])
{
calcsiz(G[x][j].num,x);
maxx[x]=mymax(maxx[x],siz[G[x][j].num]);
siz[x]+=siz[G[x][j].num];
}
maxx[x]=mymax(maxx[x],sum-siz[x]);
if(maxx[x]<maxx[rt]) rt=x;
}
ll ask(int p,int l,int r,bool op,int x,int y)
{
if(l>y||r<x) return -INF;
if(l>=x&&r<=y) return Max[p][op];
int mid=l+r>>1;
return mymax(ask(p<<1,l,mid,op,x,y),ask(p<<1|1,mid+1,r,op,x,y));
}
void calcdist(int x,int fa,int col)
{
maxdis[len[x]]=mymax(maxdis[len[x]],dist[x]);
//這裡卡常,然而卻沒有什麼用
if(!ins[len[x]])
{
New[NEW++]=len[x];
ins[len[x]]=1;
if(!tf[0][len[x]])
{
tf[0][len[x]]=1;
tag[0][TAG[0]++]=len[x];
}
}
if(len[x]<=R)
ans=mymax(ans,dist[x]+ask(1,1,n+1,0,mymax(L-len[x],0)+1,R-len[x]+1));
for(register int j=0;j<G[x].size();j++)
if(G[x][j].num!=fa&&!vis[G[x][j].num])
{
len[G[x][j].num]=len[x]+1;
if(col!=G[x][j].color) dist[G[x][j].num]=dist[x]+c[G[x][j].color];
else dist[G[x][j].num]=dist[x];
calcdist(G[x][j].num,x,G[x][j].color);
}
}
void modify(int p,int l,int r,bool op,int x,ll d)
{
if(l>x||r<x) return;
if(l==r)
{
if(d!=-INF) Max[p][op]=mymax(d,Max[p][op]);
else Max[p][op]=d;
//這裡千萬注意,如果是還原的話可以直接賦值,但是如果是更新的話一定要先比較
return;
}
int mid=l+r>>1;
modify(p<<1,l,mid,op,x,d);
modify(p<<1|1,mid+1,r,op,x,d);
Max[p][op]=mymax(Max[p<<1][op],Max[p<<1|1][op]);
}
bool cmp(node i,node j)
{
if(vis[i.num]==vis[j.num]) return i.color<j.color;
else return vis[i.num]<vis[j.num];
}
void dp(int x,int fa,int col)
{
maxdis[len[x]]=mymax(maxdis[len[x]],dist[x]);
if(!ins[len[x]])
{
New[NEW++]=len[x];
ins[len[x]]=1;
}
if(len[x]<=R)
ans=mymax(ans,dist[x]+ask(1,1,n+1,1,mymax(L-len[x],0)+1,R-len[x]+1)-c[col]);
for(register int j=0;j<G[x].size();j++)
if(G[x][j].num!=fa&&!vis[G[x][j].num])
dp(G[x][j].num,x,col);
}
void dfs(int x,int fa)
{
tf[0][0]=1;
tag[0][TAG[0]++]=0;
modify(1,1,n+1,0,1,0);
vis[x]=1;
sort(G[x].begin(),G[x].end(),cmp);
//將顏色相同的放在一起
for(register int j=0;j<G[x].size();)
if(G[x][j].num!=fa&&!vis[G[x][j].num])
{
int k=j;
while(j<G[x].size()&&!vis[G[x][j].num]&&G[x][k].color==G[x][j].color)
{
len[G[x][j].num]=1,dist[G[x][j].num]=c[G[x][j].color];
calcdist(G[x][j].num,x,G[x][j].color);
j++;
}//找出同一顏色的兒子
for(register int o=0;o<NEW;o++) modify(1,1,n+1,0,New[o]+1,maxdis[New[o]]);
for(register int o=0;o<NEW;o++)
maxdis[New[o]]=-INF;
for(register int o=0;o<NEW;o++)
ins[New[o]]=0;
NEW=0;
for(register int o=k;o<j;o++)//處理同一顏色的點
{
dp(G[x][o].num,x,G[x][o].color);
for(register int w=0;w<NEW;w++) modify(1,1,n+1,1,New[w]+1,maxdis[New[w]]);
for(register int w=0;w<NEW;w++)
if(!tf[1][New[w]])
{
tf[1][New[w]]=1;
tag[1][TAG[1]++]=New[w];
}
for(register int w=0;w<NEW;w++)
maxdis[New[w]]=-INF;
for(register int w=0;w<NEW;w++)
ins[New[w]]=0;
NEW=0;
}
for(register int o=0;o<TAG[1];o++)
modify(1,1,n+1,1,tag[1][o]+1,-INF);
for(register int o=0;o<TAG[1];o++)
tf[1][tag[1][o]]=0;
TAG[1]=0;
}
else break;
for(register int j=0;j<TAG[0];j++)
modify(1,1,n+1,0,tag[0][j]+1,-INF);
for(register int j=0;j<TAG[0];j++)
tf[0][tag[0][j]]=0;
TAG[0]=0;
for(register int j=0;j<G[x].size();j++)
if(G[x][j].num!=fa&&!vis[G[x][j].num])
{
sum=siz[G[x][j].num];
rt=0;
maxx[rt]=n+1;
calcsiz(G[x][j].num,x);
calcsiz(rt,0);
dfs(rt,x);
}
}
void build(int p,int l,int r)
{
if(l==r)
{
for(register int i=0;i<=1;i++)
Max[p][i]=-INF;
return;
}
int mid=l+r>>1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
for(register int i=0;i<=1;i++)
Max[p][i]=-INF;
}
int read()
{
int x=0,f=1;char s=getchar();
while(s<'0'||s>'9'){if(s=='-')f=-f;s=getchar();}
while(s>='0'&&s<='9'){x=x*10+s-48;s=getchar();}
return x*f;
}
int main()
{
n=read(),m=read(),L=read(),R=read();
for(register int i=1;i<=n;i++) maxdis[i]=-INF;
build(1,1,n+1);
//注意,線段樹的下標表示長度,而且整體加了一
//0表示不同顏色的線段樹,1表示相同顏色的線段樹
for(register int i=1;i<=m;i++)
c[i]=read();
for(register int i=1,a,b,c;i<n;i++)
{
a=read(),b=read(),c=read();
temp.num=b,temp.color=c;
G[a].push_back(temp);
temp.num=a;
G[b].push_back(temp);
}
rt=0;
maxx[rt]=n+1;
sum=n;
calcsiz(1,0);
calcsiz(rt,0);
ans=-INF;
dfs(rt,0);
printf("%lld",ans);
return 0;
}