D2. Reverse Card (Hard Version)
The two versions are different problems. You may want to read both versions. You can make hacks only if both versions are solved.
You are given two positive integers $n$, $m$.
Calculate the number of ordered pairs $(a, b)$ satisfying the following conditions:
- $1\le a\le n$, $1\le b\le m$;
- $b \cdot \gcd(a,b)$ is a multiple of $a+b$.
Input
Each test contains multiple test cases. The first line contains the number of test cases $t$ ($1\le t\le 10^4$). The description of the test cases follows.
The first line of each test case contains two integers $n$, $m$ ($1\le n,m\le 2 \cdot 10^6$).
It is guaranteed that neither the sum of $n$ nor the sum of $m$ over all test cases exceeds $2 \cdot 10^6$.
Output
For each test case, print a single integer: the number of valid pairs.
Example
input
6
1 1
2 3
3 5
10 8
100 1233
1000000 1145141
output
0
1
1
6
423
5933961
Note
In the first test case, no pair satisfies the conditions.
In the fourth test case, $(2,2),(3,6),(4,4),(6,3),(6,6),(8,8)$ satisfy the conditions.
解題思路
令 $d = \gcd(a, b)$,則有 $a = p \cdot d$,$b = q \cdot d$,且 $\gcd(p,q) = 1$。對於 $(a+b) \mid b \cdot d$,等價於 $(p+q) \mid q \cdot d$。又因為 $\gcd(q, p + q) = \gcd(q, p) = 1$,因此有 $(p+q) \mid d$。
因此 $(a+b) \mid b \cdot \gcd(a, b)$ 等價於 $(p+q) \mid q \cdot d$。
此時考慮能否列舉數對 $(p,q)$。事實上 $p < d = \frac{n}{p} \Longrightarrow p^2 < n$,同理 $q^2 < m$,因此這樣的數對最多有 $O(\sqrt{n} \sqrt{m})$。所以我們直接暴力列舉找到滿足 $\gcd(p,q) = 1$ 的 $p$ 和 $q$,由於此時 $d$ 最大能取到 $\min\left\{ \left\lfloor\frac{n}{p}\right\rfloor, \left\lfloor\frac{m}{q}\right\rfloor \right\}$,因此滿足 $(p+q) \mid q \cdot d$ 對應的 $(a,b)$ 的數量就是 $\left\lfloor\frac{\min\left\{ \left\lfloor\frac{n}{p}\right\rfloor, \left\lfloor\frac{m}{q}\right\rfloor \right\}}{p+q}\right\rfloor$。
AC 程式碼如下,時間複雜度為 $O\left(\sqrt{n} \sqrt{m} \, \log{\max\{n,m\}}\right)$:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
void solve() {
int n, m;
scanf("%d %d", &n, &m);
LL ret = 0;
for (int i = 1; i * i < n; i++) {
for (int j = 1; j * j < m; j++) {
if (__gcd(i, j) == 1) ret += min(n / i, m / j) / (i + j);
}
}
printf("%lld\n", ret);
}
int main() {
int t;
scanf("%d", &t);
while (t--) {
solve();
}
return 0;
}
參考資料
Codeforces Round 942 (Div. 1, Div. 2) Editorial:https://codeforces.com/blog/entry/129027