P10958 啟示錄 解題報告

Brilliant11001發表於2024-09-01

更好的閱讀體驗

用記憶化搜尋寫數位 dp 真的很好寫!

題目傳送門

題目大意:

\(T\) 組資料,每次詢問第 \(x\) 個含有至少 \(3\) 個連續 \(6\) 的數是什麼。

思路:

考慮數位 dp。

一般數位 dp 問題有兩種常見形式:

  1. 詢問 \([l, r]\) 內有多少個符合條件的數;
  2. 詢問滿足條件的第 \(k\) 大(小)的數是什麼。

很顯然這道題是第二種形式。

首先問題 \(1\) 很簡單,那我們考慮將第二個問題轉化成第一個問題來做。

因為答案具有單調性,於是可以二分判定。

每次二分到一個值 \(mid\),計算 \([1, mid]\) 的魔鬼數個數,若大於等於 \(x\),則說明所求在 \(mid\) 左側,否則在 \(mid\) 右側。

接著考慮問題 \(1\),這裡採用記憶化搜尋的方式,註釋在程式碼中。

//pos 記錄當前填到了哪一位,cnt 記錄當前末尾有幾個連續的 6,flag 記錄當前數是否滿足條件
//limit 記錄當前有沒有頂上界
//因為這道題有沒有前導零無影響,遂不記錄
int dfs(int pos, int cnt, bool flag, bool limit) {
    //邊界,若填完了就檢查一下是否符合條件
    if(pos < 0) return flag;
    //若不頂上界就記憶化,因為頂上界是特殊情況,滿足條件的數可能和普通情況不同
    if(!limit && f[pos][cnt][flag] != -1) return f[pos][cnt][flag];
    //看一下當前這位需不需要頂上界,若前面填的數都是貼著上界的,這一位最多隻能填到 num[pos],否則不受限
    int mx = (limit ? num[pos] : 9);
    int res = 0;
    //列舉第 pos 位填什麼
    for(int i = 0; i <= mx; i++) {
        //處理連續的 6
        int ncnt;
        if(i == 6) ncnt = cnt + 1;
        else ncnt = 0;
        res += dfs(pos - 1, ncnt, flag || (ncnt >= 3), limit && (i == num[pos]));
    }
    //若不頂上界就記憶化
    if(!limit) f[pos][cnt][flag] = res;
    return res;
}

這裡我直接把二分值域拉滿了,但是實測發現第 \(50000000\) 個魔鬼數只有 \(6668056399\)

時間複雜度為:\(O(N^2MT\log V)\),這裡 \(N\) 表示數字位數,\(V\) 表示二分值域,\(M\) 表示每次列舉填的數的個數,可看作 \(10\)

\(\texttt{Code:}\)

#include <vector>
#include <cstring>
#include <iostream>

using namespace std;
typedef long long ll;

const int N = 20;

int T;
int x;
ll f[N][N][2];
vector<int> num;

ll dfs(int pos, int cnt, bool flag, bool limit) {
    if(pos < 0) return flag;
    if(!limit && f[pos][cnt][flag] != -1) return f[pos][cnt][flag];
    int mx = (limit ? num[pos] : 9);
    ll res = 0;
    for(int i = 0; i <= mx; i++) {
        int ncnt;
        if(i == 6) ncnt = cnt + 1;
        else ncnt = 0;
        res += dfs(pos - 1, ncnt, flag || (ncnt >= 3), limit && (i == num[pos]));
    }
    if(!limit) f[pos][cnt][flag] = res;
    return res;
}

ll calc(ll x) {
    num.clear();
    ll tmp = x;
    while(tmp) {
        num.push_back(tmp % 10);
        tmp /= 10;
    }
    return dfs(num.size() - 1, 0, 0, 1);
}

void solve() {
    scanf("%d", &x);
    ll l = 1, r = 5e18;
    while(l < r) {
        ll mid = l + r >> 1;
        if(calc(mid) >= x) r = mid;
        else l = mid + 1;
    }
    printf("%lld\n", l);
}

int main() {
    scanf("%d", &T);
    memset(f, -1, sizeof f);
    while(T--) {
        solve();
    }
    return 0;
}

相關文章