Atcoder ARC090F Number of Digits

rizynvu發表於2024-07-03

\(n\) 為題面的 \(S\)

能發現對於 \(f(l) = 8\),共有 \(9\times 10^7\) 個數。
此時就已經有 \(8\times 9\times 10^7 > 10^8 = n_{\max}\) 了,就說明不存在 \(f \ge 8\) 的情況,還滿足這部分對應的數能全被選滿。

所以可以知道對於 \(f(l)\ge 8\) 的情況,只存在 \(f(r) - f(l) = 0 \operatorname{or} 1\) 的情況。

對於 \(f(l)\le 7\) 的情況。
能發現 \(r\) 的上界是 \(10^7 + \frac{10^8}{8}\),所以這部分直接雙指標就行了。

接下來考慮 \(f(l)\ge 8, f(r) - f(l) = 1\) 的情況。
考慮令 \(f(l)\) 選了 \(x\) 個,\(f(r) = f(l) + 1\) 選了 \(y\) 個。
需要先宣告的是,在這個位置先不考慮 \(x, y > 0\) 的限制,而認為 \(y\) 可以為 \(0\),關於這部分將會在後面提到。
那麼就能得到 \((x + y)f(l) + y = n\)
然後這裡以 \(x + y\) 為主元,能發現對於固定的 \(x + y\),其對應的 \(f(l)\)\(x, y\) 都只有 \(1\) 個,就是 \(\begin{cases}y = n\bmod (x + y)\\ x = (x + y) - y\\ f(l) = \frac{n - y}{x + y}\end{cases}\)

於是轉而去考慮 \(x + y\) 能有多少種可能。
這是好算的,因為 \(x + y\le \frac{n}{f(l)}\),而因為 \(f(l)\ge 8\),所以 \(x + y\le \frac{n}{8}\),對應的就有 \(\lfloor \frac{n}{8} \rfloor\) 種選法。

最後來考慮 \(f(l)\ge 8, f(r) - f(l) = 0\) 的情況。
這時候就已經確定了 \(f(l)\) 了,就相當於是已經知道需要的個數 \(x\) 了。
那麼顯然答案就為 \(9\times 10^{f(l) - 1} - x + 1\),但注意到,在前文的時候提到了誤給 \(y = 0\) 時加了 \(1\) 的貢獻,而其對應的應該就是這種 \(f(r) - f(l) = 0\) 的情況,所以實際算貢獻的時候應該算為 \(9\times 10^{f(l) - 1} - x\)
這部分可以 \(\mathcal{O}(\sqrt{n} + d(n)\log n)\)

時間複雜度 \(\mathcal{O}(\frac{n}{b} + \sqrt{n} + d(n)\log n)\),其中 \(B = 8\)\(d(n)\) 表示 \(n\) 的因子個數。

#include<bits/stdc++.h>
using ll = long long;
constexpr ll mod = 1e9 + 7;
inline ll qpow(ll a, ll b, ll v = 1) {
   while (b)
      b & 1 && (v = v * a % mod), b >>= 1, a = a * a % mod;
   return v;
}
ll ans;
const int limn = (int)1e7 + (int)1e8 / 8, R = (int)1e7;
int f[limn + 10];
int main() {
   for (int i = 1, l = 1, r = 10; i <= 8; i++, r = std::min(l * 10, limn))
      while (l < r) f[l++] = i;
   int n; scanf("%d", &n);
   for (int i = 1, j = 0, sum = 0; f[i] <= n && i < R; i++) {
      while (sum + f[j + 1] <= n) sum += f[++j];
      if (sum == n) (++ans) %= mod;
      sum -= f[i];
   }
   (ans += n / 8) %= mod;
   auto calc = [&](int fx) {
      int l = n / fx;
      ll len = 9ll * qpow(10, fx - 1) % mod;
      (ans += mod + len - l) %= mod;
   };
   for (int i = 1; i * i <= n; i++) if (n % i == 0) {
      if (i >= 8) calc(i);
      if (n / i >= 8 && i != n / i) calc(n / i);
   }
   printf("%lld\n", ans);
   return 0;
}

相關文章