UOJ 241. 【UR #16】破壞發射臺 [矩陣乘法]

Candy?發表於2017-05-14

UOJ 241. 【UR #16】破壞發射臺

題意:長度為 n 的環,每個點染色,有 m 種顏色,要求相鄰相對不能同色,求方案數。(定義兩個點相對為去掉這兩個點後環能被分成相同大小的兩段)


只想到一個奇怪的線性遞推,無法寫成矩乘的形式...

正解用狀態記錄了顏色是否相同

奇環,只考慮相鄰,確定第一個的顏色,\(f[i][0/1]\)表示i個與第一個不同/同色的方案數

偶環,再考慮相對,分成兩段,同時遞推\(i,\frac{n}{2}+i\),\(f[i][0..6]\)來表示

構造矩陣討論好煩啊

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
typedef long long ll;
const int mo = 998244353;
inline int read() {
    char c=getchar(); int x=0,f=1;
    while(c<'0'||c>'9') {if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();}
    return x*f;
}

int n, m, r, g[7][7], f[7][7];
void mul(int a[7][7], int b[7][7]) {
    static int c[7][7];
    memset(c, 0, sizeof(c));
    for(int i=0; i<r; i++)
        for(int k=0; k<r; k++) if(a[i][k])
            for(int j=0; j<r; j++) if(b[k][j])
                c[i][j] = (c[i][j] + (ll) a[i][k] * b[k][j]) %mo;
    memcpy(a, c, sizeof(c));
}
void Pow(int a[7][7], int b) {
    static int c[7][7];
    memset(c, 0, sizeof(c));
    for(int i=0; i<r; i++) c[i][i] = 1;
    for(; b; b >>= 1, mul(a, a)) if(b & 1) mul(c, a);
    memcpy(a, c, sizeof(c));
}
void print(int a[7][7]) {
    for(int i=0; i<r; i++) for(int j=0; j<r; j++) printf("%d%c", a[i][j], j==r-1 ? '\n' : ' ');
    puts("");
}
namespace odd {
    void solve() {
        r = 2;
        g[0][0] = m-2; g[0][1] = m-1;
        g[1][0] = 1;   g[1][1] = 0;
        Pow(g, n-2);
        //print(g);
        f[0][0] = m-1; f[1][0] = 0;
        mul(g, f);
        //print(g);
        int ans = (ll) g[0][0] * m %mo;
        printf("%d\n", ans);
    }
}
namespace even {
    int id[5][5];
    inline ll cal(int a, int c) {
        if(a == 0) return c==0 ? m-3 : m-2;
        else return 1;
    }
    void solve() {
        r = 7;
        memset(id, -1, sizeof(id));
        id[0][0] = 0; id[0][1] = 1; id[0][2] = 2;
        id[1][0] = 3; id[1][2] = 4;
        id[2][0] = 5; id[2][1] = 6;
        for(int a=0; a<3; a++) for(int b=0; b<3; b++) if(~id[a][b])
            for(int c=0; c<3; c++) for(int d=0; d<3; d++) if(~id[c][d]) {
                int i = id[a][b], j = id[c][d];
                if((a && a==c) || (b && b==d)) continue;
                if(a == 0 && b == 0) { //printf("hi\n");
                    if(c && d) g[i][j] = (ll) (m-2) * max(0, m-3) %mo;
                    else if(!c && !d) g[i][j] = ((ll) max(0, m-4) * max(0, m-4) + max(0, m-3)) %mo;
                    else if(c || d) g[i][j] = ((ll) max(0, m-3) * max(0, m-3)) %mo;
                    g[i][j] = max(0, g[i][j]);
                } else g[i][j] = cal(a, c) * cal(b, d) %mo;
            }
        //print(g);
        f[id[0][0]][0] = max(0LL, (ll)(m-2) * (m-3)) %mo;
        f[id[0][1]][0] = m-2;
        f[id[2][0]][0] = m-2;
        f[id[2][1]][0] = 1;
        n = n/2 - 1;
        Pow(g, n-1);
        //print(g);
        mul(g, f);
        //print(g);
        int ans = (ll) ((ll) g[0][0] + g[id[1][0]][0] + g[id[0][2]][0] + g[id[1][2]][0]) %mo * m %mo * (m-1) %mo;
        printf("%d\n", ans);
    }
}
int main() {
    freopen("in", "r", stdin);
    n = read(); m = read();
    if(n == 1) printf("%d\n", m);
    else if(n == 2) printf("%lld\n", (ll) m * (m-1) %mo);
    if(n & 1) odd::solve();
    else even::solve();
}

相關文章