個人線段樹寫法 & 注意逝項

Fun_Strawberry發表於2024-06-13

前言

眾所周知由於一些原因,我們有時候需要寫一些維護較多東西的線段樹,如 P4513 小白逛公園 這種。

這個過程中,不妙的實現(比如說像某位李姓,名字最後一個字是木字旁的性感同學的常見實現),比如隨意多開線段樹,大量使用 if,大量複製貼上來完成的,難寫難調,很容易爆炸。

那麼相反的,合理的實現,就會有清晰的架構,簡易的程式碼。

起因是最近寫一道樹剖題目,其中的線段樹類似小白逛公園類,但是有時需要區間取反以及合併答案時需要整體翻轉,涵蓋範圍相對全。

舉例

具體問題:樹上,點權為 \(1\) 或者 \(-1\),操作包含:路徑乘 \(-1\),路徑問長度減字首 \(\max\) 加字尾 \(\min\)

分析:類似小白逛公園,由於區間乘 \(-1\) 的操作,我還需要維護字首 \(\min\) 和字尾 \(\max\)

另外為了合併兩個區間,需要維護區間 \(\text{sum}\)

分析一下,需要複雜合併 / 下傳 / 維護答案。

採用如下實現方式:

  1. node 開節點,原因見下一條:

  2. node operator + 這樣 build()change() 裡面的 pushup(),合併答案都可以解決。

  3. node chg(node x)pushdown() 內對兩個子節點分別的操作,和 change() 返回前的修改操作統一。

  4. node rev(node x) 樹上的問題中,鏈拼起來可能需要翻轉順序,這個就適用。

放個示例看看。

#include<bits/stdc++.h>
using namespace std;
struct node{
	long long l,r,len,le,lx,ln,rx,rn,tg;
}t[800110];
node operator + (const node &x,const node &y)
{
	return {x.l,y.r,x.len+y.len,x.le+y.le,max(x.lx,x.len+y.lx),min(x.ln,x.len+y.ln),max(y.rx,y.len+x.rx),min(y.rn,y.len+x.rn),0};
}
long long a[100050];
#define mid ((t[o].l+t[o].r)>>1)
#define ls (o<<1)
#define rs ((o<<1)^1)
node chg(node x)
{
	return {x.l,x.r,-x.len,x.le,-x.ln,-x.lx,-x.rn,-x.rx,x.tg^1};
}
node rev(node x)
{
	return {x.l,x.r,x.len,x.le,x.rx,x.rn,x.lx,x.ln,0};
}
void build(int l,int r,int o)
{
	t[o].l=l,t[o].r=r;
	if(l==r)
	{
		if(a[l]==1) t[o].lx=t[o].rx=t[o].len=1;
		else t[o].ln=t[o].rn=t[o].len=-1;
		t[o].le=1;
		return;
	}
	build(l,mid,ls);
	build(mid+1,r,rs);
	t[o]=t[ls]+t[rs];
}
void pushdown(int o)	//why not spread lol
{
	if(t[o].tg)
	{
		t[o].tg=0;
		t[ls]=chg(t[ls]);
		t[rs]=chg(t[rs]);
	}
}
void change(int l,int r,int o)
{
	if(l<=t[o].l&&t[o].r<=r)
	{
		t[o]=chg(t[o]);
		return;
	}
	pushdown(o);
	if(l<=mid) change(l,r,ls);
	if(r>mid) change(l,r,rs);
	t[o]=t[ls]+t[rs];
}
node ask(int l,int r,int o)
{
	if(l<=t[o].l&&t[o].r<=r) return t[o];
	pushdown(o);
	if(l<=mid&&r>mid) return ask(l,r,ls)+ask(l,r,rs);
	if(l<=mid) return ask(l,r,ls);
	if(r>mid) return ask(l,r,rs);
}

vector<int> e[114514];
int fa[114514],si[114514],son[114514],dep[114514],ms[114514],dfn[114514],cnt,da[114514],top[114514];
void dfs1(int u)
{
	si[u]=1;
	for(auto v:e[u]) if(v!=fa[u])
	{
		fa[v]=u;
		dep[v]=dep[u]+1;
		dfs1(v);
		si[u]+=si[v];
		if(si[v]>ms[u])
		{
			ms[u]=si[v];
			son[u]=v;
		}
	}
}
void dfs2(int u,int topf)
{
	cnt++;
	dfn[u]=cnt;
	a[cnt]=da[u];
	top[u]=topf;
	if(!son[u]) return;
	dfs2(son[u],topf);
	for(auto v:e[u])
	if(v!=fa[u]&&v!=son[u])
		dfs2(v,v);
}

int n,m,i,j,r,u,v,w,x,y,z,in;
signed main()
{
	freopen("loser.in","r",stdin);
	freopen("loser.out","w",stdout);
	cin>>n>>m;r=1;
	for(i=1;i<n;i++)
	{
		cin>>u>>v;
		e[u].push_back({v});e[v].push_back({u});
	}
	
	for(i=1;i<=n;i++)
	cin>>da[i],da[i]=(da[i]?1:-1);
	memset(ms,-1,sizeof(ms));
	dfs1(r);dfs2(r,r);
	build(1,n,1);
	
	for(i=1;i<=m;i++)
	{
		cin>>in>>x>>y;
		if(in==2)
		{
			node ansx,ansy;
			ansx=ansy={0,0,0,0,0,0,0,0};
			int rv=0;
			while(top[x]!=top[y])
			{
				if(dep[top[x]]<dep[top[y]]) swap(x,y),swap(ansx,ansy),rv^=1;
				ansx=ask(dfn[top[x]],dfn[x],1)+ansx;
				x=fa[top[x]];
			}
			if(dep[x]>dep[y]) swap(x,y),swap(ansx,ansy),rv^=1;
			ansx=(rev(ansx)+ask(dfn[x],dfn[y],1))+ansy;
			if(rv) ansx=rev(ansx);
			cout<<ansx.le-ansx.lx+ansx.rn<<endl;
		}
		if(in==1)
		{
			while(top[x]!=top[y])
			{
				if(dep[top[x]]<dep[top[y]]) swap(x,y);
				change(dfn[top[x]],dfn[x],1);
				x=fa[top[x]];
			}
			if(dep[x]>dep[y]) swap(x,y);
			change(dfn[x],dfn[y],1);
		}
	}
	return 0;
}