gym103687D / QOJ3998 The Profiteer

dcytrl發表於2024-10-06

題意

\(n\) 個物品,和一個揹包容量上限 \(m\)。每個物品有價值 \(v_i\) 和體積 \(a_i\)

你需要選擇一段區間 \([l,r]\),將這個區間內的體積變為 \(b_i\),剩下的不變。然後你對這 \(n\) 個物品做揹包,設揹包容量結果為 \(f(i)\),需要求出有多少段區間使得 \(\dfrac{\sum_{i=1}^m f(i)}{m}\le E\)

\(n,k\le 2\times10^5,nk\le 10^7\)

分析

\(p_i\) 為最小的滿足 \([i,p_i]\) 合法的數。那麼答案就是 \(\sum_i n-p_i+1\)

首先,需要注意到 \(p_i\) 單調不降。暴力的話直接雙指標揹包即可,\(O(n^2k)\),飛了。

由於 \(p_i\) 滿足決策單調性那樣的性質,考慮套路性地分治,考慮設 \(\operatorname{solve}(l,r,L,R)\) 表示計算 \([l,r]\) 這段區間中的 \(p\)\(p\) 的取值範圍落在 \([L,R]\),且不在 \([l,r]\cup[L,R]\) 的物品已經被加入揹包。令 \(M=\lfloor\frac{L+R}{2}\rfloor\),考慮二分答案找到最大的滿足 \(p_i\le M\) 的下標,記作 \(m\),然後我們就把問題劃分成了 \(\operatorname{solve}(l,m,L,M),\operatorname{solve}(m+1,r,M+1,R)\) 兩個子問題,分別遞迴求解即可。

分析時間複雜度:若每次二分都暴力將 \([l,r]\cup[L,R]\) 內的物品加入,每一層中物品都要加入 \(O(n\log n)\) 次,分治一共 \(O(\log n)\) 層,每次加入物品的複雜度顯然 \(O(k)\),故複雜度 \(O(nk\log^2n)\),飛了。

考慮最佳化,發現實際上很多情況下物品都被重複加入了。考慮在二分前 \([L,R]\) 的取值(取 \(a_i\)\(b_i\))就已經確定了,提前將這些物品不在 \([l,r]\) 中的部分加入。考慮二分 \(mid\) 時實際上就是把 \([mid,r]\) 中的物品歸為 \(b_i\)\([l,mid)\) 歸為 \(a_i\),所以考慮在二分指標右移時(即 \(l=mid+1\))時 \([l,mid]\) 中的物品就永遠是 \(a_i\) 類的了,直接把這些物品加入揹包即可。二分指標左移同理。這樣物品加入次數就降為了 \(O(n)\),複雜度就是 \(O(nk\log n)\),看上去還是飛了但就是能過。

小細節:舉個例子,比如往左遞迴時需要將在 \([l,r]\cup[L,R]\) 但不在 \([l,m]\cup[L,M]\) 的物品加入,根據推導我們應該將 \([M+1,R]\) 劃給 \(a_i\),將 \([m+1,r]\) 劃給 \(b_i\),但這兩段區間可能會有交集,需要分類討論取哪一個。自行畫圖不難理解。

小細節 2:注意特殊處理一下 \(p_i>n\) 的情況,即 \([i,n]\) 不合法。

點選檢視程式碼
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<map>
#include<vector>
#include<queue>
#include<stack>
#include<bitset>
#include<set>
#include<ctime>
#include<random>
#include<cassert>
#define IOS ios::sync_with_stdio(false)
#define PY puts("Yes")
#define PN puts("No")
#define PW puts("-1")
#define P0 puts("0")
#define P__ puts("")
#define PU puts("--------------------")
#define mp make_pair
#define fi first
#define se second
#define pc putchar
#define pb emplace_back
#define un using namespace
#define popc __builtin_popcountll
#define all(x) x.begin(),x.end()
#define rep(a,b,c) for(int a=(b);a<=(c);++a)
#define per(a,b,c) for(int a=(b);a>=(c);--a)
#define reprange(a,b,c,d) for(int a=(b);a<=(c);a+=(d))
#define perrange(a,b,c,d) for(int a=(b);a>=(c);a-=(d))
#define graph(i,j,k,l) for(int i=k[j];i;i=l[i].nxt)
#define lowbit(x) (x&-x)
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
#define mem(x,y) memset(x,y,sizeof x)
//#define double long double
//#define int long long
//#define int __int128
using namespace std;
using i64=long long;
using u64=unsigned long long;
using pii=pair<int,int>;
inline int rd(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-48;ch=getchar();}return x*f;
}
template<typename T>
inline void write(T x,char ch='\0'){
	if(x<0){x=-x;putchar('-');}
	int y=0;char z[40];
	while(x||!y){z[y++]=x%10+48;x/=10;}
	while(y--)putchar(z[y]);if(ch!='\0')putchar(ch);
}
bool Mbg;
const int maxn=2e5+5,maxm=4e5+5,inf=0x3f3f3f3f;
const long long llinf=0x3f3f3f3f3f3f3f3f;
int n;
i64 m,E;
int val[maxn],wa[maxn],wb[maxn];
i64 ans;
vector<i64>f;
stack<vector<i64> >sta;
inline void add(int x,int typ){
	int w=typ?wb[x]:wa[x];
	per(i,m,w)f[i]=max(f[i],f[i-w]+val[x]);
}
inline bool ck(int l,int r,int mid){
	sta.emplace(f);
	rep(i,l,mid-1)add(i,0);
	rep(i,mid,r)add(i,1);
	i64 sum=0;
	rep(i,1,m)sum+=f[i];
	bool ok=sum<=m*E;
//	write(mid,32),write(r,32),write(sum,10);
	f=sta.top();sta.pop();
	if(ok){
		rep(i,l,mid)add(i,0);
	}else{
		rep(i,mid,r)add(i,1);
	}
	return ok;
}
inline void solve(int l,int r,int ll,int rr,bool valid=false){
//	write(l,32),write(r,32),write(ll,32),write(rr,10);
	if(l>r||ll>rr)return;
	if(ll==rr&&valid){
		ans+=1ll*(r-l+1)*(n-ll+1);
		return;
	}
	const int mm=(ll+rr)>>1;
	sta.emplace(f);
	int L=l,R=min(mm,r),res=L-1;
	rep(i,ll,mm)if(!(L<=i&&i<=R))add(i,1);
	rep(i,mm+1,rr)if(!(L<=i&&i<=R))add(i,0);
	while(L<=R){
		int mid=(L+R)>>1;
		if(ck(L,R,mid))res=mid,L=mid+1;
		else R=mid-1;
	}
	f=sta.top();
	rep(i,mm+1,rr)if(!(l<=i&&i<=res))add(i,0);
	rep(i,res+1,r)if(!(ll<=i&&i<=rr))add(i,1);
	solve(l,res,ll,mm,1);
	f=sta.top();sta.pop();
	rep(i,l,res)if(!(mm<i&&i<=rr))add(i,0);
	rep(i,ll,mm)if(!(l<=l&&i<=r))add(i,1);
	solve(res+1,r,mm+1,rr,0);
}
inline void solve_the_problem(){
	n=rd(),m=rd(),E=rd();
	rep(i,1,n)val[i]=rd(),wa[i]=rd(),wb[i]=rd();
	f.resize(m+1,0);
	solve(1,n,1,n);
	write(ans);
}
bool Med;
signed main(){
//	freopen(".in","r",stdin);freopen(".out","w",stdout);
//	fprintf(stderr,"%.3lfMB\n",(&Mbg-&Med)/1048576.0);
	int _=1;
	while(_--)solve_the_problem();
}
/*

*/