中國剩餘定理(個人筆記)

ocean__ocean發表於2020-11-06

前序

已知a和x0,且x % a ≡ x0, 則x = (x0 % a + a) % a


1. 擴充套件歐幾里德演算法

求出a與b的最大公因數d的同時,還能解出ax+by=gcd(a,b)中(x,y)的一組特解

int exgcd(int a, int b, int &x, int &y){
    if(b == 0) {x = 1, y = 0; return a; }
    int d = exgcd(b, a % b, x, y);
    int tmp = x;
    x = y, y = tmp - a / b * y;
    return d;
}

這裡如果將x與y交換去呼叫exgcd函式,就可以得到以下程式碼

int exgcd(int a, int b, int &x, int &y){
    if(b == 0) { x = 1, y = 0; return a; }
    int d = exgcd(b, a % b, y, x);
    y -= a / b * x;
    return d;
}

解得的x和y為它們的一組特解(x0, y0),它們的通解為x = x0 + b / gcd(a,b), y = y0 - a / gcd(a,b)


2.同餘方程

給出a,b和m,求滿足a ∗ x ≡ b (mod m)的x,且x為最小正整數解

變形 => a * x = k * m + b, 即 a * x - k * m = b

只要滿足b % gcd(a, m) == 0, 那麼就一定存在解x

直接帶入exgcd函式即可求出一組特解(x0, k0)滿足a * x - k * m = gcd(a, m),所以x1 = x0 * b / gcd(a, m)的值就滿足a * x - k * m = b

為了得到最小正整數解,只需令x2 = (x1 % m + m) % m,那麼x2就是最終的結果


3.中國剩餘定理

給出n組式子x ≡ mi (mod ai)中的mi,ai,求出一個最小正整數x滿足這n組式子

取其中兩個式子,可以得到

x = k1 * a1 + m1
x = k2 * a2 + m2

**那麼整理得到 **

a1 * k1 - a2 * k2 = m2 - m1

(此時需要判斷是否存在解!!!也就是說,(m2 - m1) % gcd(a1, a2) == 0必須成立才能有解)

*這樣,通過呼叫exgcd函式就可以得到一個特解k1 = (m2 - m1) / gcd(a1, a2),通解為k1 + k * a2 / gcd(a1, a2), 同理k2的通解為 k2 + k * a1 / gcd(a1,a2)(因為上邊的式子為減號,所以這裡的兩個通解全為加k倍,而不是一加一減),那麼

t = a2 / gcd(a1, a2)
x = [k1 + k * t] * a1 + m1
  = a1 * t * k + a1 * k1 + m1

為了讓x變得更小,這裡我們可以讓k1 + k * t = (k1 % t + t) % t

此時我們可以將a1 * t當作新的a將a1 * k1 + m1當作新的m,即m1 = a1 * k1 + m1, a1 = a1 * t,也滿足x = k1 * a1 + m1的格式,就將x = k1 * a1 + m1 和 x = k2 * a2 + m2合併成了一個式子

同理,最後可以將n個式子合併成一個式子,只剩下更改後的a1和m1,那麼答案x = (m1 % a1 + a1) % a1

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;

ll exgcd(ll a, ll b, ll& x, ll& y) {
	if (b == 0) { x = 1, y = 0; return a; }
	ll d = exgcd(b, a % b, y, x);
	y -= a / b * x;
	return d;
}

int main(void)
{
	bool flag = true;
	int n; cin >> n;
	ll a1, m1; cin >> a1 >> m1;
	for (int i = 1; i < n; ++i) {
		ll a2, m2; cin >> a2 >> m2;

		ll k1, k2;
		ll d = exgcd(a1, -a2, k1, k2);

		if ((m2 - m1) % d != 0) { flag = false; break; }

		k1 *= (m2 - m1) / d;
		ll t = a2 / d;
		k1 = (k1 % t + t) % t;

		m1 = a1 * k1 + m1;
		a1 = abs(a1 * t);
	}
	if (flag) cout << (m1 % a1 + a1) % a1 << endl;
	else cout << "-1\n";
	return 0;
}

END

相關文章