P5664 [CSP-S2019] Emiya 家今天的飯

dcytrl發表於2024-05-08

題意簡述

\(n\) 種方法和 \(m\) 種食材,第 \(i\) 種方法第 \(j\) 種食材做出來的菜有 \(a_{i,j}\) 種。

有以下限制:

  • 至少做一盤菜。
  • 每種方法做出來的菜品數至多為 \(1\)
  • 所有以第 \(i\) 種食材做出來的菜品數不超過菜品種數的一半。

求方案數。

\(n\le 100,m\le 2\times10^3\)

分析

條件一、二都很好滿足,問題在於條件三有點棘手。

發現這個“一半”的限制很特殊,考慮從這裡入手,進一步顯然發現不合法的食材只有至多一種,而且總數(即不考慮第三條限制的方案數)可以容易的透過 dp \(O(n^2)\) 求出(後文會有),這樣的話我們可以考慮正難則反了。

列舉這個不合法的食材 \(c\),設計 dp 狀態 \(f_{i,j,k}\) 表示前 \(i\) 種方法,不合法的食材選了 \(j\) 次,其他的食材選了 \(k\) 次的方案數。轉移考慮第 \(i\) 種方法下是不做 \(f_{i-1,j,k}\rightarrow f_{i,j,k}\) 或者做不合法食材 \(f_{i-1,j-1,k}\cdot a_{i,c}\rightarrow f_{i,j,k}\) 或者做其他食材 \(f_{i-1,j,k-1}\cdot (s_i-a_{i,c})\rightarrow f_{i,j,k}\),這裡 \(s_i\)\(a_{i,*}\) 的字首和,即第 \(i\) 種方法可以做出的菜品個數。我們需要的答案是 \(\sum_{i>j} f_{n,i,j}\)

考慮求出總方案。設 \(g_{i,j}\) 表示前 \(i\) 種方法做了 \(j\) 盤菜,轉移和 \(f\) 類似。

時間複雜度 \(O(mn^3)\),瓶頸在於 \(f\) 的轉移。

考慮最佳化。發現我們並不需要知道 \(j,k\) 具體的值,我們只需要讓 \(j>k\)。重新設 \(f_{i,j}\) 表示前 \(i\) 種方法,不合法食材與其他食材的數量差為 \(j\)。由於 \(j\) 可能為負,所以加上一個整體偏移量。

時間複雜度 \(O(mn^2)\),可以透過。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<map>
#include<unordered_map>
#include<vector>
#include<queue>
#include<bitset>
#include<set>
#include<ctime>
#include<random>
#include<cassert>
#define x1 xx1
#define y1 yy1
#define IOS ios::sync_with_stdio(false)
#define ITIE cin.tie(0);
#define OTIE cout.tie(0);
#define PY puts("Yes")
#define PN puts("No")
#define PW puts("-1")
#define P0 puts("0")
#define P__ puts("")
#define PU puts("--------------------")
#define popc __builtin_popcount
#define mp make_pair
#define fi first
#define se second
#define gc getchar
#define pc putchar
#define pb emplace_back
#define rep(a,b,c) for(int a=(b);a<=(c);++a)
#define per(a,b,c) for(int a=(b);a>=(c);--a)
#define reprange(a,b,c,d) for(int a=(b);a<=(c);a+=(d))
#define perrange(a,b,c,d) for(int a=(b);a>=(c);a-=(d))
#define graph(i,j,k,l) for(int i=k[j];i;i=l[i].nxt)
#define lowbit(x) (x&-x)
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
#define mem(x,y) memset(x,y,sizeof x)
//#define double long double
#define int long long
//#define int __int128
using namespace std;
typedef long long i64;
using pii=pair<int,int>;
bool greating(int x,int y){return x>y;}
bool greatingll(long long x,long long y){return x>y;}
inline int rd(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}return x*f;
}
inline void write(int x,char ch='\0'){
	if(x<0){x=-x;putchar('-');}
	int y=0;char z[40];
	while(x||!y){z[y++]=x%10+48;x/=10;}
	while(y--)putchar(z[y]);if(ch!='\0')putchar(ch);
}
bool Mbg;
const int maxn=105,maxm=2e3+5,inf=0x3f3f3f3f,mod=998244353,delta=101;
const long long llinf=0x3f3f3f3f3f3f3f3f;
int n,m,a[maxn][maxm];
int f[maxn][maxn<<1],s[maxn];
int g[maxn][maxn];
void solve_the_problem(){
	n=rd(),m=rd();rep(i,1,n)rep(j,1,m)a[i][j]=rd(),s[i]=(s[i]+a[i][j])%mod;
//	rep(i,1,n)write(s[i],32);P__;
	g[0][0]=1;
	rep(i,1,n){
		g[i][0]=g[i-1][0];
		rep(j,1,n)g[i][j]=(g[i-1][j]+g[i-1][j-1]*s[i]%mod)%mod;
	}
	int ans=0;
	rep(i,1,n)ans=(ans+g[n][i])%mod;
//	write(tot,32);
	rep(c,1,m){
		int res=0;
		mem(f,0);
		f[0][delta]=1;
		rep(i,1,n)rep(j,1,2*delta){
			int rem=(s[i]-a[i][c]+mod)%mod;
			f[i][j]=f[i-1][j];
			f[i][j]=(f[i][j]+f[i-1][j-1]*a[i][c]%mod)%mod;
			f[i][j]=(f[i][j]+f[i-1][j+1]*rem%mod)%mod;
		}
		rep(i,delta+1,2*delta)res=(res+f[n][i])%mod;
//		write(res,32);
		ans=(ans-res+mod)%mod;
	}
	write(ans);
}
bool Med;
signed main(){
//	freopen(".in","r",stdin);freopen(".out","w",stdout);
//	fprintf(stderr,"%.3lfMB\n",(&Mbg-&Med)/1048576.0);
	int _=1;while(_--)solve_the_problem();
}
/*

*/

相關文章