CF1618G Trader Problem 題解

DengStar發表於2024-04-13

CF1618G Trader Problem 題解

題目連結:CF|洛谷

提供一個線上做法。

分析1

我們不妨把 \(a\)\(b\) 合併為一個序列,稱合併後的序列為 \(c\),並將其不降序排序。把玩樣例後不難發現:對於一個物品序列 \(c_1, c_2, \cdots, c_l\),滿足 \(\forall i < l, c_{i+1} - c_i \le k\)(即任意兩個相鄰的物品都可以交換),要使最後的總價值最大,最優的方法顯然是把初始物品一直往後換。如果在這個序列中我們一開始有 \(cnt\) 個物品,那麼交換完之後我們的總價值就是 \(\sum_{i=l-cnt+1}^{l} c_i\),即最後 \(cnt\) 個物品的價值和。

分析2

上面的分析中,我們假設對於某個特定的 \(k\),序列滿足任意兩個相鄰的物品可以交換。但如果序列不滿足這個性質呢?不難想到,可以把序列斷開:對於所有滿足 \(c_{i+1} - c_i > k\)\(i\) (也就是說第 \(i\) 個物品不能與第 \(i+1\) 個物品交換),我們就讓 \(i\)\(i+1\) 分別屬於兩個鏈(子序列)。這樣對於每個條鏈,都滿足它任意兩個相鄰的物品可以交換。把所有鏈的答案加起來就是總答案。於是,對於每次詢問,我們都可以 \(O(n)\) 掃一遍解決。

單次詢問 \(O(n)\) 的程式碼如下:

// mark[i]: 第i個物品是否是一開始手裡擁有的
// cnt: 當前這條鏈中有多少個物品時一開始就有的
void solve(int k)
{
	int pos = 2, lst = 1, cnt = mark[1];
	ll ans = 0;
	for(; pos <= tot; pos++)
	{
		if(c[pos] - c[pos-1] > k) // 如果這兩個相鄰的物品不能交換,就斷開,統計上一段的貢獻
		{
			ans += sum[pos - 1] - sum[pos - cnt - 1];
			lst = pos, cnt = 0;
		}
		cnt += mark[pos];
	}
	cout << ans << endl;
}

分析3

上面我們一直預設 \(k\) 為定值,如果 \(k\) 不確定,有沒有什麼辦法能一次處理好所有可能的 \(k\) 對應的答案呢?

首先明確一個問題:雖然 \(k\) 有很多種取值(\(0 \le k \le 10^9\)),但並不是每個不同的 \(k\) 都對應不同的答案。可以這樣理解:假設現在我們已經對於某個特定的 \(k\) 把原序列斷成了幾條鏈,然後讓 \(k\) 不斷增大,只有當 \(k\) 可以“跨越”某兩條鏈之間的間隔時,這兩條鏈才會合併到一起,同時答案才可能發生改變。因此,實際上我們並不需要對所有的 \(k\) 都算一個答案。設所有的相鄰物品價值差的集合為 \(dif\),我們只需要讓 \(k\) 取遍 \(dif\) 中的每個值即可。將 \(dif\) 排序,設 \(k = dif_i\) 時的答案為 \(ans_i\)。對於每次詢問的 \(k\),二分找出一個 \(i\) 使得 \(dif_i \le k < dif_{i+1}\) ,答案便為 \(ans_i\)

接下來的做法便不難想到了。一開始令 \(k = 0\),這時每個物品都只能和與它價值相同的物品構成鏈,因為任意兩個價值不同的物品都不能交換。當 \(k \in [0, dif_1)\) 時,總不會有新的可以交換的物品對產生。只有當 \(k = dif_1\) 時,才可以進行合併。同理,我們繼續讓 \(k\) 取遍所有的 \(dif_i\),每次都合併所有可以合併的鏈,同時統計答案。

可以使用並查集維護鏈。同時維護每條鏈包含的初始物品數。合併時,新鏈的初始物品數即為原來兩條鏈的初始物品數之和。

時間複雜度

(因為 \(n\)\(m\) 數量級相同,下面把所有的 \(n + m\) 都當作 \(n\)。)

因為總共只會合併 \(n-1\) 次,所以統計答案的時間複雜度為 \(O(n)\)。而先前對 \(c\) 排序的時間複雜度為 \(O(n \log n)\),單次查詢的時間複雜度為 \(O(\log n)\),所以總的時間複雜度為 \(O(n \log n)\)

程式碼

// CF1618G Trader Problem
#include<bits/stdc++.h>

using namespace std;

typedef long long ll;
constexpr int MAXN = 2e5 + 10;
int n, m, q, p, a[MAXN], b[MAXN];
int tot, c[MAXN << 1], dif[MAXN << 1], cnt[MAXN << 1];
// cnt[i]: 以i結尾的鏈中的原始商品的數量 
ll sum[MAXN << 1], ans[MAXN << 1];
bool mark[MAXN << 1];
struct Node
{
	ll dif;
	int pos; // c[pos+1] - c[pos] = dif
	Node(int dif, int pos): dif(dif), pos(pos) {}
	bool operator > (const Node &rhs) const
	{
		return dif > rhs.dif;
	}
};
priority_queue<Node, vector<Node>, greater<Node>> que;

struct DSU
{
	int fa[MAXN << 1];
	void init()
	{
		for(int i = 1; i <= tot; i++)
		{
			fa[i] = i;
			cnt[i] = mark[i];
		}
	}
	int getfa(int u)
	{
		return fa[u] == u ? u : fa[u] = getfa(fa[u]);
	}
	void merge(int x, int y) // x < y
	{
		int fx = getfa(x), fy = getfa(y);
		fa[fx] = fy;
		cnt[fy] += cnt[fx], cnt[fx] = 0; 
	}
}dsu;

void merge() // 直接 sort 應該也行
{
	int i = 1, j = 1, k = 1;
	while(i <= n && j <= m)
	{
		if(a[i] <= b[j]) mark[k] = true, c[k++] = a[i++];
		else c[k++] = b[j++];
	}
	while(i <= n) mark[k] = true, c[k++] = a[i++];
	while(j <= m) c[k++] = b[j++];
	for(int i = 1; i <= tot; i++) sum[i] = sum[i-1] + c[i];
	for(int i = 1; i < tot; i++)
	{
		dif[i] = c[i+1] - c[i];
		que.push(Node(dif[i], i));
	}
	sort(dif + 1, dif + tot);
	dif[0] = unique(dif + 1, dif + tot) - dif - 1;
}

inline int getrk(int x)
{
	return lower_bound(dif + 1, dif + dif[0] + 1, x) - dif;
}

void solve()
{
	merge();
	dsu.init();
	for(int i = 1; i <= n; i++) ans[0] += a[i];
	while(!que.empty()) // 這裡用優先佇列維護 dif,實際上直接把 dif 排序再列舉應該也可以 
	{
		Node u = que.top();
		que.pop();
		int nowdif = u.dif, rk = getrk(nowdif);
		int ori = dsu.getfa(u.pos), tail = dsu.getfa(u.pos + 1); 
		if(!ans[rk]) ans[rk] = ans[rk - 1];
		// 因為沒有對 dif 去重,所以第一次遇到 dif 時要先取用上一個 dif 的 ans 
		ll d1 = (sum[ori] - sum[ori - cnt[ori]]) + (sum[tail] - sum[tail - cnt[tail]]);
		dsu.merge(u.pos, u.pos + 1);
		ll d2 = sum[tail] - sum[tail - cnt[tail]];
		ans[rk] = ans[rk] - d1 + d2;
		// 去除原先兩條鏈的貢獻,加上新鏈的貢獻 
	}
}

signed main()
{
	cin >> n >> m >> q;
	tot = n + m;
	for(int i = 1; i <= n; i++) cin >> a[i];
	for(int i = 1; i <= m; i++) cin >> b[i];
	sort(a + 1, a + n + 1), sort(b + 1, b + m + 1);
	solve();
	while(q--)
	{
		cin >> p;
		int rk = upper_bound(dif + 1, dif + dif[0] + 1, p) - dif - 1;
		cout << ans[rk] << endl;
	}
	return 0;
}

相關文章