HNOI2018排列(堆+並查集)

WAautomaton發表於2019-03-02

題目連結

題目大意

給定數列a(0ain)a(0\le a_i\le n),設其排列後的數列為ap1,ap2,...,apna_{p_1},a_{p_2},...,a_{p_n},要求對於任意的kk,滿足1jk,apjpk1\le j\le k,a_{p_j}\neq p_k。某個合法排列的價值為wp1+2wp2+...+nwpnw_{p_1}+2w_{p_2}+...+nw_{p_n},求最大價值。
n5×105n\le 5\times 10^5

題解

神題qaq。
首先,好好思考題目中那個亂七八糟的條件是啥,也就是說讓=pk=p_kaa值出現在第kk個之後。我們考慮最後pp的排列方式,顯然對於任意aia_i,使=ai=a_ipp值出現在=i=ipp值之前。
於是我們就得到了若干個限制pp數列的關係,於是對於每個aia_i,從aia_iii連一條邊,得到了一個圖。如果這個圖中存在環的話顯然無解,否則這個圖就是以0為根的樹。
考慮ww最小的那個點,顯然如果它的father被選了,下一個選的必然是它。因此它倆必然在數列中連續,於是把它們用並查集合起來,計算貢獻。
然而這樣操作之後就變成了多個連通塊比較,我們如何選擇最小的連通塊呢?
考慮任意兩個連通塊a,ba,b,如果aabb之前更優,則必然有size[a]sum[b]+sum[a]+sum[b]>sum[a]size[b]+sum[a]+sum[b]size[a]\cdot sum[b]+sum[a]+sum[b]>sum[a]\cdot size[b]+sum[a]+sum[b],消一下就變成了size[a]sum[b]>size[b]sum[a]size[a]\cdot sum[b]>size[b]\cdot sum[a]
於是用優先佇列維護最小值就行了,複雜度O(nlogn)O(nlogn)

#include <bits/stdc++.h>
namespace IOStream {
	const int MAXR = 10000000;
	char _READ_[MAXR], _PRINT_[MAXR];
	int _READ_POS_, _PRINT_POS_, _READ_LEN_;
	inline char readc() {
	#ifndef ONLINE_JUDGE
		return getchar();
	#endif
		if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
		char c = _READ_[_READ_POS_++];
		if (_READ_POS_ == MAXR) _READ_POS_ = 0;
		if (_READ_POS_ > _READ_LEN_) return 0;
		return c;
	}
	template<typename T> inline void read(T &x) {
		x = 0; register int flag = 1, c;
		while (((c = readc()) < '0' || c > '9') && c != '-');
		if (c == '-') flag = -1; else x = c - '0';
		while ((c = readc()) >= '0' && c <= '9') x = x * 10 - '0' + c;
		x *= flag;
	}
	template<typename T1, typename ...T2> inline void read(T1 &a, T2&... x) {
		read(a), read(x...);
	}
	inline int reads(char *s) {
		register int len = 0, c;
		while (isspace(c = readc()) || !c);
		s[len++] = c;
		while (!isspace(c = readc()) && c) s[len++] = c;
		s[len] = 0;
		return len;
	}
	inline void ioflush() { fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0; fflush(stdout); }
	inline void printc(char c) {
		if (!c) return;
		_PRINT_[_PRINT_POS_++] = c;
		if (_PRINT_POS_ == MAXR) ioflush();
	}
	inline void prints(const char *s, char c) {
		for (int i = 0; s[i]; i++) printc(s[i]);
		printc(c);
	}
	template<typename T> inline void print(T x, char c = '\n') {
		if (x < 0) printc('-'), x = -x;
		if (x) {
			static char sta[20];
			register int tp = 0;
			for (; x; x /= 10) sta[tp++] = x % 10 + '0';
			while (tp > 0) printc(sta[--tp]);
		} else printc('0');
		printc(c);
	}
	template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
		print(x, ' '), print(y...);
	}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
typedef pair<int, int> P;

const int MAXN = 500005;
struct Node {
	ll sum; int sz, rt;
	bool operator<(const Node &nd) const { return sum * nd.sz > sz * nd.sum; }
};
priority_queue<Node> pq;
int ww[MAXN], par[MAXN], sz[MAXN], fa[MAXN], n;
ll sum[MAXN];
int find(int x) { return x == par[x] ? x : par[x] = find(par[x]); }
void merge(int x, int y) {
	x = find(x), y = find(y);
	if (x == y) return;
	par[x] = y, sz[y] += sz[x], sum[y] += sum[x];
}
int main() {
	read(n);
	for (int i = 0; i <= n; i++) par[i] = i;
	for (int i = 1; i <= n; i++) {
		int t; read(t); fa[i] = t;
		if (find(i) == find(t)) return puts("-1") * 0;
		par[find(i)] = find(t);
	}
	ll res = 0;
	for (int i = 1; i <= n; i++) {
		par[i] = i, sz[i] = 1; read(sum[i]);
		pq.push((Node) { sum[i], 1, i });
	}
	par[0] = 0, sz[0] = 1;
	for (int i = 1; i <= n; i++) {
		Node d = pq.top(); pq.pop();
		if (!d.rt || sz[d.rt] != d.sz) { --i; continue; }
		int p = find(fa[d.rt]);
		res += sum[d.rt] * sz[p];
		merge(d.rt, p);
		pq.push((Node) { sum[p], sz[p], p });
	}
	printf("%lld\n", res);
	return 0;
}

相關文章