LG P3233 [HNOI2014]世界樹(虛樹,dp)

All_fade_away發表於2020-12-30

LG P3233 [HNOI2014]世界樹

Solution

看完題意,顯然是虛樹。

建出虛樹後,可以容易地求出虛樹上的點會被哪一個點管轄,關鍵在於不在虛樹上的點歸屬於哪個點,我們分類討論不在虛樹上的點的貢獻:

我們先假設虛樹上的點全是關鍵點,注意後文的子樹都是原樹的子樹。

  1. 在虛樹上點 x , y x,y x,y路徑上(不包含 x , y x,y x,y)的點(設依次為 v 1 , v 2 . . . v k v_1,v_2...v_k v1,v2...vk,它們不在虛樹上)及其子樹中的點:它們要麼屬於 x x x,要麼屬於 y y y,且必然存在一個 m i d mid mid,使得 v 1 , v 2 . . . v m i d − 1 v_1,v_2...v_{mid-1} v1,v2...vmid1屬於 x x x v m i d . . . v k v_{mid}...v_k vmid...vk屬於 y y y,而求解這個 m i d mid mid位置的判定條件是 d i s t x , m i d dist_{x,mid} distx,mid d i s t m i d , y dist_{mid,y} distmid,y的大小(大小相同看編號大小),這個可以通過二分簡單地得到。而對於那些 v i v_i vi子樹中的點,一定和 v i v_i vi的歸屬相同。
  2. 在虛樹上的點 x x x的子樹中不在虛樹上的兒子 v v v以及它的子樹中的點:也就是 v v v子樹中沒有關鍵點,那麼一定整個子樹歸屬於 x x x,直接統計即可。
  3. 完全不在虛樹上的點:它們一定不在虛樹的根的子樹內(可以理解為在虛樹的上面),它們一定歸屬於虛樹的根。

實現時,我們通過一個向上和一個向下的 d p dp dp求出虛樹上點的歸屬。
然後再對於每個點 x x x,列舉其出邊 v v v,求出 m i d mid mid,計算 x , v x,v x,v的新增貢獻。
並且記錄一個 g x g_x gx表示 2 , 3 2,3 2,3類的答案,初始為子樹大小,列舉出邊 v v v時,把 x x x包含 v v v的兒子的子樹結點個數去掉,最後讓 x x x的貢獻加上 g x g_x gx即可。

時間複雜度 O ( n l g n ) O(nlgn) O(nlgn)

有一個實現過程中的小 t r i c k trick trick是建虛樹時直接把 1 1 1結點放入虛樹,會大大減少一些不必要的分類討論。

Code

#include <vector>
#include <list>
#include <map>
#include <set>
#include <deque>
#include <queue>
#include <stack>
#include <bitset>
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cctype>
#include <string>
#include <cstring>
#include <ctime>
#include <cassert>
#include <string.h>
//#include <unordered_set>
//#include <unordered_map>
//#include <bits/stdc++.h>

#define MP(A,B) make_pair(A,B)
#define PB(A) push_back(A)
#define SIZE(A) ((int)A.size())
#define LEN(A) ((int)A.length())
#define FOR(i,a,b) for(int i=(a);i<(b);++i)
#define fi first
#define se second

using namespace std;

template<typename T>inline bool upmin(T &x,T y) { return y<x?x=y,1:0; }
template<typename T>inline bool upmax(T &x,T y) { return x<y?x=y,1:0; }

typedef long long ll;
typedef unsigned long long ull;
typedef long double lod;
typedef pair<int,int> PR;
typedef vector<int> VI;

const lod eps=1e-11;
const lod pi=acos(-1);
const int oo=1<<30;
const ll loo=1ll<<62;
const int mods=1e9+7;
const int MAXN=600005;
const int INF=0x3f3f3f3f;//1061109567
/*--------------------------------------------------------------------*/
inline int read()
{
	int f=1,x=0; char c=getchar();
	while (c<'0'||c>'9') { if (c=='-') f=-1; c=getchar(); }
	while (c>='0'&&c<='9') { x=(x<<3)+(x<<1)+(c^48); c=getchar(); }
	return x*f;
}
PR mn[MAXN];
vector<int> e[MAXN],E[MAXN];
int a[MAXN],b[MAXN],f[MAXN],g[MAXN],stk[MAXN],top=0,n,m;
int dep[MAXN],sz[MAXN],Log[MAXN],dfn[MAXN],fa[MAXN][20],head[MAXN],flag[MAXN],DFN=0,edgenum;
int getlca(int x,int y)
{
	if (dep[x]<dep[y]) swap(x,y);
	for (int i=Log[dep[x]];i>=0;i--)
		if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
	if (x==y) return x;
	for (int i=Log[dep[x]];i>=0;i--)
		if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}
int jump(int x,int d)
{
	for (int i=Log[dep[x]];i>=0;i--)
		if (dep[fa[x][i]]>=d) x=fa[x][i];
	return x;
}
void dfs(int x,int father)
{
	fa[x][0]=father,sz[x]=1,dep[x]=dep[father]+1,dfn[x]=++DFN;
	for (int i=1;i<=Log[dep[x]];i++) fa[x][i]=fa[fa[x][i-1]][i-1];
	for (auto v:e[x]) if (v!=father) dfs(v,x),sz[x]+=sz[v];
}
void Init()
{
	dep[0]=-1,Log[1]=0;
	for (int i=1;i<=n;i++) Log[i]=Log[i>>1]+1;
	dfs(1,0);
}


void add(int u,int v) { E[u].PB(v); }
void build()
{
	sort(a+1,a+m+1,[&](int x,int y){ return dfn[x]<dfn[y]; });
	stk[top=1]=1;
	for (int i=1+(a[1]==1);i<=m;i++)
	{
		int lca=getlca(stk[top],a[i]);
		while (top>1&&dep[stk[top-1]]>dep[lca]) add(stk[top-1],stk[top]),top--;
		if (dep[stk[top]]>dep[lca]) add(lca,stk[top--]);
		if (!top||stk[top]!=lca) stk[++top]=lca;
		stk[++top]=a[i];
	}
	while (top>1) add(stk[top-1],stk[top]),top--;
}


void up(int x,int father)
{
	mn[x]=(flag[x]?MP(0,x):MP(INF,x));
	for (auto v:E[x])
	{
		if (v==father) continue;
		up(v,x),upmin(mn[x],MP(mn[v].fi+dep[v]-dep[x],mn[v].se));
	}
}
void down(int x,int father)
{
	for (auto v:E[x])
		if (v!=father) upmin(mn[v],MP(mn[x].fi+dep[v]-dep[x],mn[x].se)),down(v,x);
}

void tree_dp(int x,int father)
{
	for (auto v:E[x])
		if (v!=father) tree_dp(v,x);
	g[x]=sz[x];
	for (auto v:E[x])
	{
		int t=jump(v,dep[x]+1); g[x]-=sz[t];
		if (mn[x].se==mn[v].se) { f[mn[x].se]+=sz[t]-sz[v]; continue; }

		int mid=v;
		for (int i=Log[dep[v]];i>=0;i--)
		{
			int p=fa[mid][i];
			if (dep[p]<=dep[x]) continue;
			if (MP(dep[p]-dep[x]+mn[x].fi,mn[x].se)>MP(dep[v]-dep[p]+mn[v].fi,mn[v].se)) mid=p;
		}
		f[mn[x].se]+=sz[t]-sz[mid];
		f[mn[v].se]+=sz[mid]-sz[v];
	}
	f[mn[x].se]+=g[x];
}

void clean(int x,int father)
{
	for (auto v:E[x]) if (v!=father) clean(v,x);
	f[x]=g[x]=0,E[x].clear();
}
void clear()
{
	for (int i=1;i<=m;i++) flag[a[i]]=0;
	clean(1,0),top=0;
}

signed main()
{
	n=read();
	for (int i=1,u,v;i<n;i++) u=read(),v=read(),e[u].PB(v),e[v].PB(u);
	Init();
	int Case=read();
	while (Case--)
	{
		m=read();
		for (int i=1;i<=m;i++) a[i]=b[i]=read(),flag[a[i]]=1;
		build(),up(1,0),down(1,0),tree_dp(1,0);
		for (int i=1;i<=m;i++) printf("%d ",f[b[i]]); puts("");
		clear();
	}
	return 0;
}

相關文章