BZOJ2956: 模積和(數論分塊)

自為風月馬前卒發表於2019-02-07

題意

題目連結

Sol

啊啊這題好惡心啊,推的時候一堆細節qwq

(a \% i = a – frac{a}{i} * i)

把所有的都展開,直接分塊。關鍵是那個(i
ot= j)
的地方需要減。。。。

然後就慢慢寫就好了

#include<bits/stdc++.h>
#define Pair pair<int, int>
#define MP(x, y) make_pair(x, y)
#define fi first
#define se second
#define int long long
#define LL long long
#define Fin(x) {freopen(#x".in","r",stdin);}
#define Fout(x) {freopen(#x".out","w",stdout);}
using namespace std;
const int MAXN = 1e6 + 10, mod = 19940417, INF = 1e9 + 10;
const double eps = 1e-9;
template <typename A, typename B> inline LL add(A x, B y) {
    if(x + y < 0) return x + y + mod;
    return x + y >= mod ? x + y - mod : x + y;
}
template <typename A, typename B> inline void add2(A &x, B y) {
    if(x + y < 0) x = x + y + mod;
    else x = (x + y >= mod ? x + y - mod : x + y);
}
template <typename A, typename B> inline LL mul(A x, B y) {
    x = (x + mod) % mod;
    y = (y + mod) % mod;
    return 1ll * x * y % mod;
}
template <typename A> inline LL sqr(A x) {
    return 1ll * x * x;
}
int N, M, a, b;
int sum(int l, int r) {
    if(l == r) return l;
    int n = r - l + 1;
    if(n & 1) return add(mul(l, n), mul(n, (n - 1) / 2));
    else return add(mul(l, n), mul(n / 2, n - 1));
}
int calc(int n) {
    int ret = 0;
    for(int i = 1, j; i <= n; i = j + 1) {
        j = n / (n / i);
        add2(ret, mul(n / j, sum(i, j)));
    }
    return ret;
}
int get(int x) {
    int a = x, b = 2 * x + 1, c = x + 1;
    if(a % 2 == 0) a /= 2;
    else if(b % 2 == 0) b /= 2;
    else if(c % 2 == 0) c /= 2;
    if(a % 3 == 0) a /= 3;
    else if(b % 3 == 0) b /= 3;
    else if(c % 3 == 0) c /= 3;
    return mul(mul(a, b), c);
}
int fuck2(int i, int j) {//sum k^2
    return add(get(j), -get(i - 1));
}
int calc2() {
    int ret = 0;
    for(int i = 1, j; i <= N; i = j + 1) {
        j = min(M / (M / i), N / (N / i));
        int a = M / i, b = N / i;
        add2(ret, add(add(mul(N, mul(a, sum(i, j))), mul(M, mul(b, sum(i, j)))), -mul(mul(a, b), fuck2(i, j))));
    }
    return ret;
}
signed main() {
    cin >> N >> M;
    if(N > M) swap(N, M);
    a = calc(N);
    b = calc(M);
    int ans = mul(add(mul(N, N), -a), add(mul(M, M), -b));
    add2(ans, -mul(N, mul(N, M)));
    add2(ans, calc2());
    cout << ans;
    return 0;
}

相關文章