[BJWC2010] 嚴格次小生成樹

神眷の櫻花發表於2021-05-29

題面

嚴格次小生成樹

題解

小藍書 + 我自己的補充

做法

題意很好理解吧。
設最小生成樹的邊權之和為 \(sum\)
我們要找嚴格次小生成樹,就是要找到這樣的一條非最小生成樹上的邊,滿足:

  • 將最小生成樹上的某條邊替換成這條邊後,樹依然聯通。
  • 這條邊與被替換邊的權值之差最小,且大於 \(0\)

所以我們進行如下操作:

  • 選擇一條非最小生成樹上的邊 \((x,y,z)\)
  • 將它加入樹中,顯然會形成一個環。
  • \(Kruskal\) 的證明過程我們可以得到,\(z \geq dis_{x,y}\)
  • 所以我們可以將 \(x\) - \(y\) 路徑上的某條邊替換成邊 \((x,y,z)\) ,顯然樹依然聯通。
  • \(x\) - \(y\) 路徑上的權值最大邊的邊權為 \(val_1\) ,次大邊的邊權為 \(val_2\)
  • 根據上述嚴格次小生成樹的找法的定義, 當 \(z > val_1\) 時 ,將 \(val_1\) 的這條邊替換成 \((x,y,z)\) 肯定是最優的,得到候選答案 \(sum - val_1 + z\)。當 \(z = val_1\) 時,將 \(val_2\) 替換成 \((x,y,z)\) 肯定是最優的,因為 \(val_1\) 所在的邊不能被替換,得到候選答案 \(sum - val_2 + z\)
  • 對所有的非樹邊執行上述操作,記錄最小的範圍值,得到最終答案。

優化

顯然每次暴查 \(val_1,val_2\) 明顯會炸。
所以我們可以運用倍增的思想預處理出點 \(x\) 向上跳 \(2^k\) 次的路徑中的 \(val_1\)\(val_2\)
做法類似與 \(lca\) 倍增做法時維護祖先的做法,這到題我們在找 \(x\) - \(y\) 的路徑時也需要 \(lca\) ,所以這兩個我們也需要維護出來。
\(g[x][k][0/1]\) 表示點 \(x\),向上跳 \(2^k\) 次的路徑中的 \(val_1\)\(val_2\)

  • 初始化:

\[g[x][0][0] = w_{x,fa_x},g[x][0][1] = -INF(設為負無窮不為影響到其它值的維護) \]

  • \(Fa 表示 x的2^{k - 1}輩祖先\)

\[g[x][k][0] = max(g[x][k - 1][0],g[][k - 1][0] ) \]

\[g[x][k][1] = \left\{ \begin{array}{lcl} max(g[x][k - 1][1],g[Fa][k - 1][1])\ (g[x][k - 1][0] = g[Fa][k -1][0]) \\ max(g[x][k - 1][0],g[Fa][k - 1][1])\ (g[x][k - 1][0] < g[Fa][k -1][0]) \\ max(g[x][k - 1][1],g[Fa][k - 1][0])\ (g[x][k - 1][0] > g[Fa][k -1][0]) \\ \end{array} \right\}\]

轉移過程應該很好想,最大值相等從次小值裡找最小值,一個更大,從小的最大和大的次小裡找次小值。
注意要開 \(long\ long\)

程式碼

#include<cstdio>
#include<algorithm>
#include<queue>
#include<cmath>
#include<iostream>
#define LL long long

using namespace std;

const int N = 1e5 + 5;
const int M = 3e5 + 5;
const LL INF = 1e16;

int n,m,f[N][18],t;
LL ans = 0,g[N][18][2],res = INF;
struct E {
	int from,to,w; bool is;
	E () {
		is = false;
	}
	bool operator < (const E & x) const {
		return w < x.w;
	}
}e[M];

struct edge {
	int head[N],next[M],to[M],w[M],size;
	inline void add(int u,int v,int W) {
		next[++size] = head[u]; w[size] = W;
		to[size] = v; head[u] = size;
		next[++size] = head[v]; w[size] = W;
		to[size] = u; head[v] = size;
	} 
	inline LL Get_val_2(int y,int j) {
		int Fa = f[y][j - 1]; LL Ans;
		if(g[y][j - 1][0] == g[Fa][j - 1][0])
			Ans = max(g[y][j - 1][1],g[Fa][j - 1][1]);	
		else if(g[y][j - 1][0] < g[Fa][j - 1][0])
			Ans = max(g[y][j - 1][0],g[Fa][j - 1][1]);
		else Ans = max(g[y][j - 1][1],g[Fa][j - 1][0]);
		return Ans;
	}
	queue<int> q; int dep[N]; LL dis[N];
	void bfs(int s) {
		q.push(s); dep[s] = 1;
		while(!q.empty()) {
			int x = q.front(); q.pop();
			for(int i = head[x]; i; i = next[i]) {
				int y = to[i];
				if(dep[y]) continue;
				dep[y] = dep[x] + 1;
				dis[y] = dis[x] + w[i];
				f[y][0] = x;
				g[y][0][0] = w[i]; g[y][0][1] = -INF;
				for(int j = 1; j <= t; j++) {
					f[y][j] = f[f[y][j - 1]][j - 1];
					g[y][j][0] = max(g[y][j - 1][0],g[f[y][j - 1]][j - 1][0]);
					g[y][j][1] = max(g[y][j][1],Get_val_2(y,j));
				}
				q.push(y);
			}	
		}
	} 
	inline LL Get(int x,int y,int w) {
		if(dep[x] > dep[y]) swap(x,y);
		LL val_1 = 0,val_2 = 0;
		for(int i = t; i >= 0; i--)
			if(dep[f[y][i]] >= dep[x]) {
				val_1 = max(val_1,g[y][i][0]);
				if(i > 0) val_2 = max(val_2,Get_val_2(y,i));
				y = f[y][i];
			}
		if(x == y) {
			if(w >  val_1) return ans - val_1 + w;
			if(w == val_1) return ans - val_2 + w; 
		}
		for(int i = t; i >= 0; i--)
			if(f[x][i] != f[y][i]) {
				val_1 = max(val_1,g[y][i][0]);
				if(i > 0) val_2 = max(val_2,Get_val_2(y,i));
				val_1 = max(val_1,g[x][i][0]);
				if(i > 0) val_2 = max(val_2,Get_val_2(x,i));
				x = f[x][i],y = f[y][i];
			}
		val_1 = max(val_1,g[x][0][0]);
		val_1 = max(val_1,g[y][0][0]);
		val_2 = max(val_2,max(g[x][0][0],g[y][0][0]));
		if(w >  val_1) return ans - val_1 + w;
		if(w == val_1) return ans - val_2 + w; 
	}
}a;

int fa[N];
inline int Find(int x) {
	return fa[x] == x ? x : fa[x] = Find(fa[x]);
}

inline int read() {
	int x = 0,flag = 1;
	char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-')flag = -1;ch = getchar();}
	while(ch >='0' && ch <='9'){x = (x << 3) + (x << 1) + ch - 48;ch = getchar();}
	return x * flag;
}
int main() {
	n = read(),m = read();
	t = log(1.0 * n) / log(2.0);
	for(int i = 1; i <= m; i++) {
		e[i].from = read();
		e[i].to = read();
		e[i].w = read();
	}
	for(int i = 1; i <= n; i++) fa[i] = i;
	sort(e + 1,e + 1 + m); a.size = 1;
	for(int i = 1; i <= m; i++) {
		int x = Find(e[i].from);
		int y = Find(e[i].to);
		if(x == y) continue;
		fa[x] = y; ans += e[i].w;
		a.add(e[i].from,e[i].to,e[i].w); e[i].is = true;
	}
	a.bfs(1);
	for(int i = 1; i <= m; i++) {
		if(e[i].is) continue;
		res = min(res,a.Get(e[i].from,e[i].to,e[i].w));
	}
	printf("%lld\n",res);
	return 0;
}

相關文章