[Hackerrank University Codesprint 5] Sword profit (李超線段樹)

Fire_Raku發表於2024-07-03

[Hackerrank University Codesprint 5] Sword profit

李超線段樹

考慮大力推式子。寫出在第 \(i\) 所商店的第 \(k\) 把劍在第 \(j\) 所商店賣掉的價格。

\[\text{profit}=\max(0,q_i-(j-i)\cdot d_i-r_j)-(a_i+k\cdot b_i) \]

顯然利益一定要是正的才有價值,所以 \(\max\) 可以改到:

\[\text{profit}=\max(0,q_i-(j-i)\cdot d_i-r_j-(a_i+k\cdot b_i)) \]

小於 \(0\) 可以特判掉,先去掉 \(\max\),然後整理一下式子。

\[\text{profit}=q_i+i\cdot d_i-a_i-k\cdot b_i-(j\cdot b_i+r_i) \]

前面的部分是這把劍的固有貢獻,我們只需要將後面的部分最小化。容易看出後面的部分是一條 \(k=j\)\(b=r_i\) 的線段,定義域為 \([1,\max(b_i)]\)。而我們要求的就是 \(x=b_i\) 時所有線段的最小值。這是一個經典問題,可以用李超線段樹解決。

關於 \(k\),就是能獲得利益的最多的劍數。

從大到小列舉商店,每次先加入線段,再查詢即可。

複雜度 \(O(n\log^2n)\)

#include <bits/stdc++.h> 
#define pii std::pair<int, int>
#define mk std::make_pair
#define fi first
#define se second
#define pb push_back

using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const i64 N = 3e5 + 10, mod = 1e9 + 7, inv = (mod + 1) / 2;
i64 n, ans, m, cnt;
i64 q[N], a[N], b[N], r[N], d[N];
struct line {
	i64 k, b;
} f[N];
int t[N << 2];
i64 calc(int id, i64 x) {
	if(!id) return linf;
	return 1LL * f[id].k * x + f[id].b;
}
void mdf(int u, int l, int r, int x) {
	int mid = (l + r) >> 1;
	bool bmid = (calc(x, mid) <= calc(t[u], mid));
	if(bmid) std::swap(t[u], x);
	bool bl = (calc(x, l) < calc(t[u], l)), br = (calc(x, r) < calc(t[u], r));
	if(bl) mdf(u << 1, l, mid, x);
	if(br) mdf(u << 1 | 1, mid + 1, r, x);
}
void upd(int u, int l, int r, int L, int R, int x) {
	if(L <= l && r <= R) {
		mdf(u, l, r, x);
		return;
	}
	int mid = (l + r) >> 1;
	if(L <= mid) upd(u << 1, l, mid, L, R, x);
	if(R > mid) upd(u << 1 | 1, mid + 1, r, L, R, x);
}
int mn(int x, int y, int z) {
	bool ret = (calc(x, z) <= calc(y, z));
	if(ret) return x;
	return y;
}
int qry(int u, int l, int r, int x) {
	if(l == r) return t[u];
	int mid = (l + r) >> 1;
	if(x <= mid) return mn(t[u], qry(u << 1, l, mid, x), x);
	else return mn(t[u], qry(u << 1 | 1, mid + 1, r, x), x);
}
int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
	std::cin >> n;
	for(int i = 1; i <= n; i++) {
		std::cin >> q[i] >> a[i] >> b[i] >> r[i] >> d[i];
		m = std::max(m, d[i]);
	}

	for(i64 i = n; i >= 1; i--) {
		f[++cnt].k = i, f[cnt].b = r[i];
		upd(1, 1, m, 1, m, cnt);
		i64 p = qry(1, 1, m, d[i]);
		if(p) {
			i64 tot = q[i] + i * d[i] - a[i] - calc(p, d[i]), k = std::max(0LL, tot / b[i]) % mod; //特判
			if(k) ans = (ans + tot % mod * k % mod - 1LL * (1 + k) % mod * k % mod * inv % mod * b[i] % mod + mod) % mod;
		}
	}
	std::cout << ans << "\n";

	return 0;
}