BZOJ3518 點組計數

lalaouye發表於2024-06-24

莫比烏斯反演做法。

將橫線豎線斜線分開考慮。這題有個很好的特點,就是以四個角落為線段端點所涵蓋的的線的種類是最多的。那麼我們先將列舉的線的一個端點固定在左上角,其座標為 \((0,0)\),然後再列舉另外一個端點 \((i,j)\),那麼這種線在本題中的貢獻為 \((\gcd(i,j)-1)\times(n-i+1)\times(m-j+1)\),為什麼呢?因為這種線段在影像中有 \((n-i+1)(m-j+1)\) 個位置可以存放,並且中間的點有 \((\gcd(i,j)-1)\) 個位置可以選擇,所以我們將本題轉化為計算 \(\sum_{i=1}^n\sum_{j=1}^m(\gcd(i,j)-1)(n-i+1)(m-j+1)\)

我們將式子拆開,分別計算出

\[a=\sum_{i=1}^n\sum_{j=1}^m(\gcd(i,j)-1)(n+1)(m+1) \]

\[b=\sum_{i=1}^n\sum_{j=1}^mi(\gcd(i,j)-1)(m+1) \]

\[c=\sum_{i=1}^n\sum_{j=1}^mj(\gcd(i,j)-1)(n+1) \]

\[d=\sum_{i=1}^n\sum_{j=1}^m(\gcd(i,j)-1)ij \]

\[ans=a-b-c+d \]

即可。

計算可以利用莫比烏斯反演,時間複雜度 \(\mathcal{O}(n)\)

注意,因為我左上角座標為 \((0,0)\),故開始 \(n,m\) 都要減一。最後再加回去來計算橫豎線產生的貢獻。

程式碼:

#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = l; i <= r; ++ i)
#define rrp(i, l, r) for (int i = r; i >= l; -- i)
#define eb emplace_back
#define int long long
using namespace std;
constexpr int N = 5e4 + 5, P = 1e9 + 7;
inline int rd ()
{
  int x = 0, f = 1;
  char ch = getchar ();
  while (! isdigit (ch))
  {
    if (ch == '-') f = - 1;
    ch = getchar ();
  }
  while (isdigit (ch))
  {
    x = (x << 1) + (x << 3) + (ch ^ 48);
    ch = getchar ();
  }
  return x * f;
}
int qpow (int x, int y)
{
  int ret = 1;
  for (; y; y >>= 1, x = x * x % P) if (y & 1) ret = ret * x % P;
  return ret;
}
int T;
int mu[N], pri[N], cnt;
int sum[N], s1[N], s2[N], s3[N], s4[N], s5[N];
bool ip[N];
int func (int a, int b, int t)
{
  int n = min (a, b);
  int ret = 0;
  int l1, l2, r1, r2;
  l1 = l2 = 1;
  while (l1 <= n && l2 <= n)
  {
    int k1 = a / l1;
    r1 = min (a / k1, n);
    int k2 = b / l2;
    r2 = min (b / k2, n);
    if (t == 1)
    (ret += (sum[min (r1, r2)] - sum[max (l1, l2) - 1]) * k1 * k2) %= P;
    if (t == 2)
    (ret += (s4[min (r1, r2)] - s4[max (l1, l2) - 1]) * (k1 * (k1 + 1) / 2 % P) % P * k2) %= P;
    if (t == 3)
    (ret += (s4[min (r1, r2)] - s4[max (l1, l2) - 1]) * (k2 * (k2 + 1) / 2 % P) % P * k1) %= P;
    if (t == 4)
    (ret += (s5[min (r1, r2)] - s5[max (l1, l2) - 1]) * (k1 * (k1 + 1) / 2 % P) % P * (k2 * (k2 + 1) / 2 % P)) %= P;
    if (r1 < r2) l1 = r1 + 1; else l2 = r2 + 1;
  }
  return ret;
}
signed main ()
{
  mu[1] = 1;
  rep (i, 2, N - 1)
  {
    if (! ip[i])
    {
      pri[++ cnt] = i;
      mu[i] = -1;
    }
    for (int j = 1; j <= cnt && i * pri[j] <= N - 1; ++ j)
    {
      ip[i * pri[j]] = 1;
      if (i % pri[j] == 0)
      {
        mu[i * pri[j]] = 0;
        break;
      }
      mu[i * pri[j]] = - mu[i];
    }
  }
  rep (i, 0, N - 1) sum[i] = sum[i - 1] + mu[i];
  rep (i, 1, N - 1)
  	s1[i] = (s1[i - 1] + max (0ll, i - 1)) % P,
  	s2[i] = (s2[i - 1] + i * max (0ll, i - 1)) % P,
	s3[i] = (s3[i - 1] + i * i * max (0ll, i - 1)) % P,
	s4[i] = (s4[i - 1] + mu[i] * i) % P,
	s5[i] = (s5[i - 1] + mu[i] * i * i) % P;
  int n = rd () - 1, m = rd () - 1;
  if (n > m) swap (n, m);
  int l1 = 1, l2 = 1, ret = 0;
  while (l1 <= n && l2 <= n)
  {
    int r1 = n / (n / l1);
    int r2 = m / (m / l2);
    int l = max (l1, l2), r = min (r1, r2);
    (ret +=
	func (n / l, m / r, 1) * (m + 1) % P * (n + 1) % P * (s1[r] - s1[l - 1]) -
	func (n / l, m / r, 2) * (m + 1) % P * (s2[r] - s2[l - 1]) - 
	func (n / l, m / r, 3) * (n + 1) % P * (s2[r] - s2[l - 1]) + 
	func (n / l, m / r, 4) * (s3[r] - s3[l - 1])
	);
    if (r1 < r2) l1 = r1 + 1; else l2 = r2 + 1;
  }
  ret <<= 1;
  ++ n, ++ m;
  rep (i, 3, n) (ret += (i - 2) * m * (n - i + 1)) %= P;
  rep (i, 3, m) (ret += (i - 2) * n * (m - i + 1)) %= P;
  printf ("%lld\n", (ret + P) % P);
}
 

相關文章