[CCPC2022 廣東] XOR Sum

Fire_Raku發表於2024-07-21

數位 dp

看到這樣求和價值的計算,考慮可不可以交換求和符號或者改變計算方式。

這題中的位運算使我們考慮按位計算貢獻,價值可以寫成:

\[f(A)=\sum_{i=0}2^i\times c_i\times (k-c_i) \]

其中 \(c_i\) 表示第 \(i\) 位上為 \(1\)\(a_i\) 數量。

題目第二個要求即 \(f(A)=n\)。考慮從高位到低位計算貢獻,類似數位 dp 計算方案數。於是序列中的元素就分為兩種:卡了上界和沒卡上界的。並且計算到當前位時需要知道低位留下來的餘數,使該位最終與 \(n\) 上這一位相同。

\(dp(i,j,k)\) 表示考慮完從高到低前 \(i\) 位,此時低位留下的餘數為 \(j\),卡了上界的數的數量為 \(k\) 的方案數。

轉移看 \(m\) 上第 \(i\) 位上是 \(0\) 還是 \(1\)

如果是 \(0\),那麼卡上界的數只能繼續卡上界,列舉沒卡上界的數中 \(1\) 的個數。

如果是 \(1\),分別列舉卡上界和不卡上界的 \(1\) 的個數。

\(1\) 的位置不固定,所以需要預處理組合數。

記憶化搜尋即可。

分析當前位上餘數最多是多少,如果餘數 \(cnt\) 滿足 \((cnt-81)\times 2\ge cnt\),那麼低位上不存在一種方案使得出現這樣的餘數。

複雜度 \(O(50\times162\times18\times18\times 18)\),遠小於此數。

#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define mk std::make_pair
#define fi first
#define se second
#define pb push_back

using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 50, M = 170, K = 19, mod = 1e9 + 7;
i64 n, m, k;
i64 f[N][M][K], a[N], c[K][K];
i64 dfs(int dep, int left, int cnt) {
	if(left >= 162) return 0;
	if(dep == -1) return left == 0;
	if(f[dep][left][cnt] != -1) return f[dep][left][cnt];
	int dig = (!dep ? 0 : ((n >> (dep - 1)) & 1));
	i64 ans = 0;
	if(!a[dep]) {
		for(int i = 0; i <= k - cnt; i++) {
			int cur = left - 1LL * i * (k - i);
			if(cur < 0) continue;
			ans = (ans + c[k - cnt][i] * dfs(dep - 1, (cur << 1) | dig, cnt) % mod) % mod;
		}
	} else {
		for(int i = 0; i <= cnt; i++) {
			for(int j = 0; j <= k - cnt; j++) {
				int cur = left - 1LL * (i + j) * (k - j - i);
				if(cur < 0) continue;
				ans = (ans + c[cnt][i] * c[k - cnt][j] % mod * dfs(dep - 1, (cur << 1) | dig, i) % mod) % mod;
			}
		}
	}
	f[dep][left][cnt] = ans;
	return ans;
}
int solve() {
	if(k == 1) return !n;		
	if(!m) return !n;
	memset(f, -1, sizeof(f));
	i64 l = 0, left = 0;
	while(m) {
		a[l++] = m % 2;
		m >>= 1;
	}
	for(int i = l; i <= 50; i++) {
		if((n >> i) & 1) {
			left += (1LL << (i - l));
		}
	}
	if(left >= 162) return 0;
	return dfs(l, left, k);
}
int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
	std::cin >> n >> m >> k;

	c[0][0] = 1;
	for(int i = 1; i <= k; i++) {
		c[i][0] = 1;
		for(int j = 1; j <= i; j++) {
			c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
		}
	}
	std::cout << solve() << "\n";

	return 0;
}