題目連結:Becoder or Luogu
首先我們可以先把點給縮一縮,把連續的正數點和連續的負數點分別縮成一個點,比如 1 2 3 -1 -1 1 2
這個東西我們就可以將其縮成 6 -2 3
我們可以發現,求前者的值等於求後者的值,我們就將原序列變為了正負交替的序列。
然後我們就可以開始反悔貪心,將所有數的點全部丟進小根堆裡,小根堆的權值是這個點的絕對值,將所有正數的點暫時全部加進答案 \(ans\) 並用一個數記錄有多少個了,如果這個數大於了 \(m\) 我們就開始反悔貪心(這裡可能有些人不明白,不過到後面就會明白了,先記著就行。
當一個點是負數,而且它在角落,即左右兩邊有沒有的點,那麼我們就可以將其拋棄掉不管它了,反之,那麼我們可以分類討論一下:
-
這個點是負數:那麼它和它左右兩邊的點可以是這個形式“正負正”那我們就將這一組合並起來,那正數的點的個數就會減去 \(1\) 答案還得加上這個負點,然後我們在把這個新的點丟進堆裡面去。
-
這個點是正點:那麼它和它左右兩邊的點可以是這個形式“負正負”那我們就將這一組合並起來,那正數的點的個數就會減去 \(1\) 答案還得減去上這個正點,然後我們在把這個新的點丟進堆裡面去。
-
如果這個點是正點,且在最左邊或者最右邊:那麼可以成這個形式“正負”“負正”,很明顯我們還是要將其合併起來,然後減去它,再將新點扔進堆裡面去,目前選中的點也減去 \(1\)。
整理一下我們就可以變為,如果這個點滿足:是一個負數,左右兩邊點不全,我們就刪除它。如果不滿足:答案減去它的絕對值,個數減 \(1\) 再合併成一個新點,最後丟進堆裡面。
如果還不明白,就結合著程式碼吧。
#include <queue>
#include <cstring>
#include <algorithm>
#include <iostream>
#define x first
#define y second
using namespace std;
namespace oi{
using ll = long long;
using ull = unsigned long long;
using pii = pair<int, int>;
using db = double;
using pll = pair<ll, ll>;
#define endl '\n'
inline ll read() {
char ch = getchar(); ll fu = 0, s = 0;
while(!isdigit(ch)) fu |= (ch == '-'), ch = getchar();
while(isdigit(ch)) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar();
return fu ? -s : s;
}
template <typename T>
inline void write(T x, char ch) {
if(x < 0) putchar('-'), x = -x;
static int stk[30];
int tt = 0;
do stk[++tt] = x % 10, x /= 10; while(x);
while(tt) putchar(stk[tt--] ^ 48);
putchar(ch);
}
template <typename T>
inline void write(T x) {
if(x < 0) putchar('-'), x = -x;
static int stk[30];
int tt = 0;
do stk[++tt] = x % 10, x /= 10; while(x);
while(tt) putchar(stk[tt--] ^ 48);
}
inline void write(char x) {putchar(x);}
};
using namespace oi;
const int MAXN = 1e5 + 10;
int n, a[MAXN], l[MAXN], r[MAXN], m;
int x, cnt;
ll ans;
bool st[MAXN];
void del(int x) {
st[x] = true;
r[l[x]] = r[x];
l[r[x]] = l[x];
}
void solve() {
n = read(), m = read();
for (int i = 1; i <= n; i++) {
x = read(); if (!x) continue;
if (1ll * x * a[cnt] > 0) a[cnt] += x;
else a[++cnt] = x;
}
n = cnt;
cnt = 0;
priority_queue<pii, vector<pii>, greater<pii>> q;
for (int i = 1; i <= n; i++) {
if (a[i] > 0) cnt++, ans += a[i];
l[i] = i - 1, r[i] = i + 1;
q.push({abs(a[i]), i});
}
while (cnt > m) {
while (st[q.top().y]) q.pop();
pii t = q.top(); q.pop();
int x = t.y;
if (l[x] != 0 && r[x] != n + 1 || a[x] > 0) {
cnt--; ans -= abs(a[x]);
a[x] += a[l[x]] + a[r[x]];
del(l[x]), del(r[x]);
q.push({abs(a[x]), x});
} else {
del(x);
}
}
write(ans, '\n');
}
signed main() {
// freopen("test.in", "r", stdin);
// freopen("test.out", "w", stdout);
int T = 1;
// T = read();
while(T--) solve();
return 0;
}