DMOJ

Yaosicheng124發表於2024-09-29

B. Infinity Card Decks

題目描述

\(N\) 張牌,第 \(i\) 張牌打出需要 \(A_i\) 能量,獲得 \(B_i\) 能量。一開始你有 \(M\) 的能量。

如果一些牌,無論怎麼無限的按照隨機順序打出,都不會缺少能量,則我們稱這是一個無限牌組

求有多少個子區間是無限牌組。

思路

很容易想到,一個無限牌組必須滿足以下條件。

  1. \(\sum A_i \le \sum B_i\),因為如果不滿足該條件,那麼每經過一輪能量就會減少,所以最終一定會不夠。
  2. 在第一輪打出時不可能缺少能量。因為滿足條件 1,所以能量每過一輪都是單調不降的,所以第一輪的能量是最少的。

我們先來看條件 2:

  • 我們要列舉一張牌 \(i\),使得在出這張牌的時候能量不足。在這之前,肯定會把其他 \(A_j>B_j且i\ne j\)\(j\) 出掉。
    • 如果 \(A_i\le B_i\),那麼此時要滿足 \(M\ge A_i+\sum \max(0,A_j-B_j)\)
    • 否則如果 \(A_i>B_i\),那麼此時要滿足 \(M\ge A_i+(\sum \max(0,A_j-B_j)-(A_i-B_j))=B_i+\sum \max(0,A_j-B_j)\)
  • 上面兩式合起來就是 \(M\ge \min (A_i,B_i)+\sum \max(0,A_j-B_j)\)

所以一個區間 \([l,r]\) 要滿足 \(M\ge \max\limits_{i=l}^r\{\min (A_i,B_i)\}+\sum \limits_{i=l}^r \max(0,A_i-B_i)\)。這個使用雙指標求解。

接著我們考慮滿足條件 1,我們可以做一個 \(A_i-B_i\) 的字首和,離散化後用樹狀陣列統計數量即可。

空間複雜度 \(O(N)\),時間複雜度 \(O(N\log N)\)

程式碼

#include<bits/stdc++.h>
using namespace std;
using ll = long long;

const int MAXN = 1000005;

int n, m, a[MAXN], b[MAXN], log_2[MAXN];
ll ans, st[21][MAXN], pre[MAXN], tr[MAXN];
vector<ll> X;

int Getmax(int l, int r) {
  return max(st[log_2[r - l + 1]][l], st[log_2[r - l + 1]][r - (1 << log_2[r - l + 1]) + 1]);
}

int lowbit(int x) {
  return x & -x;
}

void update(int p, ll x) {
  for(; p <= n + 1; tr[p] += x, p += lowbit(p)) {
  }
}

ll Getsum(int p) {
  ll sum = 0;
  for(; p; sum += tr[p], p -= lowbit(p)) {
  }
  return sum;
}

int main() {
  ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
  cin >> n >> m;
  for(int i = 1; i <= n; ++i) {
    cin >> a[i];
  }
  for(int i = 1; i <= n; ++i) {
    cin >> b[i];
    pre[i] = pre[i - 1] + a[i] - b[i];
    X.emplace_back(pre[i]);
    st[0][i] = min(a[i], b[i]);
  }
  X.emplace_back(0);
  sort(X.begin(), X.end()), X.erase(unique(X.begin(), X.end()), X.end());
  for(int i = 0; i <= n; ++i) {
    pre[i] = lower_bound(X.begin(), X.end(), pre[i]) - X.begin() + 1;
  }
  for(int i = 1; i <= 20; ++i) {
    for(int j = 1; j <= n; ++j) {
      if(j + (1 << i) - 1 <= n) {
        st[i][j] = max(st[i - 1][j], st[i - 1][j + (1 << (i - 1))]);
      }
    }
  }
  for(int i = 2; i <= n; ++i) {
    log_2[i] = log_2[i / 2] + 1;
  }
  ll sum = 0;
  for(int i = 1, j = 1; i <= n; sum -= max(0, a[i] - b[i]), update(pre[i], -1), ++i) {
    for(; j <= n && Getmax(i, j) + sum + max(0, a[j] - b[j]) <= m; sum += max(0, a[j] - b[j]), update(pre[j], 1), ++j) {
    }
    ans += Getsum(pre[i - 1]);
    if(j == i) {
      sum += max(0, a[j] - b[j]), update(pre[j], 1), ++j;
    }
  }
  cout << ans;
  return 0;
}