JZOJ 6664. 【2020.05.28省選模擬】最最佳化

leiyuanze發表於2023-02-28

\(\text{Solution}\)

原題:\(\text{Honorable Mention}\)

一個費用流做法,\(S\)\(2i-1\) 連流量為 \(1\),費用為 \(0\) 的邊,\(2i\)\(T\) 連流量為 \(1\),費用為 \(0\) 的邊
\(2i-1\)\(2i\) 連流量為 \(1\),費用為 \(a_i\) 的邊。然後增廣 \(k\) 次即為答案

既然用了費用流模型那麼這個關於 \(k\) 的函式自然是凸函式
於是可以考慮一些最佳化
比如,多組詢問想到將區間拆成 \(O(\log n)\) 段線段樹上的區間,處理每個區間上的函式值,合併可以做到 \(O(n)\)
也就是要維護凸包,閔科夫斯基和,就可以 \(O(n\log n)\) 預處理凸包了
但這樣還是 \(O(nQ)\) 的,仍然是暴力
每個詢問有 \(O(\log n)\) 個凸包,合併代價很高
考慮 \(\text{WQS}\) 二分的威力,想想凸包合併時 \(f_{i+j}=f_i+f_j\),又點 \((i,f_i)\) 考慮成 \(f_i=ki+b_i,f_j=kj+b_j\)
那麼合併後的凸包 \((i+j,f_{i+j})\)\(f_{i+j}=k(i+j)+b_{i+j}\),也就是啟示我們 \(\text{WQS}\) 二分斜率後在每個凸包上找到對應斜率的值直接合並值,然後用 \(WQS\) 二分的方式算出答案
於是就做到 \(O(n\log n+Q\log V \log ^2n)\)

注意事項:

  1. \(\text{WQS}\) 一定要注意斜率變大或者變小會導致分的段數變多還是變少,同時二分寫法要和求最值的寫法相統一
    如本題寫了分的段數 \(\ge k\) 時更新答案,那麼求最值,值相等時優先取段數多的
  2. 傳參事項,傳 vector 時加個 & 就不會發生複製導致超時的問題了(因為這題某函式本身並不需要遍歷整個 vector,只要也只能 \(O(\log n)\) 做某些特定事情複雜度才對)

\(\text{Code}\)

#include <bits/stdc++.h> 
#define IN inline
#define eb emplace_back
#define LL long long
#define Vec vector<LL>
using namespace std;

template<typename Tp>
IN void read(Tp &x) {
    x = 0; char ch = getchar(); int f = 0;
    for(; !isdigit(ch); f |= (ch == '-'), ch = getchar());
    for(; isdigit(ch); x = (x<<3)+(x<<1)+(ch^48), ch = getchar());
    if (f) x = ~x + 1;
}

const int N = 35005;
const LL INF = 2e9;
int n, a[N];
LL pans[2], pcnt[2], tans[2], tcnt[2];

struct SegmentTree {
    #define ls (p << 1)
    #define rs (ls | 1)
    Vec tr[N << 2][2][2];
    
    IN Vec merge(Vec &a, Vec &b) {
    	if (a.empty() || b.empty()) return{};
    	Vec ret(a.size() + b.size() - 1, -INF);
    	int l = 0, r = 0; if (a[0] != -INF && b[0] != -INF) ret[0] = a[0] + b[0];
    	while (l < a.size() || r < b.size()) {
    		if (l >= a.size() - 1 && r >= b.size() - 1) break;
    		if (l == a.size() - 1) ++r; else if (r == b.size() - 1) ++l;
    		else if (a[l + 1] - a[l] > b[r + 1] - b[r]) ++l; else ++r;
    		if (a[l] != -INF && b[r] != -INF) ret[l + r] = a[l] + b[r];
        }
    	return ret;
    }
    IN void shift(Vec tmp, Vec &res) {
        for(int i = 1; i < tmp.size(); i++) res[i - 1] = max(res[i - 1], tmp[i]);
    }
    IN void pushup(int p) {
        for(int i = 0; i < 2; i++)
            for(int j = 0; j < 2; j++) {
                tr[p][i][j] = merge(tr[ls][i][0], tr[rs][0][j]);
                shift(merge(tr[ls][i][1], tr[rs][1][j]), tr[p][i][j]);
            }
    }
    void build(int p, int l, int r) {
        if (l == r) {
            tr[p][0][0] = {0, a[l]}, tr[p][0][1] = tr[p][1][0] = tr[p][1][1] = {-INF, a[l]};
            return;
        }
        int mid = l + r >> 1; build(ls, l, mid), build(rs, mid + 1, r), pushup(p);
    }
    
    IN void update(Vec &a, LL k, int x, int y) {
        int l = 1, r = a.size() - 1, mid = l + r >> 1, ret = 0;
        for(; l <= r; mid = l + r >> 1)
            if (a[mid] - a[mid - 1] >= k) ret = mid, l = mid + 1; else r = mid - 1;
        if (a[ret] == -INF) return;
        for(int i = 0; i < 2; i++) {
            LL w = tans[i] + a[ret] - k * ret;
            if (!x && pans[y] <= w) pans[y] = w, pcnt[y] = tcnt[i] + ret;
            w = tans[i] + a[ret] - k * (ret - i);
            if (x && (pans[y] < w || (pans[y] == w && tcnt[i] + ret - i > pcnt[y])))
                pans[y] = w, pcnt[y] = tcnt[i] + ret - i;
        }
    }
    IN void Merge(int p, LL k) {
        tcnt[0] = pcnt[0], tcnt[1] = pcnt[1], tans[0] = pans[0], tans[1] = pans[1];
        pans[1] = -INF, pans[0] = pcnt[0] = pcnt[1] = 0;
        for(int i = 0; i < 2; i++)
            for(int j = 0; j < 2; j++) update(tr[p][i][j], k, i, j);
    }
    
    void query(int p, int l, int r, int x, int y, LL k) {
        if (x <= l && r <= y) return Merge(p, k), void();
        int mid = l + r >> 1;
        if (x <= mid) query(ls, l, mid, x, y, k);
        if (y > mid) query(rs, mid + 1, r, x, y, k);
    }
}seg;

void Query(int L, int R, int k) {
    LL res = 0, l = -1e10, r = 1e10, mid = l + r >> 1;
    for(; l <= r; mid = l + r >> 1) {
        pans[1] = -INF, pans[0] = pcnt[0] = pcnt[1] = 0, seg.query(1, 1, n, L, R, mid);
        int z = (pans[0] > pans[1] ? 0 : 1);
        if (pcnt[z] >= k) res = pans[z] + mid * k, l = mid + 1; else r = mid - 1;
    }
    printf("%lld\n", res);
}

int main() {
    freopen("maximize.in", "r", stdin);
    freopen("maximize.out", "w", stdout);
    int q; read(n), read(q);
    for(int i = 1; i <= n; i++) read(a[i]);
    seg.build(1, 1, n);
    for(int i = 1, l, r, k; i <= q; i++) read(l), read(r), read(k), Query(l, r, k);
}

相關文章