P2672 NOIP2015 普及組 推銷員

blind5883發表於2024-10-20

P2672 [NOIP2015 普及組] 推銷員 - 洛谷 | 電腦科學教育新生態 (luogu.com.cn)

我還是相信,大部分人是想不出貪心的。

時間複雜度 \(O(n\log n)\) 但是常數極大,運用線段樹,這題資料過水,甚至我一個寫錯了的線段樹都能拿滿(除了#3hack)。非貪心。

首先按距離大小分類,並在每一類裡進行排序,這裡使用vector實現,為了方便彈出我們從小到大排,這樣我們只需要統計每一類的末尾值就能統計出最大值。

對於選 i 個人,就相當於我先選 i - 1 個人,再選一個人,我們要求選 1 ~ n 個人的結果,就可以利用上一個的結果,只需要選出現在最大的這個人即可。

因為有路程的問題,這我們也算在疲勞值中,那在選完一個人 u 之後就會出現幾種情況,設 r 為 u 所在的距離一類,last 為當前最大距離的類;

  1. 如果所剩距離疲勞值 \(last \ge r\) ,那麼選了 r 將不影響任何類的路程疲勞值;
  2. 如果所剩距離疲勞值 \(r > last\),那麼對於已經被處理過路程疲勞的類即 1~last,不用再處理,對於 last + 1 ~ r - 1 則直接把所剩路程疲勞值清零,對於 r ~ n 我們直接減去當前 r 所剩的距離疲勞值。
    每次詢問句詢問區間最大值,儲存這個最大值的來源 r,然後讓這個類 r,彈出最後一個,讓新的加入區間,最大值加上之前的 sum,就是當前答案。

按上面進行,需要實現區間覆蓋,區間加,求區間最值。實際上區間覆蓋可以用區間加代替。

而實際上我實現思路不同,我們假設第一個選了 r,如果把所有數距離疲勞都減去 \(S_r\),那麼對於 1 ~ r - 1 還需要加上 \(S[r] - S[j]\),因為它沒有那麼多距離。如果只看相對大小的話,那麼就相當於 1~r - 1 加上了 \(S[r] - S[j]\),照這個思路很快就能想出上面的分類。而有所不同:

  1. 如果所剩距離疲勞值 \(last \ge r\) ,那麼選了 r 將不影響任何類的路程疲勞值;
  2. 如果所剩距離疲勞值 \(r > last\),對於已經被處理過路程疲勞的類即 1~last,因為不需要繼續減,就相當於加上 \(S[r]\)。對於 last + 1 ~ r - 1 則相當於加上 \(S[r]- S[j]\),上面有解釋。我們只看相對大小,而為了方便直接算出答案,還是把實際影響加進去即:全部減去 \(S[r]\)
    對於每一類新加入值,就相當於把線段樹對應值減去當前最大maxv(即這個點的值),再加上新點值即可,別忘了 pop_back()。

按照上面進行線段樹,即可。

感覺還是寫麻煩了。


#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#include <queue>

using namespace std;

const int N = 100010, M = N * 8;

int n, m, cnt;
vector<int> s[N];
int id[N];
int f[N];
int S[N];

struct Nodes
{
    int l, r;
    int id, maxv;
    int add;
}tr[M];

struct Node 
{
    int s, a;
    bool operator<(const Node &W)const
    {
        return s < W.s;
    }
}g[N];

void print(int A)
{
    Nodes u = tr[A];
    printf("%d %d %d %d %d %d %d\n", A, u.l, u.r, u.id, u.maxv, u.add);
    puts("");
}

void pushup(Nodes &u, Nodes &l, Nodes &r)
{
    u.l = l.l;
    u.r = r.r;
    if (l.maxv > r.maxv) 
    {
        u.maxv = l.maxv;
        u.id = l.id;
    }
    else
    {
        u.maxv = r.maxv;
        u.id = r.id;
    }
}

void pushup(int u)
{
    pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

void pushdown(Nodes &u, Nodes &l, Nodes &r)
{
    l.add += u.add;
    r.add += u.add;
    l.maxv += u.add;
    r.maxv += u.add;
    u.add = 0;
}

void pushdown(int u)
{
    pushdown(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

void build(int u, int l, int r)
{
    if (l == r) tr[u] = {l, r, l, s[l].back() + S[l], 0};
    else 
    {
        tr[u] = {l, r};
        int mid = l + r >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
    // print(u);
}

void modify(int u, int l, int r, int add)
{
    if (l <= tr[u].l && tr[u].r <= r) 
    {
        tr[u].add += add, tr[u].maxv += add;
    }
    else 
    {
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        if (l <= mid) modify(u << 1, l, r, add);
        if (r > mid) modify(u << 1 | 1, l, r, add);
        pushup(u);
    }
    
}


Nodes query(int u, int l, int r)
{
    if (l <= tr[u].l && tr[u].r <= r) return tr[u];
    else 
    {
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        if (r <= mid) return query(u << 1, l, r);
        else if (l > mid) return query(u << 1 | 1, l, r);
        else 
        {
            Nodes A, B, res;
            A = query(u << 1, l, r);
            B = query(u << 1 | 1, l, r);
            pushup(res, A, B);
            return res;
        }
    }
}


int main()
{
    cin >> n;
    for (int i = 1; i <= n; i ++ ) scanf("%d", &g[i].s);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &g[i].a);
    
    int last = -1;
    for (int i = 1; i <= n; i ++ ) 
    {
        int a = g[i].a;
        if (last != g[i].s) 
        {
            cnt ++ ;
            last = g[i].s;
            S[cnt] = 2 * g[i].s;
        }
        s[cnt].push_back(a);
    }
    for (int j = 1; j <= cnt; j ++ ) sort(s[j].begin(), s[j].end());
    
    build(1, 1, cnt);
    
    last = 0;
    int sum = 0;
    for (int i = 1; i <= n; i ++ )
    {
        auto t = query(1, 1, cnt);
        int r = t.id;
        // cout << last << ' ' << query(1, 4, 4).maxv << endl;
        
        if (last < r)
        {
            for (int j = last + 1; j < r; j ++ )
            {
                modify(1, j, j, S[r] - S[j]);
            }
            
            if (last) modify(1, 1, last, S[r]);
            modify(1, 1, cnt, -S[r]);
            //實際上這兩句可以合併為
            //modify(1, last + 1, cnt, -S[r]);
        }
        last = max(last, r);
        
        modify(1, r, r, -s[r].back()); 
        s[r].pop_back();
        if (s[r].size()) modify(1, r, r, s[r].back());
        else modify(1, r, r, -0x3f3f3f3f);
        sum += t.maxv;
        
        cout << sum << '\n';
    }
    return 0;
    
}

相關文章