link
一眼看上去沒什麼思路,手摸一下樣例,發現有不同性質的點對求解想法很不一樣,考慮先分類討論看看。
從簡單的約束到強的約束分類討論,這樣更可做,也更好討論,
比如首先我就想到兩點是否重合,然後所求點一定要到兩點的距離相等,我就想到路徑長度的奇偶性,接著就考慮複雜的深度關係 ......
對於當前點對 (u,v),
-
如果 u = v,則結果就是 n;
-
如果 u ≠ v
-
如果 u -> v 的簡單路徑長度為奇數,則結果為 0;
-
如果 u -> v 的簡單路徑長度為偶數
- 如果 \(dep[u] = dep[v]\),可以發現路徑上到 lca 肯定是符合要求的中點 mid,再往以 lca 這顆子樹往外走就步步重合了,也就是說都成立,如果 lca 子樹中到 u、v 的兩顆子樹的根節點分別為 fu、fv,則有:
\[size[1] - size[fu] - size[fv] \]
-
-
-
- 如果 \(dep[u] \not = dep[v]\),假設令 \(dep[u] > dep[v]\),肯定還是要找路徑上的中點 mid,但就一定不是 lca 了,但一定在較深的那顆子樹裡,所以還得從 u 往上跳一半的距離到 mid,顯然滿足的點肯定在 mid 延伸的子樹上,且不是 u、v 所在的子樹,這裡只用考慮減掉 u 所在的子樹大小就可以了,因為 v 肯定在 mid 往上走的路上,即:\[size[mid] - size[fu] \]
- 如果 \(dep[u] \not = dep[v]\),假設令 \(dep[u] > dep[v]\),肯定還是要找路徑上的中點 mid,但就一定不是 lca 了,但一定在較深的那顆子樹裡,所以還得從 u 往上跳一半的距離到 mid,顯然滿足的點肯定在 mid 延伸的子樹上,且不是 u、v 所在的子樹,這裡只用考慮減掉 u 所在的子樹大小就可以了,因為 v 肯定在 mid 往上走的路上,即:
-
一開始實現後兩部分討論,我只用了普通的暴力上跳,\(O(mn)\) 的複雜度就 T 了(太投入了,我就老是會把想到腦邊的做法先實現了再說)
code
#include <bits/stdc++.h>
#define re register int
using namespace std;
const int N = 1e5 + 10, logN = 50;
struct edge
{
int to, next;
}e[N << 1];
int top, h[N], dep[N], f[N][logN], size[N];
int n, m;
inline void add(int x, int y)
{
e[++ top] = (edge){y, h[x]};
h[x] = top;
}
void dfs(int u, int fa)
{
dep[u] = dep[fa] + 1;
f[u][0] = fa;
for (re i = 1; i <= log2(n); i ++)
f[u][i] = f[f[u][i - 1]][i - 1];
for (re i = h[u]; i; i = e[i].next)
{
int v = e[i].to;
if (v == fa) continue;
dfs(v, u);
size[u] += size[v];
}
}
inline int lca(int u, int v)
{
if (dep[u] < dep[v]) swap(u, v);
for (re i = log2(n); i >= 0; i --)
if (dep[f[u][i]] >= dep[v]) u = f[u][i];
if (u == v) return u;
for (re i = log2(n); i >= 0; i --)
if (f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
return f[u][0];
}
inline int work(int u, int v, int p)
{
if (u == v) return n;
int dist = dep[u] + dep[v] - 2 * dep[p];
if (dist % 2) return 0;
if (dep[u] == dep[v])
{
// return size[1] - size[p] + 1; 當 p = 1 時不成立
int fu = u, fv = v, du, dv;
du = dv = dist / 2 - 1;
while (du)
{
fu = f[fu][0];
du --;
}
while (dv)
{
fv = f[fv][0];
dv --;
}
return size[1] - size[fu] - size[fv];
}
else
{
if (dep[u] < dep[v]) swap(u, v);
int fu = u, mid;
dist = dist / 2 - 1;
while (dist)
{
fu = f[fu][0];
dist --;
}
mid = f[fu][0];
return size[mid] - size[fu];
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n;
for (re i = 1; i < n; i ++)
{
int x, y; cin >> x >> y;
add(x, y); add(y, x);
}
for (re i = 1; i <= n; i ++) size[i] = 1;
dfs(1, 0);
cin >> m;
while (m --)
{
int x, y; cin >> x >> y;
cout << work(x, y, lca(x, y)) << '\n';
}
return 0;
}
所以顯然是要倍增上跳就可以了。\(O(m\log n)\)
#include <bits/stdc++.h>
#define re register int
using namespace std;
const int N = 1e5 + 10, logN = 50;
struct edge
{
int to, next;
}e[N << 1];
int top, h[N], dep[N], f[N][logN], size[N];
int n, m;
inline void add(int x, int y)
{
e[++ top] = (edge){y, h[x]};
h[x] = top;
}
void dfs(int u, int fa)
{
dep[u] = dep[fa] + 1;
f[u][0] = fa;
for (re i = 1; i <= log2(n); i ++)
f[u][i] = f[f[u][i - 1]][i - 1];
for (re i = h[u]; i; i = e[i].next)
{
int v = e[i].to;
if (v == fa) continue;
dfs(v, u);
size[u] += size[v];
}
}
inline int lca(int u, int v, int type)
{
if (dep[u] < dep[v]) swap(u, v);
for (re i = log2(n); i >= 0; i --)
if (dep[f[u][i]] >= dep[v]) u = f[u][i];
if (u == v) return u;
for (re i = log2(n); i >= 0; i --)
if (f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
if (!type) return f[u][0];
else return size[1] - size[u] - size[v];
}
inline int work(int u, int v, int p)
{
if (u == v) return n;
int dist = dep[u] + dep[v] - 2 * dep[p];
if (dist % 2) return 0;
if (dep[u] == dep[v])
{
// return size[1] - size[p] + 1; 當 p = 1 時不成立
return lca(u, v, 1);
}
else
{
if (dep[u] < dep[v]) swap(u, v);
int step = dist / 2;
int cnt = size[u];
for(int i = log2(n); i >= 0; i--)
if(step - (1 << i) > 0)
{
u = f[u][i];
cnt += size[u] - cnt;
step -= (1 << i);
}
u = f[u][0];
return size[u] - cnt;
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n;
for (re i = 1; i < n; i ++)
{
int x, y; cin >> x >> y;
add(x, y); add(y, x);
}
for (re i = 1; i <= n; i ++) size[i] = 1;
dfs(1, 0);
cin >> m;
while (m --)
{
int x, y; cin >> x >> y;
cout << work(x, y, lca(x, y, 0)) << '\n';
}
return 0;
}