題意
給定一顆樹,每個點有權值 \(1\) 和 \(-1\),稱一條路徑是好的當且僅當路徑上所有點的權值和為 \(0\)。
求連續編號區間 \([l, r]\) 使得兩個點都在 \([l, r]\) 的好路徑比兩個點都不在 \([l, r]\) 的好路徑數嚴格多的方案數。
\(n \le 10 ^ 5\)。
Sol
兩個端點都在區間內不好做,設一個區間的權值為 \(f_{[l, r]}\)。
因此答案為 \(\sum [f_{[l, r]} > f_{[1, l - 1] \cup [r + 1, n]}]\)。
集中注意力,考慮至少一個端點在區間內的情況,發現好像兩邊可以約掉!
具體地,至少一個端點在 \([l, r]\) 的方案數 等於 \(f_{[l, r]}\) 加上 有一個端點在 \([l, r]\) 一個端點在 \([1, l- 1] \cup [r + 1, n]\) 的方案數,於是直接約掉了。
考慮我們現在可以求出什麼,設 \(g_i\) 表示一個端點為 \(i\) 的合法路徑數,這個東西可以簡單使用點分治求得。
最後因為合法區間具有單調性,直接雙指標計算最小的合法區間即可。
複雜度 \(O(n \log n)\)。
Code
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <bitset>
#define ll long long
#define pii pair <int, int>
using namespace std;
#ifdef ONLINE_JUDGE
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;
#endif
int read() {
int p = 0, flg = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') flg = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
p = p * 10 + c - '0';
c = getchar();
}
return p * flg;
}
void write(ll x) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x > 9) {
write(x / 10);
}
putchar(x % 10 + '0');
}
bool _stmer;
#define fi first
#define se second
const int N = 1e5 + 5, M = 2e5 + 5;
namespace G {
array <int, N> fir;
array <int, M> nex, to;
int cnt = 1;
void add(int x, int y) {
cnt++;
nex[cnt] = fir[x];
to[cnt] = y;
fir[x] = cnt;
}
} //namespace G
array <int, N> len, siz;
bitset <N> vis;
void dfs1(int x, int fa) {
siz[x] = 1;
for (int i = G::fir[x]; i; i = G::nex[i]) {
if (vis[G::to[i]] || G::to[i] == fa) continue;
dfs1(G::to[i], x), siz[x] += siz[G::to[i]];
}
}
pii rt;
void dfs2(int x, int fa, int Rt) {
int tp = 0;
for (int i = G::fir[x]; i; i = G::nex[i]) {
if (vis[G::to[i]] || G::to[i] == fa) continue;
dfs2(G::to[i], x, Rt), tp = max(tp, siz[G::to[i]]);
}
tp = max(tp, siz[Rt] - siz[x]);
if (tp < rt.fi) rt = make_pair(tp, x);
}
array <int, M> isl;
array <int, N> dis;
void dfs3(int x, int pl, int fa) {
isl[dis[x]] += pl;
for (int i = G::fir[x]; i; i = G::nex[i]) {
if (vis[G::to[i]] || G::to[i] == fa) continue;
dis[G::to[i]] = dis[x] + len[G::to[i]];
dfs3(G::to[i], pl, x);
}
}
array <int, N> f;
void dfs4(int x, int fa, int sum) {
f[x] += isl[1e5 - sum];
for (int i = G::fir[x]; i; i = G::nex[i]) {
if (vis[G::to[i]] || G::to[i] == fa) continue;
dfs4(G::to[i], x, sum + len[G::to[i]]);
}
}
void solve(int x) {
dis[x] = 1e5 + len[x];
dfs3(x, 1, 0);
f[x] += isl[1e5];
for (int i = G::fir[x]; i; i = G::nex[i])
if (!vis[G::to[i]])
dfs3(G::to[i], -1, x), dfs4(G::to[i], x, len[G::to[i]]), dfs3(G::to[i], 1, x);
dfs3(x, -1, 0);
}
void divide(int x) {
rt = make_pair(2e9, 0);
dfs1(x, 0), dfs2(x, 0, x);
x = rt.se, vis[x] = 1, solve(x);
for (int i = G::fir[x]; i; i = G::nex[i])
if (!vis[G::to[i]]) divide(G::to[i]);
}
bool _edmer;
int main() {
cerr << (&_stmer - &_edmer) / 1024.0 / 1024.0 << "MB\n";
int n = read();
for (int i = 1, x; i <= n; i++)
x = read(), len[i] = x ? 1 : -1;
for (int i = 2, x, y; i <= n; i++)
x = read(), y = read(), G::add(x, y), G::add(y, x);
divide(1);
ll res = 0, sum = 0, ans = 0;
for (int i = 1; i <= n; i++) res += f[i];
for (int i = 1, lst = 1; i <= n; i++) {
sum += f[i];
while (lst < i && sum > res - sum) sum -= f[lst++];
ans += lst - 1;
}
write(ans), puts("");
return 0;
}