根號分治
PS:本篇部落格題目分析及內容(除程式碼)均來自於paulzrm
根號分治,是暴力美學的集大成體現。與其說是一種演算法,我們不如稱它為一個常用的trick。
首先,我們引入一道入門題目 CF1207F Remainder Problem:
給你一個長度為 $5\times10^5$ 的序列,初值為 $0$ ,你要完成 $q$ 次操作,操作有如下兩種:
1 x y
: 將下標為 $x$ 的位置的值加上 $y$。2 x y
: 詢問所有下標模 $x$ 的結果為 $y$ 的位置的值之和。
考慮這題的暴力是什麼。
首先有一種暴力就是按照題目所說的去做,開一個 $5\times10^5$ 大小的陣列 $a$ 去存,$1$ 操作就對 $a_x$ 加上 $y$,$2$ 操作就列舉所有下標模 $x$ 的結果為 $y$ 的位置,統計他們的和。
對於這種暴力,$1$ 操作的時間複雜度為 $O(1)$,$2$ 操作的時間複雜度為 $O(n)$,所以在最壞情況下總時間複雜度可達 $O(nq)$。
經過思考,我們可以發現另外一種暴力:新開一個大小為 $n\times n$ 的二維陣列 $b$,$b_{i,j}$ 當前所有下標模 $i$ 的結果為 $j$ 的數的和是什麼。對於每個 $1$ 操作,動態的去維護這個 $b$ 陣列,在每次詢問的時候直接輸出答案即可。
對於這種暴力,$1$ 操作的時間複雜度是列舉模數的 $O(n)$ ,$2$ 操作的時間複雜度為 $O(1)$,總的時間複雜度為 $O(nq)$。
現在我們發現,這兩種暴力對應了兩種極端:一個是 $1$ 操作的時間複雜度為 $O(1)$,$2$ 操作的時間複雜度為 $O(n)$;另一個是 $1$ 操作的時間複雜度是列舉模數的 $O(n)$,$2$ 操作的時間複雜度為 $O(1)$。那麼,有沒有辦法讓這兩種暴力融合一下,均攤時間複雜度,達到一個平衡呢?
其實是有的。我們設定一個閾值 $b$。
對於所有 $\le b$ 的數,我們動態的維護暴力 $2$ 的 $b$ 陣列。每次 $1$ 操作只需要列舉 $b$ 個模數即可,故單次操作 $1$ 的時間複雜度降為 $O(b)$。
對於所有 $>b$ 的數,我們就不在操作 $1$ 中維護 $b$,直接再詢問答案時暴力列舉下標即可。顯然,這 $n$ 個下標中最多有 $\lceil \frac{n}{b}\rceil$ 個下標對 $x$ 取模餘 $y$ 找到第一個 $y$ 後每次跳 $x$,即可做到單次操作 $2$ 時間複雜度為 $O(\frac{n}{b})$。
所以,總時間複雜度就成為了 $O(q\times(b+\frac{n}{b}))$。由基本不等式可得,$b+\frac{n}{b} \geq 2\sqrt{b\times\frac{n}{b}}=2\sqrt{n}$,當 $b=\sqrt{n}$ 時取等。所以我們只需要讓 $b=\sqrt{n}$,就可以做到時間和空間複雜度均為 $O(q\sqrt{n})$ 的優秀演算法了,可以透過此題。
#include<bits/stdc++.h>
#define rint register int
#define endl '\n'
const int N = 8e2 + 5;
const int M = 5e5 + 5;
using namespace std;
int s[N][N], a[M];
signed main()
{
int q;
cin >> q;
int n = sqrt(500000);
while(q--)
{
int op, x , y;
cin >> op >> x >> y;
if (op == 1)
{
for(rint i = 1 ; i < n; i++)
{
s[i][x % i] += y;
}
a[x] += y;
}
if (op == 2)
{
if(x < n)
{
cout << s[x][y] << endl;
}
else
{
int res = 0;
for(rint i = y; i <= 500000; i += x)
{
res += a[i];
}
cout << res << endl;
}
}
}
return 0;
}
CF710D Two Arithmetic Progressions
題目大意:
現在有兩個等差數列,形如 $a_1k+b_1$ 和 $a_2k+b_2$,其中 $k$ 要滿足是自然數。現在再給你兩個正整數 $l,r$,求出 $[l,r]$ 間有多少個數同時出現在兩個等差數列中。資料滿足 $0<a_1,b_1\le2\times10^9,-2\times10*9\le b_1,b_2,l,r\le 2\times10^9,l\le r$。
題解:
正解要用到 exgcd 等數論知識,且細節較多比較麻煩。現在我們考慮如何用根號分治解決該數論問題。
現在欽定 $a_1\geq a_2$,再令 $t=\sqrt{2\times 10^9}$。
$a_1\le t$。此時 $a_2$ 也 $\le t$。由於每隔 $lcm(a_1,b_1)$ 就是一個迴圈節,且每個迴圈節只會有 $1$ 的貢獻,我們只需要找到第一個重合的數(或報告不存在),然後計算出迴圈節的個數就可以了。找到第一個重合的數,可以直接對著第一個等差數列從前往後跳,如果跳了 $a_2$ 次還是沒有出現,可以證明一定不存在了。
$a_1>t$。那麼有 $\frac{2\times10^9}{a_1}\le t$。也就是說,在 $[l,r]$ 這段區間內,屬於等差數列 $1$ 的數不會超過根號個。我們只需要列舉這個根號個數,依次判斷其是否在等差數列 $2$ 中即可。
#include <bits/stdc++.h>
#define rint register int
#define int long long
#define endl '\n'
using namespace std;
const int N = 2e9;
int a1, a2, b1, b2;
int l, r;
int n;
signed main()
{
n = sqrt(N);
cin >> a1 >> a2 >> b1 >> b2;
cin >> l >> r;
int m = max(a1, b1);
if (m <= n)
{
for (rint i = -m * 2; i <= m * 2; i++)
{
int p = i * a1 + a2;
if ((abs(b2 - p) % b1 == 0))
{
int k = __gcd(a1, b1);
int lcm = a1 * b1 / k;
int begin = max(max(a2, b2), l);
if (p > begin)
{
p = begin + (p - begin) % lcm;
}
else
{
p += ((begin - p) / lcm + 1) * lcm;
p = begin + (p - begin) % lcm;
}
if (r < p)
{
continue;
}
cout << (r - p) / lcm + 1 << endl;
return 0;
}
}
cout << 0 << endl;
return 0;
}
else
{
if (a1 < b1)
{
swap(a1, b1);
swap(a2, b2);
}
int cnt = 0;
for (rint i = a2; i <= r; i += a1)
{
if (i >= l && i >= b2)
{
if ((i - b2) % b1 == 0)
{
cnt++;
}
}
}
cout << cnt << endl;;
}
return 0;
}
[ARC052D] 9
題目大意:
給定兩個正整數 $K,M (1\le K,M \le 10^{10})$,你需要求出有多少個正整數 $N$ 滿足 $1 \le N \le M$ 且 $N \equiv S_N (\mod K) $,其中 $S_N$ 是 $N$ 的各位數字之和。
題解:
這個 $10^{10}$ 的資料範圍並不常見,但是可以發現大概是根號的複雜度。
顯然無法分塊,考慮怎麼做到根號分治。我們先對 $K$ 設定一個閾值 $T$,其中 $T$ 是 $\sqrt{M}$ 級別。
- $K \ge T$
當 $1 \le N \le 10^{10}$ 時,最大的 $S_N$ 不過 $9\times 10=90$,所以我們可以去先列舉數字和 $S$,然後就可以發現,$\mod K= S$ 的 $N$ 的個數不會超過 $\lfloor\frac{M}{K}\rfloor +1$ 個。直接列舉這些數就可以了。複雜度 $O(90\times \frac{M}{K})$。
- $K \le T$
我們可以考慮把 $K$ 做為一維壓到數位 dp 裡了。令 $dp_{i,j,sm,0/1}$ 表示考慮到從高到低第 $i$ 位,此時的數 $\mod K = j$,數字和為 $sm$,是否已經小於 $m$ 的數的個數。這樣就可以 dp 了,複雜度 $O(10\times90\times K\times 10)=O(9000K)$。
#include <bits/stdc++.h>
#define rint register int
#define int long long
#define endl '\n'
using namespace std;
const int N = 1e4 + 5;
const int M = 1e8;
const int K = 1e2 - 10;
const int W = 1e1 + 2;
int k, m, n;
int len;
int a[W];
int f[W][N][K][2];
int ans;
signed main()
{
cin >> k >> m;
len = sqrt(M);
int tool = m;
for (rint i = 1; ;i++)
{
a[i] = tool % 10;
tool /= 10;
if (tool == 0)
{
n = i;
break;
}
}
reverse(a + 1, a + n + 1);
if (k >= len)
{
for (rint i = 0; i <= K; i++)
{
for (rint j = i; j <= m; j += k)
{
int t = j;
int cnt = 0;
while (t)
{
cnt += t % 10;
t /= 10;
}
if (cnt % k == i)
{
ans++;
}
}
}
cout << ans - 1 << endl;
return 0;
}
f[1][0][0][0] = 1;
for (rint i = 1; i <= n; i++)
{
for (rint j = 0; j < k; j++)
{
for (rint o = 0; o <= K; o++)
{
for (rint t = 0; t < a[i]; t++)
{
if (o + t <= K)
{
f[i + 1][(j * 10 + t) % k][o + t][1] += f[i][j][o][0];
}
}
if (o + a[i] <= K)
{
f[i + 1][(j * 10 + a[i]) % k][o + a[i]][0] += f[i][j][o][0];
}
for (rint t = 0; t < 10; t++)
{
if (o + t <= K)
{
f[i + 1][(j * 10 + t) % k][o + t][1] += f[i][j][o][1];
}
}
}
}
}
for (rint i = 0; i <= K; i++)
{
ans += f[n + 1][i % k][i][0];
ans += f[n + 1][i % k][i][1];
}
cout << ans - 1 << endl;
return 0;
}