[賽記] NOIP2024加賽8

Peppa_Even_Pig發表於2024-11-28

大抵是NOIP前寫的最後一篇題解了吧。。。

flandre 80pts

賽時打的錯解A了,然後證偽以後寫了個更錯的錯解80pts;

考慮我們最終要求的答案是 $ a $ 陣列從小到大排序後的一個字尾

考慮怎樣證明這個結論,感性理解一下就是儘量選大的然後挺對;

考慮比較嚴謹的證明

如果序列中沒有重複的元素,那麼正確性顯而易見;

考慮如果有重複的元素怎麼辦;

假設現在我們選了一個可重集合 $ A = { x, y, y, y, z } $,其中元素遞增給出,首先我們會選 $ x $,因為它最小,然後考慮加 $ y $ 的貢獻,不難發現,無論加多少 $ y $,它和加 $ x $ 的貢獻是一樣的,都是對原序列加了兩個可以 $ +k $ 的點對,所以 $ y $ 在選 $ x $ 的前提下是越多越好的(要不就不選);

證畢;

然後上個 $ BIT $ 維護一下前面有多少個比它大的即可,時間複雜度 $ \Theta(n \log V) $,其中 $ V $ 為值域;

點選檢視程式碼
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
long long n, k;
struct sss{
	long long a;
	int id;
	bool operator <(const sss &A) const {
		return a < A.a;
	}
}e[1000005];
vector<int> v;
namespace BIT{
	inline int lowbit(int x) {
		return x & (-x);
	}
	int tr[3000005];
	inline void add(int pos, int d) {
		for (int i = pos; i <= 3000000; i += lowbit(i)) tr[i] += d;
	}
	inline int ask(int d) {
		int ans = 0;
		for (int i = d; i; i -= lowbit(i)) ans += tr[i];
		return ans;
	}
}
int main() {
	freopen("flandre.in", "r", stdin);
	freopen("flandre.out", "w", stdout);
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin >> n >> k;
	for (int i = 1; i <= n; i++) {
		cin >> e[i].a;
		e[i].id = i;
	}
	sort(e + 1, e + 1 + n);
	long long ans = 0, sum = 0, now = 0;
	for (int i = n; i >= 1; i--) {
		now = now + k * (BIT::ask(3000000) - BIT::ask(e[i].a + 2000000)) + e[i].a;
		if (now > ans) {
			ans = now;
			sum = i;
		}
		BIT::add(e[i].a + 2000000, 1);
	}
	if (!sum) {
		cout << 0 << ' ' << 0;
		return 0;
	}
	cout << ans << ' ' << n - sum + 1 << '\n';
	for (int i = sum; i <= n; i++) cout << e[i].id << ' ';
	return 0;
}

meirin 100pts

發現一個式子有四個 $ \sum $,然後直接做不太好做,所以考慮貢獻

這個題的 $ a $ 陣列始終不變,所以我們可以從這裡入手;

發現我們最終求的是 $ (\sum_{i = l}^{r} a_i) \times (\sum_{i = l}^{r} b_i) $,而 $ a $ 始終不變,所以我們可以嘗試維護每個 $ b_i $ 的係數

設 $ val_i $ 表示 $ b_i $ 的係數;

先不考慮修改,那麼考慮一個 $ b_i $ 的係數 $ val_i $ 是:

\[\sum_{r = i}^{n} \sum_{j = i}^{r} a_j + val_{i - 1} - \sum_{l = 1}^{i - 1} \sum_{j = l}^{i - 1} a_j \]

簡單來講就是上一個的係數減去不包含這個點的係數(右端點是 $ i - 1 $ )再加上這個點新的係數(左端點是 $ i $ );

發現前面兩個 $ \sum $ 可以倒著掃一遍 $ \Theta(n) $ 處理,後面兩個 $ \sum $ 可以正著掃一遍 $ \Theta(n) $ 處理,這樣我們就得到了每個點的係數;

修改直接用這個區間的係數和乘要加的值然後和原來的答案相加即可,這個可以字首和處理;

時間複雜度:$ \Theta(n + m) $;

點選檢視程式碼
#include <iostream>
#include <cstdio>
using namespace std;
const long long mod = 1000000007;
int n, q;
long long a[500005], b[500005], sum[500005], p[500005], val[500005], su[500005];
int main() {
	freopen("meirin.in", "r", stdin);
	freopen("meirin.out", "w", stdout);
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin >> n >> q;
	for (int i = 1; i <= n; i++) {
		cin >> a[i];
	}
	for (int i = 1; i <= n; i++) {
		cin >> b[i];
	}
	sum[n] = a[n];
	sum[n] = (sum[n] + mod) % mod;
	for (int i = n - 1; i >= 1; i--) {
		sum[i] = (sum[i + 1] + (a[i] * (n - (i + 1) + 1) % mod + a[i]) % mod) % mod;
		sum[i] = (sum[i] + mod) % mod;
	}
	p[1] = a[1];
	p[1] = (p[1] + mod) % mod;
	for (int i = 2; i <= n; i++) {
		p[i] = (p[i - 1] + ((a[i] * (i - 1)) % mod + a[i]) % mod) % mod;
		p[i] = (p[i] + mod) % mod;
	}
	val[1] = sum[1];
	long long ans = ((val[1] * b[1]) % mod + mod) % mod;
	for (int i = 2; i <= n; i++) {
		val[i] = ((sum[i] + val[i - 1]) % mod - p[i - 1]) % mod;
		val[i] = (val[i] + mod) % mod;
		ans = (ans + (val[i] * b[i]) % mod + mod) % mod;
	}
	for (int i = 1; i <= n; i++) {
		su[i] = (su[i - 1] + val[i]) % mod;
	}
	int l, r;
	long long k;
	for (int i = 1; i <= q; i++) {
		cin >> l >> r;
		cin >> k;
		long long now = ((su[r] - su[l - 1] + mod) % mod + mod) % mod;
		ans = (ans + (now * k) % mod + mod) % mod;
		cout << ans << '\n';
	}
	return 0;
}

sakuya 15pts

這種題有個套路:統計邊被經過的次數

欽定 $ 1 $ 為根;

設 $ f_{x} $ 表示 $ x $ 子樹內重要點的個數,這個容易求出;

考慮一條邊 $ (x, u), fa_u = x $,它的經過次數為 $ f_u \times (f_1 - f_u) \times 2 \times (m - 1)! $;

其中 $ \times 2 $ 是因為原序列有序, $ \times (m - 1)! $ 可以理解為將這兩個點綁一起做個全排;

然後就得到了每條邊被經過的次數,這個可以 $ DFS $ 求出,順便也求出了初始答案;

考慮修改,發現修改只是將與一個點相連的所有邊的邊權 $ +k $,所以我們再開一個陣列 $ a_x $ 表示與 $ x $ 這個點相連的邊的經過次數之和,然後答案直接加 $ a_x \times k $ 即可;

最後不要忘了除以總方案數 $ m! $,因為求的是期望;

時間複雜度:$ \Theta(n) $;

點選檢視程式碼
#include <iostream>
#include <cstdio>
using namespace std;
#define int long long
const long long mod = 998244353;
int n, m, q;
struct sss{
	int t, ne;
	long long w;
}e[1000005];
int h[1000005], cnt;
void add(int u, int v, long long ww) {
	e[++cnt].t = v;
	e[cnt].ne = h[u];
	h[u] = cnt;
	e[cnt].w = ww;
}
int f[500005];
long long a[500005], ans, inv, val;
bool vis[500005];
void afs(int x, int fa) {
	if (vis[x]) f[x] = 1;
	for (int i = h[x]; i; i = e[i].ne) {
		int u = e[i].t;
		if (u == fa) continue;
		afs(u, x);
		f[x] += f[u];
	}
}
void dfs(int x, int fa) {
	for (int i = h[x]; i; i = e[i].ne) {
		int u = e[i].t;
		if (u == fa) continue;
		dfs(u, x);
		a[x] = (a[x] + f[u] * (f[1] - f[u]) % mod * 2 % mod * val % mod) % mod;
		a[u] = (a[u] + f[u] * (f[1] - f[u]) % mod * 2 % mod * val % mod) % mod;
		ans = (ans + f[u] * (f[1] - f[u]) % mod * 2 % mod * val % mod * e[i].w % mod) % mod;
	}
}
long long ksm(long long a, long long b) {
	long long ans = 1;
	while(b) {
		if (b & 1) ans = ans * a % mod;
		a = a * a % mod;
		b >>= 1;
	}
	return ans;
}
signed main() {
	freopen("sakuya.in", "r", stdin);
	freopen("sakuya.out", "w", stdout);
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin >> n >> m;
	int x, y;
	long long w;
	for (int i = 1; i <= n - 1; i++) {
		cin >> x >> y;
		cin >> w;
		add(x, y, w);
		add(y, x, w);
	}
	for (int i = 1; i <= m; i++) {
		cin >> x;
		vis[x] = true;
	}
	val = 1;
	for (int i = 2; i <= m - 1; i++) val = val * i % mod;
	inv = val * m % mod;
	inv = ksm(inv, mod - 2);
	afs(1, 0);
	dfs(1, 0);
	cin >> q;
	long long k;
	for (int i = 1; i <= q; i++) {
		cin >> x >> k;
		ans = (ans + a[x] * k % mod) % mod;
		cout << ans * inv % mod << '\n';
	}
	return 0;
}

紅樓 ~ Eastern Dream 65pts

賽時沒寫正解,一是不會,二是寫出來常數也大,800ms很難卡過去

看見部分分有 $ x $ 小和 $ x $ 大的情況,於是考慮根號分治

對於 $ x \leq \sqrt n $ 的情況,我們開一個陣列 $ f_{i, j} $ 表示除以 $ i $ 餘數 $ \leq j $ 的增加量,修改直接 $ \Theta(\sqrt n) $ 暴改,查詢考慮 $ \Theta(\sqrt n) $ 遍歷每個 $ i $,然後統計這個區間有多少整塊(長度為 $ i $ ),然後兩邊的散塊 $ \Theta(1) $ 求即可;

對於 $ x > \sqrt n $ 的情況,我們發現加的區間一共不超過 $ \Theta(\sqrt n) $ 個,所以我們考慮 $ \Theta(\sqrt n) $ 的區間加和查詢;

這個不能套線段樹做,對於區間加,我們可以想到差分,考慮樹狀陣列區間修改區間查詢的做法,我們也可以維護一個差分陣列 $ c_i $;

然後類似樹狀陣列,我們維護 $ c_i, c_i \times i $,然後對於一個到 $ r $ 的字首和我們直接 $ (r + 1) \times \sum_{i = 1}^{r} c_i - \sum_{i = 1}^{r} c_i \times i $ 即可;

對於 $ (r + 1) \times \sum_{i = 1}^{r} c_i - \sum_{i = 1}^{r} c_i \times i $ 這個式子,考慮我們求的是 $ \sum_{i = 1}^{r} \sum_{j = 1}^{i} c_j $,然後考慮每個 $ c_j $ 只會加 $ r - j + 1 $ 次,提出 $ r + 1 $ 即可得到這個式子;

然後我們修改時直接暴力 $ \Theta(\sqrt n) $ 遍歷所有加的區間進行上述差分陣列的維護,查詢時直接 $ \Theta(\sqrt n) $ 遍歷到 $ l $ 的以及到 $ r $ 的塊和散塊即可解決這個問題,

時間複雜度:$ \Theta(n \sqrt n) $,常數較大,用了CuFeO4的快讀快寫;

點選檢視程式碼
#include <iostream>
#include <cstdio>
#include <cmath>
#include<bits/stdc++.h>
using namespace std;
namespace IO{
	struct IO{
		char buf[1<<16],*p1,*p2;
		char pbuf[1<<16], *pp = pbuf;
		#define gc() ((p1==p2&&(p2=((p1=buf)+fread_unlocked(buf,1,1<<16,stdin)),p1==p2))?EOF:*p1++)
		#define pc putchar_unlocked
		template<class T>
		inline void read(T &x){
			x = 0;bool flag = false;char s = gc();
			for(;s < '0' || '9' < s;s = gc());
			for(;'0' <= s && s <= '9';s = gc()) x = (x<<1)+(x<<3)+(s^48);
		}
		template<class T,class ...Args>
		inline void read(T &x,Args&... argc){read(x);read(argc...);}
		template <class T>
		inline void write(T x) {
			static int sta[30],top = 0;
			do{sta[top++] = x % 10, x /= 10;}while(x);
			while(top) pc(sta[--top] + '0');
		}
		inline void write(char x){pc(x);}
		template <class T,class... Args>
		inline void write(T x,Args... argc) {write(x);write(argc...);}
	}io;
	#define read io.read
	#define write io.write
}using namespace IO;
int n, m;
long long f[455][455], a[500005], sum[500005];
int sq;
int st[455], ed[455], belog[500005];
long long c[500005], ci[500005], sc[455], sci[455];
inline long long w(long long l, long long r) {
	long long ans = sum[r];
	if (l > 0) ans -= sum[l - 1];
	for (int i = 1; i <= sq; i++) {
		long long L = l + (i - 1 - (l % i) + 1);
		if (l % i == 0) L = l;
		long long R = r - (r % i) - 1;
		if (r % i == i - 1) R = r;
		if (L > R) {
			if ((r % i) - (l % i) == r - l) {
				ans += f[i][r % i];
				if ((l % i) > 0) ans -= f[i][(l % i) - 1];
			} else {
				ans += f[i][r % i];
				ans += f[i][i - 1];
				if ((l % i) > 0) ans -= f[i][(l % i) - 1];
			}
			continue;
		}
		ans += (R - L + 1) / i * f[i][i - 1];
		if (R != r) ans += f[i][r % i];
		if (L != l) ans += f[i][i - 1];
		if ((l % i) > 0 && L != l) ans -= f[i][(l % i) - 1];
	}
	return ans;
}
int main() {
	freopen("scarlet.in", "r", stdin);
	freopen("scarlet.out", "w", stdout);
	read(n, m);
	for (int i = 0; i < n; i++) {
		read(a[i]);
		sum[i] = sum[max(0, i - 1)] + a[i];
	}
	sq = sqrt(n);
	for (int i = 1; i <= sq; i++) {
		st[i] = (i - 1) * sq + 1;
		ed[i] = i * sq;
	}
	ed[sq] = n;
	for (int i = 1; i <= sq; i++) {
		for (int j = st[i]; j <= ed[i]; j++) {
			belog[j] = i;
		}
	}
	long long s, x, y, k;
	for (int i = 1; i <= m; i++) {
		read(s, x, y);
		if (s == 1) {
			read(k);
			if (x <= sq) {
				y = min(y, x - 1);
				for (int j = 0; j <= y; j++) {
					f[x][j] += (j + 1) * k;
				}
				for (int j = y + 1; j <= x - 1; j++) {
					f[x][j] += (y + 1) * k;
				}
			} else {
				y = min(y, x - 1);
				for (int j = 0; j <= n - 1; j += x) {
					c[j + 1] += k;
					if (j + 1 + y + 1 <= n) c[j + 1 + y + 1] -= k;
					ci[j + 1] += (j + 1) * k;
					if (j + 1 + y + 1 <= n) ci[j + 1 + y + 1] -= (j + 1 + y + 1) * k;
					sc[belog[j + 1]] += k;
					sci[belog[j + 1]] += (j + 1) * k;
					if (j + 1 + y + 1 <= n) sc[belog[j + 1 + y + 1]] -= k;
					if (j + 1 + y + 1 <= n) sci[belog[j + 1 + y + 1]] -= (j + 1 + y + 1) * k;
				}
			}
		} else {
			long long ans = w(x - 1, y - 1);
			long long sumc = 0, sumci = 0;
			for (int j = 1; j <= belog[x] - 1; j++) {
				sumc += sc[j];
				sumci += sci[j];
			}
			for (int j = st[belog[x]]; j <= x - 1; j++) {
				sumc += c[j];
				sumci += ci[j];
			}
			ans -= (1ll * x * sumc - sumci);
			for (int j = x; j <= min(y, 1ll * ed[belog[x]]); j++) {
				sumc += c[j];
				sumci += ci[j];
			}
			for (int j = belog[x] + 1; j <= belog[y] - 1; j++) {
				sumc += sc[j];
				sumci += sci[j];
			}
			if (belog[x] != belog[y]) {
				for (int j = st[belog[y]]; j <= y; j++) {
					sumc += c[j];
					sumci += ci[j];
				}
			}
			ans += (1ll * (y + 1) * sumc - sumci);
			write(ans,'\n');
		}
	}
	return 0;
}

相關文章