[暴力 Trick] 根號分治

PassName發表於2024-06-22

根號分治

PS:本篇部落格題目分析及內容(除程式碼)均來自於paulzrm

根號分治,是暴力美學的集大成體現。與其說是一種演算法,我們不如稱它為一個常用的trick。

首先,我們引入一道入門題目 CF1207F Remainder Problem

給你一個長度為 $5\times10^5$ 的序列,初值為 $0$ ,你要完成 $q$ 次操作,操作有如下兩種:

  1. 1 x y: 將下標為 $x$ 的位置的值加上 $y$。
  2. 2 x y: 詢問所有下標模 $x$ 的結果為 $y$ 的位置的值之和。

考慮這題的暴力是什麼。

首先有一種暴力就是按照題目所說的去做,開一個 $5\times10^5$ 大小的陣列 $a$ 去存,$1$ 操作就對 $a_x$ 加上 $y$,$2$ 操作就列舉所有下標模 $x$ 的結果為 $y$ 的位置,統計他們的和。

對於這種暴力,$1$ 操作的時間複雜度為 $O(1)$,$2$ 操作的時間複雜度為 $O(n)$,所以在最壞情況下總時間複雜度可達 $O(nq)$。

經過思考,我們可以發現另外一種暴力:新開一個大小為 $n\times n$ 的二維陣列 $b$,$b_{i,j}$ 當前所有下標模 $i$ 的結果為 $j$ 的數的和是什麼。對於每個 $1$ 操作,動態的去維護這個 $b$ 陣列,在每次詢問的時候直接輸出答案即可。

對於這種暴力,$1$ 操作的時間複雜度是列舉模數的 $O(n)$ ,$2$ 操作的時間複雜度為 $O(1)$,總的時間複雜度為 $O(nq)$。

現在我們發現,這兩種暴力對應了兩種極端:一個是 $1$ 操作的時間複雜度為 $O(1)$,$2$ 操作的時間複雜度為 $O(n)$;另一個是 $1$ 操作的時間複雜度是列舉模數的 $O(n)$,$2$ 操作的時間複雜度為 $O(1)$。那麼,有沒有辦法讓這兩種暴力融合一下,均攤時間複雜度,達到一個平衡呢?

其實是有的。我們設定一個閾值 $b$。

對於所有 $\le b$ 的數,我們動態的維護暴力 $2$ 的 $b$ 陣列。每次 $1$ 操作只需要列舉 $b$ 個模數即可,故單次操作 $1$ 的時間複雜度降為 $O(b)$。

對於所有 $>b$ 的數,我們就不在操作 $1$ 中維護 $b$,直接再詢問答案時暴力列舉下標即可。顯然,這 $n$ 個下標中最多有 $\lceil \frac{n}{b}\rceil$ 個下標對 $x$ 取模餘 $y$ 找到第一個 $y$ 後每次跳 $x$,即可做到單次操作 $2$ 時間複雜度為 $O(\frac{n}{b})$。

所以,總時間複雜度就成為了 $O(q\times(b+\frac{n}{b}))$。由基本不等式可得,$b+\frac{n}{b} \geq 2\sqrt{b\times\frac{n}{b}}=2\sqrt{n}$,當 $b=\sqrt{n}$ 時取等。所以我們只需要讓 $b=\sqrt{n}$,就可以做到時間和空間複雜度均為 $O(q\sqrt{n})$ 的優秀演算法了,可以透過此題。

#include<bits/stdc++.h>

#define rint register int
#define endl '\n'

const int N = 8e2 + 5;
const int M = 5e5 + 5;

using namespace std;

int s[N][N], a[M];  

signed main()
{
    int q;
	cin >> q;
	
	int n = sqrt(500000);
	
    while(q--)
	{
        int op, x , y;
		cin >> op >> x >> y;
        if (op == 1)
		{  
            for(rint i = 1 ; i < n; i++)
			{
				s[i][x % i] += y; 
			}  
            a[x] += y;
        }
		if (op == 2)
		{
            if(x < n)
			{
                cout << s[x][y] << endl;
            }
			else
			{
                int res = 0;
                for(rint i = y; i <= 500000; i += x)
				{
				    res += a[i];	
				} 
                cout << res << endl;
            }
        }
    }
    
    return 0;
}

CF710D Two Arithmetic Progressions

題目大意:

現在有兩個等差數列,形如 $a_1k+b_1$ 和 $a_2k+b_2$,其中 $k$ 要滿足是自然數。現在再給你兩個正整數 $l,r$,求出 $[l,r]$ 間有多少個數同時出現在兩個等差數列中。資料滿足 $0<a_1,b_1\le2\times10^9,-2\times10*9\le b_1,b_2,l,r\le 2\times10^9,l\le r$。

題解:

正解要用到 exgcd 等數論知識,且細節較多比較麻煩。現在我們考慮如何用根號分治解決該數論問題。

現在欽定 $a_1\geq a_2$,再令 $t=\sqrt{2\times 10^9}$。

$a_1\le t$。此時 $a_2$ 也 $\le t$。由於每隔 $lcm(a_1,b_1)$ 就是一個迴圈節,且每個迴圈節只會有 $1$ 的貢獻,我們只需要找到第一個重合的數(或報告不存在),然後計算出迴圈節的個數就可以了。找到第一個重合的數,可以直接對著第一個等差數列從前往後跳,如果跳了 $a_2$ 次還是沒有出現,可以證明一定不存在了。

$a_1>t$。那麼有 $\frac{2\times10^9}{a_1}\le t$。也就是說,在 $[l,r]$ 這段區間內,屬於等差數列 $1$ 的數不會超過根號個。我們只需要列舉這個根號個數,依次判斷其是否在等差數列 $2$ 中即可。

#include <bits/stdc++.h>

#define rint register int
#define int long long
#define endl '\n'

using namespace std;

const int N = 2e9;

int a1, a2, b1, b2;
int l, r;
int n;

signed main() 
{
    n = sqrt(N);
	
	cin >> a1 >> a2 >> b1 >> b2;
	cin >> l >> r;
    
    int m = max(a1, b1);

    if (m <= n) 
	{ 
        for (rint i = -m * 2; i <= m * 2; i++) 
		{
            int p = i * a1 + a2;
            if ((abs(b2 - p) % b1 == 0)) 
			{
                int k = __gcd(a1, b1);
                int lcm = a1 * b1 / k;
                int begin = max(max(a2, b2), l);
                if (p > begin) 
				{
                    p = begin + (p - begin) % lcm;
                } 
				else 
				{
                    p += ((begin - p) / lcm + 1) * lcm;
                    p = begin + (p - begin) % lcm;
                }
                if (r < p)
                {
					continue;
				}
                cout << (r - p) / lcm + 1 << endl;
                return 0;
            }
        }
        cout << 0 << endl;
        return 0;
    } 
    
	else 
	{
        if (a1 < b1)
        {
            swap(a1, b1);
			swap(a2, b2);			
		}
        int cnt = 0;
        for (rint i = a2; i <= r; i += a1) 
		{ 
            if (i >= l && i >= b2) 
			{
                if ((i - b2) % b1 == 0)
                {
					cnt++;
				}
            }
        }

        cout << cnt << endl;;
    }
    
    return 0;
}

[ARC052D] 9

題目大意:

給定兩個正整數 $K,M (1\le K,M \le 10^{10})$,你需要求出有多少個正整數 $N$ 滿足 $1 \le N \le M$ 且 $N \equiv S_N (\mod K) $,其中 $S_N$ 是 $N$ 的各位數字之和。

題解:

這個 $10^{10}$ 的資料範圍並不常見,但是可以發現大概是根號的複雜度。

顯然無法分塊,考慮怎麼做到根號分治。我們先對 $K$ 設定一個閾值 $T$,其中 $T$ 是 $\sqrt{M}$ 級別。

  • $K \ge T$

當 $1 \le N \le 10^{10}$ 時,最大的 $S_N$ 不過 $9\times 10=90$,所以我們可以去先列舉數字和 $S$,然後就可以發現,$\mod K= S$ 的 $N$ 的個數不會超過 $\lfloor\frac{M}{K}\rfloor +1$ 個。直接列舉這些數就可以了。複雜度 $O(90\times \frac{M}{K})$。

  • $K \le T$

我們可以考慮把 $K$ 做為一維壓到數位 dp 裡了。令 $dp_{i,j,sm,0/1}$ 表示考慮到從高到低第 $i$ 位,此時的數 $\mod K = j$,數字和為 $sm$,是否已經小於 $m$ 的數的個數。這樣就可以 dp 了,複雜度 $O(10\times90\times K\times 10)=O(9000K)$。

#include <bits/stdc++.h>

#define rint register int
#define int long long
#define endl '\n'

using namespace std;

const int N = 1e4 + 5;
const int M = 1e8;
const int K = 1e2 - 10;
const int W = 1e1 + 2;

int k, m, n;
int len;
int a[W];
int f[W][N][K][2];
int ans;

signed main() 
{
    cin >> k >> m;
    
    len = sqrt(M);
    
    int tool = m;

    for (rint i = 1; ;i++) 
	{
        a[i] = tool % 10;
        tool /= 10;
        if (tool == 0) 
		{
            n = i;
            break;
        }
    }

    reverse(a + 1, a + n + 1);

    if (k >= len) 
	{
        for (rint i = 0; i <= K; i++) 
		{
            for (rint j = i; j <= m; j += k) 
			{
                int t = j;
				int cnt = 0;
                while (t) 
				{
                    cnt += t % 10;
                    t /= 10;
                }
                if (cnt % k == i)
                {
					ans++;
				}
            }
        }

        cout << ans - 1 << endl;
        return 0;
    }

    f[1][0][0][0] = 1;

    for (rint i = 1; i <= n; i++) 
	{
        for (rint j = 0; j < k; j++) 
		{
            for (rint o = 0; o <= K; o++) 
			{
                for (rint t = 0; t < a[i]; t++) 
				{
                    if (o + t <= K) 
					{
                        f[i + 1][(j * 10 + t) % k][o + t][1] += f[i][j][o][0];
                    }
                }
				if (o + a[i] <= K)
				{
                    f[i + 1][(j * 10 + a[i]) % k][o + a[i]][0] += f[i][j][o][0];					
				}            
                for (rint t = 0; t < 10; t++)
                {
					if (o + t <= K)
					{
                        f[i + 1][(j * 10 + t) % k][o + t][1] += f[i][j][o][1];						
					}
				}

            }
        }
    }

    for (rint i = 0; i <= K; i++)
    {
        ans += f[n + 1][i % k][i][0]; 
		ans += f[n + 1][i % k][i][1];		
	}

    cout << ans - 1 << endl;
    
    return 0;
}