CF1442D Sum 分治 揹包dp

forever_shi發表於2020-11-18

題目連結

題意:
給你 n n n個單調不降序列,要從中取 k k k個數,每次可以取一個非空序列的第一個元素,並將這個元素從這個序列中刪去,問你取的 k k k個數得到的最大權值和是多少。

n , k < = 3000 , n,k<=3000, n,k<=3000,元素總數 < = 1 e 6 <=1e6 <=1e6

題解:
對現在的我來說還是一個挺神仙的題。我覺得我分治水平挺差的,每次見到分治題都會大呼神仙。

暴力做法是揹包dp併合並揹包,然後顯然是 O ( n k 2 ) O(nk^2) O(nk2)的。然後就沒有然後了,應該沒法優化了。

首先顯然一個序列 k k k個以後的元素都沒用了,我們就不管了。

這個題首先要發現這個單調不降的性質的用法:對於所有從中選過元素的序列,最終最多有一個序列沒有選完。對於任意兩個序列,如果在最終方案中他們都沒有選完,不妨設這兩個序列為 a [ ] , b [ ] a[],b[] a[],b[],分別選到了 a [ i ] a[i] a[i] b [ j ] b[j] b[j],那麼一定有 a [ i ] > = b [ j ] a[i]>=b[j] a[i]>=b[j]或者 a [ i ] < = b [ j ] a[i]<=b[j] a[i]<=b[j],那麼我們選較大的那個的後面還沒選的若干個來替換小的那個的最後選的這幾個,一定會讓答案更優,這樣直到權值小的那個一個都不選了或者權值大的那個序列全部被選。

有了這個性質之後,我們的想法就是處理出每個序列的權值之和,然後列舉哪個是沒全選的序列。

但是很不幸,我們比暴力多發現了一個性質,寫出來的東西複雜度卻是一樣的。

但是這東西好在可以優化。當然就是用我們提到的分治了。

應用典型的思想,我們考慮一個類似線段樹的樹形分治結構。我們從層數小的向層數大的處理,直到遇到葉節點。
我們遞迴到一半,然後考慮另一半對這一半的影響。我們的做法是,以先遞迴到左區間為例,那麼我們就先把右區間的每一個序列暴力揹包合併,然後遞迴到左區間,直到區間長度為 1 1 1,這時候我們就得到了其他每個序列要麼全加入,要麼全不加入的答案,於是我們當前的這個序列就是可以有剩餘的那個。我們列舉這個序列加進去多少個,然後邊列舉邊維護一個字首和,然後嘗試更新答案即可。對於向右遞迴,我們也是先暴力將左區間的每一個序列更新揹包的dp值即可。

關鍵在於,分析這麼做的複雜度。我們可以發現,對於每個序列,你考慮從根走到這個序列的過程中,每次走向不是它所在的分枝時它會被合併一次,換句話就是每往下一層會用來合併一次,所以用來合併的次數是 l o g n logn logn的。

注意每次合併之前先清空下一層的dp值。

這樣就做完了,複雜度是 O ( n k l o g n ) O(nklogn) O(nklogn)

好訊息是,這個題有程式碼。
壞訊息是,有個我覺得等價的寫法,不知道為什麼過不了。我覺得等價的寫法是我把每個序列的長度截斷為它的長度和 k k k m i n min min,但是好像錯的時候比答案大了,但是我不知道為什麼,要是有知道的大佬非常期待您能回答一下(WA test 9)

程式碼:

#include <bits/stdc++.h>
using namespace std;

int n,k,len[3010];
vector<long long> v[3010];
long long s[3010],dp[60][3010],ans;
void solve(int dep,int l,int r)
{
	if(l==r)
	{
		long long res=0;
		ans=max(ans,dp[dep][k]);
		for(int i=1;i<=min(k,len[l]);++i)
		{
			res+=v[l][i-1];
			ans=max(ans,res+dp[dep][k-i]);
		}
		return;
	}
	int mid=(l+r)>>1;
	for(int i=0;i<=k;++i)
	dp[dep+1][i]=dp[dep][i];
	for(int i=l;i<=mid;++i)
	{
		for(int j=k;j>=len[i];--j)
		dp[dep+1][j]=max(dp[dep+1][j-len[i]]+s[i],dp[dep+1][j]);
	}
	solve(dep+1,mid+1,r);
	
	for(int i=0;i<=k;++i)
	dp[dep+1][i]=dp[dep][i];
	for(int i=mid+1;i<=r;++i)
	{
		for(int j=k;j>=len[i];--j)
		dp[dep+1][j]=max(dp[dep+1][j-len[i]]+s[i],dp[dep+1][j]);
	}
	solve(dep+1,l,mid);	
}
int main()
{
	scanf("%d%d",&n,&k);
	for(int i=1;i<=n;++i)
	{
		scanf("%d",&len[i]);
		for(int j=1;j<=len[i];++j)
		{
			long long x;
			scanf("%lld",&x);
			if(j<=k)
			{
				v[i].push_back(x);
				s[i]+=x;
			}
			//len[i]=min(len[i],k); //加上這個就會出錯
		}
	}
	solve(1,1,n);
	printf("%lld\n",ans);
	return 0;
}

相關文章