SDOI2018 原題識別(主席樹)

WAautomaton發表於2019-02-14

題目連結

題目大意

給定nn個節點的樹,其中包含一條非隨機生成的長度為kk的鏈,剩下的節點均隨機父節點連邊。每個節點有一個隨機的顏色,維護:
1.給定x,yx,y,求x,yx,y之間不同顏色數。
2.給定x,yx,y,對於所有滿足分別在x,yx,y到根的路徑上的點a,ba,b,求其詢問1的答案之和。
n105,m2×105n\le 10^5,m\le 2\times 10^5

題解

碼量比較大qwq……
我們先從鏈上的情況入手考慮。

鏈的情況

對於第一問,這是經典二維數點題。考慮pip_i表示ii之前第一個和它顏色相同的位置。我們以(i,pi)(i,p_i)為座標建點,詢問不同顏色數就相當於詢問xx座標位於[l,r][l,r]yy座標小於ll的點個數。直接主席樹維護即可。
對於第二問,我們考慮點ii對答案的貢獻。不妨設xyx\le y,我們分三種情況討論:
1.x<iyx<i\le y,此時貢獻應該是[pix](xpi)(yi+1)[p_i\le x](x-p_i)(y-i+1)
2.1ix1\le i\le xaba\le b,此時貢獻應該是(ipi)(yi+1)(i-p_i)(y-i+1)
3.1ix1\le i\le xaba\ge b,此時貢獻應該是(ipi)(xi+1)(i-p_i)(x-i+1)
如果直接把三種答案加起來的話會發現2,3兩種情況中a=ba=b的部分算重了,減1即可。於是我們就需要維護上面的東西(2,3兩個情況其實可以合起來):
第一種是
i=x+1,pixy(xpi)(yi+1)=i=x+1,pixyx(y+1)pi(y+1)xi+pii\sum_{i=x+1,p_i\le x}^y (x-p_i)(y-i+1)\\ =\sum_{i=x+1,p_i\le x}^y x(y+1)-p_i(y+1)-xi+p_ii
這個東西可以通過主席樹維護四個值來計算:個數,ii的和,pip_i的和,piip_ii的和。
我們再來看第二種。
i=1x(ipi)(x+y+22i)=i=1x(x+y+2)(ipi)2i(ipi)\sum_{i=1}^x(i-p_i)(x+y+2-2i)\\ =\sum_{i=1}^x(x+y+2)(i-p_i)-2i(i-p_i)
這個東西沒有了對pip_i的限制條件,因此直接字首和維護即可。(當然如果你非要主席樹的話我也不能攔著qwq)
到此為止,鏈的情況被我們在O(nlogn)O(nlogn)的時間內做完了。

推廣到樹

注意到樹除了那條鏈其它都是隨機的,因此每個點到鏈距離的期望是O(logn)O(logn)的。每個顏色也是隨機的,因此每個顏色出現次數的期望是O(1)O(1)的。
也就是說,對於兩個點x,yx,y的LCA,記為ll,必有一個點到其距離為O(logn)O(logn)。不妨就設這個點為xx,考慮第一問怎麼做。
我們先計算出[l,y][l,y]中不同的顏色數(注意下面的區間都指的是一條鏈),這個可以直接主席樹。接下來做的事就是暴力列舉[x,l)[x,l)中的每個顏色,看看它是否在[l,y][l,y]中出現了,直接統計。判斷方法就是暴力列舉所有顏色和它相同的點即可。
因此第一問的複雜度也是O(nlogn)O(nlogn)的。
考慮第二問,我們可以劃分成如下三個子問題:
1.a[1,l),b[1,y]a\in [1,l),b\in [1,y]。這實際上就是鏈的情況,主席樹統計即可。
2.a[l,x],b[1,l)a\in [l,x],b\in [1,l)。這其實也是一條鏈,我們可以稍微轉化一下,先求出a[1,x],b[1,l)a\in [1,x],b\in [1,l)的答案,然後減去多算的。
多算的東西是2(ipi)(li)1\sum 2(i-p_i)(l-i)-1,直接字首和就能維護。
3.a[l,x],b[l,y]a\in [l,x],b\in [l,y]。這個情況很難算,我們也考慮分開計算貢獻。考慮存在於[l,y][l,y]中的點ii的貢獻為[pi<l](yi+1)(xl+1)[p_i<l](y-i+1)(x-l+1),主席樹維護即可。
再考慮存在於[l,x][l,x]中點ii的貢獻,首先它必須是所有與它顏色相同的點中第一個在[l,x][l,x]中出現的,它不能在[l,y][l,y]中包含和它顏色相同的點。不妨令jj[l,y][l,y]中第一個和它顏色相同的點(如果不存在則為y+1y+1),那麼其貢獻為[pi<l](xi+1)(jl)[p_i<l](x-i+1)(j-l)
暴力列舉點是O(logn)O(logn)的,找第一次出現時O(1)O(1)的,因此總複雜度還是O(nlogn)O(nlogn)的,只是常數比較大。

#include <bits/stdc++.h>
namespace IOStream {
	const int MAXR = 1 << 23;
	char _READ_[MAXR], _PRINT_[MAXR];
	int _READ_POS_, _PRINT_POS_, _READ_LEN_;
	inline char readc() {
	#ifndef ONLINE_JUDGE
		return getchar();
	#endif
		if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
		char c = _READ_[_READ_POS_++];
		if (_READ_POS_ == MAXR) _READ_POS_ = 0;
		if (_READ_POS_ > _READ_LEN_) return 0;
		return c;
	}
	template<typename T> inline void read(T &x) {
		x = 0; register int flag = 1, c;
		while (((c = readc()) < '0' || c > '9') && c != '-');
		if (c == '-') flag = -1; else x = c - '0';
		while ((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
		x *= flag;
	}
	template<typename T1, typename ...T2> inline void read(T1 &a, T2 &...x) {
		read(a), read(x...);
	}
	inline int reads(char *s) {
		register int len = 0, c;
		while (isspace(c = readc()) || !c);
		s[len++] = c;
		while (!isspace(c = readc()) && c) s[len++] = c;
		s[len] = 0;
		return len;
	}
	inline void ioflush() {
		fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0;
		fflush(stdout);
	}
	inline void printc(char c) {
		_PRINT_[_PRINT_POS_++] = c;
		if (_PRINT_POS_ == MAXR) ioflush();
	}
	inline void prints(char *s) {
		for (int i = 0; s[i]; i++) printc(s[i]);
	}
	template<typename T> inline void print(T x, char c = '\n') {
		if (x < 0) printc('-'), x = -x;
		if (x) {
			static char sta[20];
			register int tp = 0;
			for (; x; x /= 10) sta[tp++] = x % 10 + '0';
			while (tp > 0) printc(sta[--tp]);
		} else printc('0');
		printc(c);
	}
	template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
		print(x, ' '), print(y...);
	}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
#define cls(a) memset(a, 0, sizeof(a))

const int MAXN = 200005, MAXT = 2000005;
struct Edge { int to, next; } edge[MAXN];
int dfn[MAXN], st[20][MAXN], head[MAXN], lg[MAXN], tot, n, m, K, T;
void addedge(int u, int v) {
	edge[++tot] = (Edge) { v, head[u] };
	head[u] = tot;
}
int lst[MAXN], col[MAXN], dep[MAXN], rt[MAXN], app[MAXN], ed[MAXN];
struct Value {
	ll sum1, sum2, sum3, sum4;
	Value() { sum1 = sum2 = sum3 = sum4 = 0; }
	Value& operator+=(const Value &v) {
		sum1 += v.sum1, sum2 += v.sum2, sum3 += v.sum3, sum4 += v.sum4;
		return *this;
	}
	Value& operator-=(const Value &v) {
		sum1 -= v.sum1, sum2 -= v.sum2, sum3 -= v.sum3, sum4 -= v.sum4;
		return *this;
	}
} nd[MAXT];
//sum1=1,sum2=p[i],sum3=i,sum4=p[i]*i
ll pre2[MAXN], pre3[MAXN]; int ptot;
//pre1=1,pre2=i-p[i],pre3=i(i-p[i])
int ls[MAXT], rs[MAXT], par[MAXN], vis[MAXN];

//presistence segment tree
int update(int p, int x, int y, int l = 0, int r = n) {
	int q = ++ptot; nd[q] = nd[p];
	++nd[q].sum1, nd[q].sum2 += y, nd[q].sum3 += x, nd[q].sum4 += (ll)x * y;
	if (l == r) return q;
	int mid = (l + r) >> 1;
	if (y <= mid) ls[q] = update(ls[p], x, y, l, mid), rs[q] = rs[p];
	else rs[q] = update(rs[p], x, y, mid + 1, r), ls[q] = ls[p];
	return q;
}
void query(Value &v, int p, int q, int a, int b, int l = 0, int r = n) {//x in (p,q],y in [a,b]
	if (a > r || b < l || p == q) return;
	if (a <= l && b >= r) { v += nd[q], v -= nd[p]; return; }
	int mid = (l + r) >> 1;
	query(v, ls[p], ls[q], a, b, l, mid);
	query(v, rs[p], rs[q], a, b, mid + 1, r);
}

vector<int> pla[MAXN];
void dfs(int u, int fa) {
	st[0][dfn[u] = ++tot] = u, dep[u] = dep[fa] + 1;
	pla[col[u]].push_back(u), lst[u] = app[col[u]];
	int t = app[col[u]]; app[col[u]] = u;
	pre2[u] = pre2[fa] + dep[u] - dep[lst[u]];
	pre3[u] = pre3[fa] + (ll)(dep[u] - dep[lst[u]]) * dep[u];
	rt[u] = update(rt[fa], dep[u], dep[lst[u]]);
	for (int i = head[u]; i; i = edge[i].next) {
		dfs(edge[i].to, u);
		st[0][++tot] = u;
	}
	app[col[u]] = t, ed[u] = tot;
}
int get_min(int x, int y) { return dep[x] < dep[y] ? x : y; }
int get_lca(int x, int y) {
	x = dfn[x], y = dfn[y];
	if (x > y) swap(x, y);
	int l = lg[y - x + 1];
	return get_min(st[l][x], st[l][y - (1 << l) + 1]);
}
int on_link(int x, int y, int p) {//x is ancestor of y
	return dfn[x] <= dfn[p] && ed[x] >= dfn[p] &&
		dfn[p] <= dfn[y] && ed[p] >= dfn[y];
}

int solve1(int x, int y) {
	++tot;
	int la = get_lca(x, K), lb = get_lca(y, K);
	if (la > lb) swap(la, lb), swap(x, y);
	int l = get_lca(x, y);
	Value v; query(v, rt[par[l]], rt[y], 0, dep[l] - 1);
	int res = v.sum1;
	for (int i = x; i != l; i = par[i]) if (vis[col[i]] != tot) {
		vis[col[i]] = tot;
		int flag = 1;
		for (int j : pla[col[i]])
			if (on_link(l, y, j)) { flag = 0; break; }
		res += flag;
	}
	return res;
}
ll calc_link(int x, int y, const Value &v) {//x is ancestor of y
	int a = dep[x], b = dep[y];
	ll res = (v.sum1 * a - v.sum2) * (b + 1) - v.sum3 * a + v.sum4;
	return res + (a + b + 2) * pre2[x] - 2 * pre3[x] - a;
}
ll solve2(int x, int y) {
	++tot;
	int la = get_lca(x, K), lb = get_lca(y, K);
	if (la > lb) swap(la, lb), swap(x, y);
	int l = get_lca(x, y), pl = par[l], dl = dep[l];
	Value v1, v2;
	query(v1, rt[pl], rt[y], 0, dl - 1);
	query(v2, rt[pl], rt[x], 0, dl - 1);
	ll res = calc_link(pl, y, v1) + calc_link(pl, x, v2) -
		2 * (dl * pre2[pl] - pre3[pl]) + dl - 1;
	res += ((dep[y] + 1) * v1.sum1 - v1.sum3) * (dep[x] - dl + 1);
	int tp = 0;
	for (int i = x; i != l; i = par[i]) app[++tp] = i;
	app[++tp] = l;
	while (tp > 0) {
		int i = app[tp--];
		if (vis[col[i]] != tot) {
			vis[col[i]] = tot;
			int mn = dep[y] + 1;
			for (int j : pla[col[i]])
				if (on_link(l, y, j) && mn > dep[j]) mn = dep[j];
			res += (ll)(mn - dl) * (dep[x] - dep[i] + 1);
		}
	}
	return res;
}

unsigned int SA, SB, SC;
unsigned int rng61(){
    SA ^= SA << 16;
    SA ^= SA >> 5;
    SA ^= SA << 1;
    unsigned int t = SA;
    SA = SB;
    SB = SC;
    SC ^= t ^ SA;
    return SC;
}
void gen(){
    read(n, K, SA, SB, SC);
    for(int i = 2; i <= K; i++) addedge(par[i] = i - 1, i);
    for(int i = K + 1; i <= n; i++)
        addedge(par[i] = rng61() % (i - 1) + 1, i);
    for(int i = 1; i <= n; i++) col[i] = rng61() % n + 1;
}
int main() {
	for (read(T); T--;) {
		tot = 0, cls(head), cls(vis), cls(app);
		gen();
		for (int i = 1; i <= n; i++) pla[i].clear();
		dfs(1, ptot = tot = 0);
		for (int i = 2; i <= tot; i++) lg[i] = lg[i >> 1] + 1;
		for (int i = 1; i < 20; i++)
		for (int j = 1; j + (1 << i) - 1 <= tot; j++)
			st[i][j] = get_min(st[i - 1][j], st[i - 1][j + (1 << i >> 1)]);
		tot = 0;
		for (read(m); m--;) {
			int a, b, c; read(a, b, c);
			if (a == 1) print(solve1(b, c));
			else print(solve2(b, c));
		}
	}
	ioflush();
	return 0;
}

相關文章