CF519E A and B and Lecture Rooms(樹上倍增 + 分類討論)

Zhang_Wenjie發表於2024-06-21

link

一眼看上去沒什麼思路,手摸一下樣例,發現有不同性質的點對求解想法很不一樣,考慮先分類討論看看。

從簡單的約束到強的約束分類討論,這樣更可做,也更好討論,

比如首先我就想到兩點是否重合,然後所求點一定要到兩點的距離相等,我就想到路徑長度的奇偶性,接著就考慮複雜的深度關係 ......

對於當前點對 (u,v),

  • 如果 u = v,則結果就是 n;

  • 如果 u ≠ v

    • 如果 u -> v 的簡單路徑長度為奇數,則結果為 0;

    • 如果 u -> v 的簡單路徑長度為偶數

      • 如果 \(dep[u] = dep[v]\),可以發現路徑上到 lca 肯定是符合要求的中點 mid,再往以 lca 這顆子樹往外走就步步重合了,也就是說都成立,如果 lca 子樹中到 u、v 的兩顆子樹的根節點分別為 fu、fv,則有:

      \[size[1] - size[fu] - size[fv] \]

      • 如果 \(dep[u] \not = dep[v]\),假設令 \(dep[u] > dep[v]\),肯定還是要找路徑上的中點 mid,但就一定不是 lca 了,但一定在較深的那顆子樹裡,所以還得從 u 往上跳一半的距離到 mid,顯然滿足的點肯定在 mid 延伸的子樹上,且不是 u、v 所在的子樹,這裡只用考慮減掉 u 所在的子樹大小就可以了,因為 v 肯定在 mid 往上走的路上,即:

        \[size[mid] - size[fu] \]

一開始實現後兩部分討論,我只用了普通的暴力上跳,\(O(mn)\) 的複雜度就 T 了(太投入了,我就老是會把想到腦邊的做法先實現了再說)

code
#include <bits/stdc++.h>
#define re register int 

using namespace std;
const int N = 1e5 + 10, logN = 50; 

struct edge
{
	int to, next; 
}e[N << 1];
int top, h[N], dep[N], f[N][logN], size[N];
int n, m;

inline void add(int x, int y)
{
	e[++ top] = (edge){y, h[x]};
	h[x] = top;
}

void dfs(int u, int fa)
{
	dep[u] = dep[fa] + 1;
	
	f[u][0] = fa;
	for (re i = 1; i <= log2(n); i ++)
		f[u][i] = f[f[u][i - 1]][i - 1];
		
	for (re i = h[u]; i; i = e[i].next)
	{
		int v = e[i].to;
		
		if (v == fa) continue;
		dfs(v, u);
		
		size[u] += size[v];
	}
}

inline int lca(int u, int v)
{
	if (dep[u] < dep[v]) swap(u, v);
	
	for (re i = log2(n); i >= 0; i --)
		if (dep[f[u][i]] >= dep[v]) u = f[u][i];
		
	if (u == v) return u;
	
	for (re i = log2(n); i >= 0; i --)
		if (f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
		
	return f[u][0];
}

inline int work(int u, int v, int p)
{
	if (u == v) return n;
	
	int dist = dep[u] + dep[v] - 2 * dep[p];
	if (dist % 2) return 0;
	
	if (dep[u] == dep[v]) 
	{
//		return size[1] - size[p] + 1; 當 p = 1 時不成立
		
		int fu = u, fv = v, du, dv;
		 
		du = dv = dist / 2 - 1;
		while (du)
		{
			fu = f[fu][0];
			du --;	
		}  
		while (dv)
		{
			fv = f[fv][0];
			dv --;
		}
		
		return size[1] - size[fu] - size[fv];
	}
	else
	{
		if (dep[u] < dep[v]) swap(u, v);
		
		int fu = u, mid;
		dist = dist / 2 - 1;
		while (dist)
		{
			fu = f[fu][0];
			dist --;
		}
		mid = f[fu][0];
		
		return size[mid] - size[fu];
	}
}

int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0); cout.tie(0);
	
	cin >> n;
	for (re i = 1; i < n; i ++)
	{
		int x, y; cin >> x >> y;
		add(x, y); add(y, x);
	}
	for (re i = 1; i <= n; i ++) size[i] = 1;
	dfs(1, 0);
	
	cin >> m;
	while (m --)
	{
		int x, y; cin >> x >> y;
		
		cout << work(x, y, lca(x, y)) << '\n';
	}
	
	return 0;
}

所以顯然是要倍增上跳就可以了。\(O(m\log n)\)

#include <bits/stdc++.h>
#define re register int 

using namespace std;
const int N = 1e5 + 10, logN = 50; 

struct edge
{
	int to, next; 
}e[N << 1];
int top, h[N], dep[N], f[N][logN], size[N];
int n, m;

inline void add(int x, int y)
{
	e[++ top] = (edge){y, h[x]};
	h[x] = top;
}

void dfs(int u, int fa)
{
	dep[u] = dep[fa] + 1;
	
	f[u][0] = fa;
	for (re i = 1; i <= log2(n); i ++)
		f[u][i] = f[f[u][i - 1]][i - 1];
		
	for (re i = h[u]; i; i = e[i].next)
	{
		int v = e[i].to;
		
		if (v == fa) continue;
		dfs(v, u);
		
		size[u] += size[v];
	}
}

inline int lca(int u, int v, int type)
{
	if (dep[u] < dep[v]) swap(u, v);
	
	for (re i = log2(n); i >= 0; i --)
		if (dep[f[u][i]] >= dep[v]) u = f[u][i];
		
	if (u == v) return u;
	
	for (re i = log2(n); i >= 0; i --)
		if (f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
		
	
	if (!type) return f[u][0];
	else return size[1] - size[u] - size[v];
}

inline int work(int u, int v, int p)
{
	if (u == v) return n;
	
	int dist = dep[u] + dep[v] - 2 * dep[p];
	if (dist % 2) return 0;
	
	if (dep[u] == dep[v]) 
	{
//		return size[1] - size[p] + 1; 當 p = 1 時不成立

		return lca(u, v, 1);
	}
	else
	{
		if (dep[u] < dep[v]) swap(u, v);
	
		int step = dist / 2;
		int cnt = size[u];
		
		for(int i = log2(n); i >= 0; i--)
			if(step - (1 << i) > 0)
			{
				u = f[u][i];
				cnt += size[u] - cnt;
				
				step -= (1 << i);
			}
		u = f[u][0];
		
		return size[u] - cnt;
	}
}

int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0); cout.tie(0);
	
	cin >> n;
	for (re i = 1; i < n; i ++)
	{
		int x, y; cin >> x >> y;
		add(x, y); add(y, x);
	}
	for (re i = 1; i <= n; i ++) size[i] = 1;
	dfs(1, 0);
	
	cin >> m;
	while (m --)
	{
		int x, y; cin >> x >> y;
		
		cout << work(x, y, lca(x, y, 0)) << '\n';
	}
	
	return 0;
}

相關文章