YC303C [ 20240617 CQYC省選模擬賽 T3 ] Generals(generals)

cxqghzj發表於2024-06-19

題意

給定一張 \(n \times m\) 的地圖。

對於第 \(0\) 列,第 \(m + 1\) 列,第 \(0\) 行,第 \(n + 1\) 行,有 \(2n + 2m\) 個人,每個人面朝地圖中心。

每個人走到別人染過色的位置,或走出地圖,將走過的地方染色。

你需要求出共有多少種本質不同的染色方案。

\(n, m \le 10 ^ 6\)

Sol

直接做似乎很不好做,考慮一些特殊情況。

一個人都沒有輸出 \(1\)

若只有兩個方向,且兩個方向對立,顯然答案為 \(2\) 的次冪。

若只有兩個方向的人且兩個方向相鄰,顯然答案為 \(\dbinom{x + y}{x}\)

考慮有三個方向的時候。

假設目前的三個方向分別為:向右,向下,向上,分別設人數為 \(x, y, z\)

若當前有一列被貫通,那麼右邊部分變為兩個方向且對立的情況。

考慮列舉第一列被貫通的位置。

那麼顯然對於左邊的部分,一定有至少有一行被染滿。

列舉當前最後一個染滿的行 \(i\),則又變為兩個子問題。

對於上方的是無限制,顯然答案為 \(\dbinom{i + y - 1}{i - 1}\)

對於下方不能被染滿行,考慮這個東西的組合意義,那麼很顯然就是不能到達最後一列。

所以答案為 \(\dbinom{x - i + z}{x - i - 1}\)

合起來:

\[\begin{aligned} & \sum_{i = 1} ^ {x} \dbinom{i + y - 1}{i - 1} \dbinom{x - i + z}{x - i - 1} \\ & = \sum_{i = 0} ^ {x - 1} \dbinom{i + y}{i} \dbinom{x - i + z - 1}{x - i - 1} \\ &= \dbinom{x + y + z - 1}{x - 1} \end{aligned} \]

這樣就搞完了。

考慮四個方向的,不難發現必定有一列或一行貫穿,可以只考慮一列的情況,而一行可以翻轉得到。

列舉最後一列被染滿的,右邊部分就是標準的三方向問題,直接組合數搞完了,左邊部分可以染滿一列,設 \(f_i\) 表示前 \(i\) 列的方案數。

若當前一列染滿,直接就是 \(f_{i - 1}\),而沒染滿就是說明前面沒有任何一列染滿,也是三方向問題。

最後考慮一下上下是否都有 \(1\),若都有 \(1\),當前答案與 \(f_i\) 都要 \(\times 2\)

Code

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#define int long long
using namespace std;
#ifdef ONLINE_JUDGE

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;

#endif
int read() {
    int p = 0, flg = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') flg = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        p = p * 10 + c - '0';
        c = getchar();
    }
    return p * flg;
}
string read_() {
    string ans;
    char c = getchar();
    while (c != '0' && c != '1')
        c = getchar();
    while (c == '0' || c == '1')
        ans += c, c = getchar();
    return ans;
}
void write(int x) {
    if (x < 0) {
        x = -x;
        putchar('-');
    }
    if (x > 9) {
        write(x / 10);
    }
    putchar(x % 10 + '0');
}
bool _stmer;

const int N = 1e6 + 5, M = 4e6 + 5, mod = 998244353;

int pow_(int x, int k, int p) {
    int ans = 1;
    while (k) {
        if (k & 1) ans = ans * x % p;
        x = x * x % p;
        k >>= 1;
    }
    return ans;
}

array <int, M> fac, inv;

void init(int n) {
    fac[0] = 1;
    for (int i = 1; i <= n; i++)
        fac[i] = fac[i - 1] * i % mod;
    inv[n] = pow_(fac[n], mod - 2, mod);
    for (int i = n; i; i--)
        inv[i - 1] = inv[i] * i % mod;
}

int C(int n, int m) {
    if (n < m || n < 0 || m < 0) return 0;
    return fac[n] * inv[m] % mod * inv[n - m] % mod;
}

int Y(int x, int y, int z) {
    if (!x) return (!y && !z);
    return C(x + y + z - 1, x - 1);
}

void Mod(int &x) {
    if (x >= mod) x -= mod;
    if (x < 0) x += mod;
}

bool _edmer;
signed main() {
    cerr << (&_stmer - &_edmer) / 1024.0 / 1024.0 << "MB\n";
    init(4e6);
    /* while (1) { */
        /* int x = read(), y = read(), z = read(); */
        /* write(Y(x, y, z)), puts(""); */
    /* } */
    int n = read(), m = read();
    string sL =  " " + read_(), sR = " " + read_(), sU = " " + read_(), sD = " " + read_();
    int tp1 = 0, tp2 = 0, tp3 = 0, tp4 = 0;
    for (int i = 1; i <= n; i++) tp1 += sL[i] == '1';
    for (int i = 1; i <= n; i++) tp2 += sR[i] == '1';
    for (int i = 1; i <= m; i++) tp3 += sU[i] == '1';
    for (int i = 1; i <= m; i++) tp4 += sD[i] == '1';
    int len = (tp1 && 1) + (tp2 && 1) + (tp3 && 1) + (tp4 && 1);
    if (len <= 1) return puts("1"), 0;
    if (len == 2) {
        if ((tp1 + tp2) && (tp3 + tp4))
            return write(C(tp1 + tp2 + tp3 + tp4, tp1 + tp2)), puts(""), 0;
        int res = 1;
        for (int i = 1; i <= n; i++)
            if (sL[i] == '1' && sR[i] == '1')
                res = res * 2ll % mod;
        for (int i = 1; i <= m; i++)
            if (sU[i] == '1' && sD[i] == '1')
                res = res * 2ll % mod;
        return write(res), puts(""), 0;
    }
    auto solve = [&](string tL, string tU, string tD, int _tp2, int _tp3, int _tp4) -> int {
        int ans = 0, sum = 0;
        int l1 = 0, l2 = 0, l3 = 0; //Left Up Down
        for (int i = 1; i <= n; i++) l1 += tL[i] == '1';
        for (int i = 1; i <= m; i++) {
            if (tU[i] == '0' && tD[i] == '0') continue;
            int tp0 = (tU[i] == '1' && tD[i] == '1') ? 2 : 1;
            sum += Y(l1, l2, l3), Mod(sum);
            sum = sum * tp0 % mod;
            if (tU[i] == '1') l2++;
            if (tD[i] == '1') l3++;
            ans += sum * Y(_tp2, _tp3 - l2, _tp4 - l3) % mod, Mod(ans);
        }
        return ans;
    };
    /* cerr << solve() << "@@" << endl, exit(0); */
    int ans = 0;
    ans += solve(sL, sU, sD, tp2, tp3, tp4), Mod(ans);
    swap(n, m);
    ans += solve(sU, sR, sL, tp4, tp2, tp1), Mod(ans);
    write(ans), puts("");
    return 0;
}

相關文章