數位 dp
看到這樣求和價值的計算,考慮可不可以交換求和符號或者改變計算方式。
這題中的位運算使我們考慮按位計算貢獻,價值可以寫成:
\[f(A)=\sum_{i=0}2^i\times c_i\times (k-c_i)
\]
其中 \(c_i\) 表示第 \(i\) 位上為 \(1\) 的 \(a_i\) 數量。
題目第二個要求即 \(f(A)=n\)。考慮從高位到低位計算貢獻,類似數位 dp 計算方案數。於是序列中的元素就分為兩種:卡了上界和沒卡上界的。並且計算到當前位時需要知道低位留下來的餘數,使該位最終與 \(n\) 上這一位相同。
設 \(dp(i,j,k)\) 表示考慮完從高到低前 \(i\) 位,此時低位留下的餘數為 \(j\),卡了上界的數的數量為 \(k\) 的方案數。
轉移看 \(m\) 上第 \(i\) 位上是 \(0\) 還是 \(1\):
如果是 \(0\),那麼卡上界的數只能繼續卡上界,列舉沒卡上界的數中 \(1\) 的個數。
如果是 \(1\),分別列舉卡上界和不卡上界的 \(1\) 的個數。
\(1\) 的位置不固定,所以需要預處理組合數。
記憶化搜尋即可。
分析當前位上餘數最多是多少,如果餘數 \(cnt\) 滿足 \((cnt-81)\times 2\ge cnt\),那麼低位上不存在一種方案使得出現這樣的餘數。
複雜度 \(O(50\times162\times18\times18\times 18)\),遠小於此數。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define mk std::make_pair
#define fi first
#define se second
#define pb push_back
using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 50, M = 170, K = 19, mod = 1e9 + 7;
i64 n, m, k;
i64 f[N][M][K], a[N], c[K][K];
i64 dfs(int dep, int left, int cnt) {
if(left >= 162) return 0;
if(dep == -1) return left == 0;
if(f[dep][left][cnt] != -1) return f[dep][left][cnt];
int dig = (!dep ? 0 : ((n >> (dep - 1)) & 1));
i64 ans = 0;
if(!a[dep]) {
for(int i = 0; i <= k - cnt; i++) {
int cur = left - 1LL * i * (k - i);
if(cur < 0) continue;
ans = (ans + c[k - cnt][i] * dfs(dep - 1, (cur << 1) | dig, cnt) % mod) % mod;
}
} else {
for(int i = 0; i <= cnt; i++) {
for(int j = 0; j <= k - cnt; j++) {
int cur = left - 1LL * (i + j) * (k - j - i);
if(cur < 0) continue;
ans = (ans + c[cnt][i] * c[k - cnt][j] % mod * dfs(dep - 1, (cur << 1) | dig, i) % mod) % mod;
}
}
}
f[dep][left][cnt] = ans;
return ans;
}
int solve() {
if(k == 1) return !n;
if(!m) return !n;
memset(f, -1, sizeof(f));
i64 l = 0, left = 0;
while(m) {
a[l++] = m % 2;
m >>= 1;
}
for(int i = l; i <= 50; i++) {
if((n >> i) & 1) {
left += (1LL << (i - l));
}
}
if(left >= 162) return 0;
return dfs(l, left, k);
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cin >> n >> m >> k;
c[0][0] = 1;
for(int i = 1; i <= k; i++) {
c[i][0] = 1;
for(int j = 1; j <= i; j++) {
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
}
}
std::cout << solve() << "\n";
return 0;
}