【題目全解】ACGO巔峰賽#15

Macw發表於2024-12-03

ACGO 巔峰賽#15 - 題目解析

間隔四個月再戰 ACGO Rated,鑑於最近學業繁忙,比賽打地都不是很頻繁。雖然這次沒有 AK 排位賽(我可以說是因為週末太忙,沒有充足的時間思考題目…(好吧,其實也許是因為我把 T5 給想複雜了))。

本文依舊提供每道題的完整解析(因為我在賽後把題目做出來了)。

T1 - 高塔

題目連結跳轉:點選跳轉

插一句題外話,這道題的題目編號挺有趣的。

沒有什麼特別難的點,迴圈讀入每一個數字,讀入後跟第一個輸入的數字比較大小,如果讀入的數字比第一個讀入的數字要大(即 \(a_i > a_1\)),直接輸出 \(i\) 並結束主程式即可。

本題的 C++ 程式碼如下:

#include <iostream>
using namespace std;

int n, arr[105];

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n;
    for (int i=1; i<=n; i++){
        cin >> arr[i];
        if (arr[i] > arr[1]){
            cout << i << endl;
            return 0;
        }
    }
    cout << -1 << endl;
    return 0;
}

本題的 Python 程式碼如下:

n = int(input())
arr = list(map(int, input().split()))

for i in range(1, n + 1):
    if arr[i - 1] > arr[0]:
        print(i)
        break
else:
    print(-1)

T2 - 營養均衡

題目連結跳轉:點選跳轉

也是一道入門題目,沒有什麼比較難的地方,重點是把題目讀清楚了。

我們設定一個陣列 \(\tt{arr}\),其中 \(\tt{arr_i}\) 表示種營養元素還需要的攝入量。那麼,如果 \(\tt{arr_i} \le 0\) 的話,就表示該種營養元素的攝入量已經達到了 “健康飲食” 的所需標準了。按照題意模擬一下即可,最後遍歷一整個陣列判斷是否有無法滿足的元素。換句話說,只要有任意的 \(\forall i\),滿足 \(\tt{arr_i} > 0\) 就需要輸出 No

本題的 C++ 程式碼如下:

#include <iostream>
using namespace std;

int n, m;
long long arr[1005];

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> m;
    for (int i=1; i<=m; i++) 
        cin >> arr[i];
    for (int i=1; i<=n; i++){
        for (int j=1; j<=m; j++){
            int t; cin >> t;
            arr[j] -= t;
        }
    }
    for (int i=1; i<=m; i++){
        if (arr[i] > 0){
            cout << "No" << endl;
            return 0;
        }
    }
    cout << "Yes" << endl;
    return 0;
}

本題的 Python 程式碼如下:

n, m = map(int, input().split())

arr = list(map(int, input().split()))

for _ in range(n):
    t = list(map(int, input().split()))
    for j in range(m):
        arr[j] -= t[j]

if any(x > 0 for x in arr):
    print("No")
else:
    print("Yes")

T3 - ^_^ 還是 😦

題目連結跳轉:點選跳轉

一道簡單的思維題目,難度定在【普及-】還算是合理的。不過 USACO 的 Bronze 組別特別喜歡考這種類似的思維題目。

普通演算法

考慮採用貪心的思路,先把序列按照從大到小的原則排序。暴力列舉一個節點 \(i\),判斷是否有可能滿足選擇前 \(i\) 個數字 \(-1\),剩下的數字都至少 \(+1\) 的情況下所有的數字都大於零。

那麼該如何快速的判斷是否所有的數字都大於零呢?首先可以肯定的是,後 \(n - i\) 個數字一定是大於零的,因為這些數字只會增加不會減少。所以我們把重點放在前 \(i\) 個數字上面。由於陣列已經是有序的,因此如果第 \(i\) 個數字是大於 \(1\) 的,那麼前 \(i\) 個數字在減去 \(1\) 之後也一定是正整數。

由於使用了排序演算法,本演算法的單次查詢時間複雜度在 \(O(N \log_2 N)\) 級別,總時間複雜度為 \(O(N^2 \log_2 N)\),可以在 \(\tt{1s}\) 內透過所有的測試點。

本題的 C++ 程式碼如下:

#include <iostream>
#include <unordered_map>
#include <algorithm>
#include <cmath>
using namespace std;

int n;
int arr[1005];

void solve(){
    cin >> n;
    long long sum = 0;
    for (int i=1, t; i<=n; i++){
        cin >> arr[i];
    }
    sort(arr+1, arr+1+n, greater<int>());
    if (n == 1) {
    	cout << ":-(" << endl;
        return ;
    }
    // 暴力列舉,選擇前 i 個數字 - 1,剩下的所有數字都至少 + 1。
    bool flag = 0;
    for (int i=1; i<=n; i++){
        sum += arr[i];
        if (arr[i] == 1) break;
        if (sum - (n - i) >= i) flag = 1;
    }
    cout << (flag ? "^_^" : ":-(") << endl;
    return ;
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    int T; cin >> T;
    while(T--) solve();
    return 0;
}

本題的 Python 程式碼如下:

def solve():
    n = int(input())
    arr = list(map(int, input().split()))
    
    # 對陣列降序排序
    arr.sort(reverse=True)
    
    if n == 1:
        print(":-(")
        return

    # 暴力列舉前 i 個數字 - 1,剩下的數字 +1
    sum_ = 0
    flag = False
    for i in range(1, n + 1):
        sum_ += arr[i - 1]
        if arr[i - 1] == 1:
            break
        if sum_ - (n - i) >= i:
            flag = True
    
    print("^_^" if flag else ":-(")

def main():
    T = int(input())
    for _ in range(T):
        solve()

if __name__ == "__main__":
    main()

二分答案最佳化

注意到答案是單調的,因此可以使用二分答案的演算法來將演算法的單次查詢複雜度降低到 \(O(\log_2 N)\) 級別,因此該演算法的總時間複雜度為 \(O(N \log_2 N)\)

最佳化後的 C++ 程式碼如下:

#include <iostream>
#include <algorithm>
using namespace std;

int n;
int arr[1005];

void solve() {
    cin >> n;
    for (int i = 1; i <= n; i++) 
        cin >> arr[i];
    
    sort(arr + 1, arr + 1 + n, greater<int>());
    
    if (n == 1) {
        cout << ":-(" << endl;
        return;
    }

    int left = 1, right = n, res = -1;
    while (left <= right) {
        int mid = (left + right) / 2;
        if (arr[mid] > 1) {
            res = mid;
            right = mid - 1;
        } else {
            left = mid + 1;
        }
    }

    cout << (res != -1 ? "^_^" : ":-(") << endl;
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    int T; cin >> T;
    while (T--) solve();
    return 0;
}

最佳化後的 Python 演算法如下:

def solve():
    n = int(input())
    arr = list(map(int, input().split()))
    
    # 對陣列降序排序
    arr.sort(reverse=True)

    if n == 1:
        print(":-(")
        return

    left, right, res = 0, n - 1, -1
    while left <= right:
        mid = (left + right) // 2
        if arr[mid] > 1:
            res = mid
            right = mid - 1
        else:
            left = mid + 1

    print("^_^" if res != -1 else ":-(")

def main():
    T = int(input())
    for _ in range(T):
        solve()

if __name__ == "__main__":
    main()

T4 - Azusa的計劃

題目連結跳轉:點選跳轉

這道題的難度也不是很高,稍微思考一下即可。

任何事件時間 \(t\)\((a + b)\) 取模後,事件可以對映到一個固定的週期內。這樣,問題就轉化為一個固定長度的區間檢查問題。

因此,在讀入數字後,將所有的數字對 \((a + b)\) 取模並排序,如果數字分佈(序列的最大值和最小值的差值天數)在 \(a\) 範圍內即可滿足將所有的日程安排在休息日當中。但需要注意的是,兩個日期的差值天數不能單純地使用數字相減的方法求得。以正常 \(7\) 天為一週作為範例,週一和週日的日期差值為 \(1\) 天,而不是 \(7 - 1 = 6\) 天。這也是本題最難的部分。

如果做過 區間 DP 的使用者應該能非常快速地想到如果資料是一個 “環狀” 的情況下該如何解決問題(參考題目:石子合併(標準版))。我們可以使用 “剖環成鏈” 的方法,將環中的元素複製一遍並將每個數字增加 \((a + b)\),拼接在原陣列的末尾,這樣一個長度為 \(n\) 的環就被擴充套件為一個長度為 \(2n\) 的線性陣列。

最後只需要遍歷這個陣列內所有長度為 \(n\) 的區間 \([i, n + i - 1]\),判斷是否有任意一個區間的最大值和最小值的差在 \(a\) 以內即可判斷是否可以講所有的日程安排都分不在休息日中。

本題的時間複雜度為 \(O(N \log_2 N)\)

本題的 C++ 程式碼如下:

#include <iostream>
#include <algorithm>
using namespace std;

int n, a, b;
int arr[500005];
int maximum, minimum = 0x7f7f7f7f;

int main(){
	ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> a >> b;
    for (int i=1; i<=n; i++){
    	cin >> arr[i];
        arr[i] %= (a + b);
    }
    sort(arr+1, arr+1+n);
    for (int i=1; i<=n; i++){
    	arr[i+n] = arr[i] + (a + b);
    }
    bool flag = 0;
    for (int i=1; i+n-1<=2*n; i++) {
        if (arr[i+n-1] - arr[i] < a)
            flag = 1;
    }
    cout << (flag ? "Yes" : "No") << endl;
}

本題的 Python 程式碼如下:

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    
    n, a, b = map(int, data[:3])
    arr = list(map(int, data[3:]))
    
    mod_value = a + b
    arr = [x % mod_value for x in arr]
    
    arr.sort()
    
    arr += [x + mod_value for x in arr]

    flag = False
    for i in range(n):
        if arr[i + n - 1] - arr[i] < a:
            flag = True
            break

    print("Yes" if flag else "No")

if __name__ == "__main__":
    main()

T5 - 字首和問題

題目連結跳轉:點選跳轉

我個人認為這道題比最後一道題要難,也許是因為這類題目做的比較少的原因,看到題目後不知道從哪下手。

使用分類討論的方法,設定一個閾值 \(S\),考慮暴力列舉所有 \(b > S\) 的情況,並離線最佳化 \(b \le S\) 的情況。將 \(S\) 設定為 \(\sqrt{N}\),則有:

  1. 對於大步長 \(b > S\),任意一次查詢只需要最多遍歷 \(550\)(即 \(\sqrt{N}\))次就可以算出答案,因此暴力列舉這部分。
  2. 對於小步長 \(b \le S\),按 \(b\) 分組批次離線查詢。

對於大步長部分,每一次查詢的時間複雜度為 \(O(\sqrt{N})\),在最壞情況下總時間複雜度為 \(O(N \times \sqrt{N})\)。對於小步長的部分,每一次查詢的時間複雜度約為 \(O(n)\),在最壞情況下的時間複雜度為 \(O(N\times \sqrt{N})\),因此本題在最壞情況下的漸進時間複雜度為:

\[\large{O(N \times \sqrt{N})} \]

最後,本題的 C++ 程式碼如下:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

struct Query {
    int id;  
    int a, b;   
};

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    
    int n; cin >> n;
    
    vector<LL> a_arr(n + 1, 0);
    for(int i =1; i <=n; i++) cin >> a_arr[i];
    
    int q; cin >> q;
    
    vector<Query> queries(q);
    for(int i =0; i < q; i++){
        cin >> queries[i].a >> queries[i].b;
        queries[i].id = i;
    }
    
    int S = 550;
    
    // 分組查詢:小步長和大步長
    // 對於小步長 b <= S,按 b 分組
    // 對於大步長 b > S,單獨儲存
    vector<vector<pair<int, int>>> small_b_queries(S +1, vector<pair<int, int>>()); // small_b_queries[b]儲存 (a, id)
    vector<pair<int, int>> large_b_queries; // 儲存 (a, id) for b > S
    
    for(int i =0; i < q; i++) {
        if(queries[i].b <= S)
            small_b_queries[queries[i].b].emplace_back(queries[i].a, queries[i].id);
        else
            large_b_queries.emplace_back(make_pair(queries[i].a, queries[i].id));
    }
    
    vector<LL> res(q, 0);
    
    // 預處理小步長查詢
    // 對每個 b =1 to S
    for(int b =1; b <= S; b++){
        if(small_b_queries[b].empty()) continue;
        
        // 建立一個臨時陣列 s_arr,用於儲存當前步長 b 的累加和
        // 從 n downto 1
        // s_arr[a] = a_arr[a] + s_arr[a + b] (如果 a + b <=n)
        // 否則 s_arr[a] = a[a]
        vector<LL> s_arr(n + 5, 0);
        for(int a = n; a >=1; a--){
            if(a + b <= n){
                s_arr[a] = a_arr[a] + s_arr[a + b];
            }
            else{
                s_arr[a] = a_arr[a];
            }
        }
        
        // 回答所有步長為 b 的查詢
        for(auto &[a, id] : small_b_queries[b]){
            res[id] = s_arr[a];
        }
    }
    
    // 處理大步長查詢
    // 由於 b > S,且 S = 550,所以每個查詢最多需要 ~550 次操作
    for(auto &[a, id] : large_b_queries){
        LL sum = 0;
        int current = a;
        while(current <= n){
            sum += a_arr[current];
            current += queries[id].b;
        }
        res[id] = sum;
    }
    
    for(int i =0; i < q; i++) 
        cout << res[i] << "\n";
    
    return 0;
}

本題的 Python 程式碼如下(不保證可以透過所有的測試點):

class Query:
    def __init__(self, id, a, b):
        self.id = id
        self.a = a
        self.b = b

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    
    n = int(data[0])
    a_arr = [0] * (n + 1)
    for i in range(1, n + 1):
        a_arr[i] = int(data[i])
    
    q = int(data[n + 1])
    queries = []
    idx = n + 2
    for i in range(q):
        a, b = int(data[idx]), int(data[idx + 1])
        queries.append(Query(i, a, b))
        idx += 2
    
    S = 550
    small_b_queries = [[] for _ in range(S + 1)]
    large_b_queries = []
    
    for query in queries:
        if query.b <= S:
            small_b_queries[query.b].append((query.a, query.id))
        else:
            large_b_queries.append((query.a, query.id))
    
    res = [0] * q
    
    for b in range(1, S + 1):
        if not small_b_queries[b]:
            continue
        
        s_arr = [0] * (n + 5)
        for a in range(n, 0, -1):
            if a + b <= n:
                s_arr[a] = a_arr[a] + s_arr[a + b]
            else:
                s_arr[a] = a_arr[a]
        
        for a, id in small_b_queries[b]:
            res[id] = s_arr[a]
    
    for a, id in large_b_queries:
        sum_val = 0
        current = a
        b = queries[id].b
        while current <= n:
            sum_val += a_arr[current]
            current += b
        res[id] = sum_val
    
    sys.stdout.write("\n".join(map(str, res)) + "\n")

if __name__ == "__main__":
    main()

T6 - 劃分割槽間

題目連結跳轉:點選跳轉

一道線段樹最佳化動態規劃的題目,難度趨近於 CSP 提高組的題目和 USACO 鉑金組的中等題。一眼可以看出題目是一個典型的動態規劃問題,但奈何資料量太大了,\(O(N^2)\) 的複雜度肯定會 TLE。但無論如何都是 “車到山前必有路”,看到資料範圍不用怕,先打一個暴力的動態規劃再最佳化。

按照一位 OI 大神的說法:“所有的動態規劃最佳化都是在基礎的程式碼上等量代換”。

與打家劫舍等線性動態規劃類似,對於本題而言,設狀態的定義為 \(dp_i\) 表示對 \([1, i]\) 這個序列劃分後可得到的最大貢獻。透過暴力遍歷 \(j, (1 \le j < i)\),表示將 \((j, i]\) 歸位一組。另設 \(A(j, i)\) 為區間 \((j, i]\) 的貢獻值。根據以上資訊可以得到狀態轉移方程:

\[\large{dp_i = \max_{0 \le j<i}{(dp_j + A(j, i))}} \]

接下來就是關於 \(A(j, i)\) 的計算了。設字首和陣列 \(S_i\) 表示從區間 \([1, i]\) 的和,那麼 \((j, i]\) 區間的和可以被表示為 \(S[i] - S[j]\)。根據不同的 \(S[i] - S[j]\),則有以下三種情況:

  1. \(S[i] - S[j] > 0\) 時,證明該區間的和是正數,貢獻為 \(i - j\)
  2. \(S[i] - S[j] = 0\) 時,該區間的和為零,貢獻為 \(0\)
  3. \(S[i] - S[j] < 0\) 時,證明該區間的和是負數,貢獻為 \(- (i - j) = j - i\)

綜上所述,可以寫出一個暴力版本的動態規劃程式碼:

#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>
using namespace std;

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    int n;
    cin >> n;
    vector<int> A(n + 1);
    vector<long long> S(n + 1, 0); 

    for (int i = 1; i <= n; i++) {
        cin >> A[i];
        S[i] = S[i - 1] + A[i];
    }

    vector<long long> dp(n + 1, LLONG_MIN);
    dp[0] = 0;

    for (int i = 1; i <= n; i++) {
        for (int j = 0; j < i; j++) {
            if (S[i] - S[j] > 0)
                dp[i] = max(dp[i], dp[j] + (i - j));
            if (S[i] - S[j] < 0)
                dp[i] = max(dp[i], dp[j] - (i - j));
            if (S[i] - S[j] == 0)
                dp[i] = max(dp[i], dp[j]);
        }
    }

    cout << dp[n] << endl;
    return 0;
}

接下來考慮最佳化這個動態規劃,注意到每一次尋找 \(\tt{max}\) 都非常耗時,每一次都需要遍歷一遍才能求出最大值。有沒有一種方法可以快速求出某一個區間的最大值呢?答案就是線段樹。線段樹是一個非常好的快速求解區間最值問題的資料結構。

更多有關區間最值問題的學習請參考:[# 淺入線段樹與區間最值問題](# 淺入線段樹與區間最值問題)

綜上,我們可以透過構建線段樹來快速求得答案。簡化三種情況可得:

if (S[i] - S[j] > 0)
    dp[i] = max(dp[i], dp[j] - j + i);
if (S[i] - S[j] < 0)
    dp[i] = max(dp[i], dp[j] + j - i));
if (S[i] - S[j] == 0)
    dp[i] = max(dp[i], dp[j]);

因此我們構造三棵線段樹,分別來維護這三個區間:

  1. \(\max_{0\le j < i} dp_j\)
  2. \(\max_{0\le j < i} (dp_j - j)\)
  3. \(\max_{0\le j < i} (dp_j + j)\)

然而我們的線段樹不能僅僅維護這個區間,因為這三個的最大值還被 \(A(j, i)\) 的三種狀態所限制著,因此,我們需要找的是滿足 \(S_i - S_j\) 在特定條件下的最大值。這樣就出現了另一個嚴重的問題,\(S_i\) 的值可能非常的大,因此我們需要對字首和陣列離散化一下(座標壓縮:類似於權值線段樹的寫法)才可以防止記憶體超限。

這樣子對於每次尋找最大值,都可以在 \(O(\log_2N)\) 的情況下找到。本演算法的總時間複雜度也控制在了 \(O(N \times \log_2N)\) 級別。

本題的 C++ 程式碼如下:

#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>
#define int long long
using namespace std;

const int MAX = 500005;

struct SegmentTree {
    int size;
    vector<int> tree;

    SegmentTree(int n_) {
        size = 1;
        while (size < n_) size <<=1;
        tree.assign(2*size, LLONG_MIN);
    }

    void update(int pos, int value){
        pos += size -1;
        tree[pos] = max(tree[pos], value);
        while(pos >1){
            pos >>=1;
            tree[pos] = max(tree[2*pos], tree[2*pos+1]);
        }
    }

    int query(int l, int r){
        l += size -1; r += size -1;
        int res = LLONG_MIN;
        while(l <= r){
            if(l%2 ==1)
                res = max(res, tree[l++]);
            if(r%2 ==0)
                res = max(res, tree[r--]);
            l >>=1; r >>=1;
        }
        return res;
    }
};

int n;
int A[MAX], S[MAX], aintS_arr[MAX];

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    
    cin >> n;
    for(int i=1;i<=n;i++) cin >> A[i];
    
    S[0] = 0;
    for(int i=1;i<=n;i++) S[i] = S[i-1] + A[i];
    
    for(int i=0;i<=n;i++) aintS_arr[i] = S[i];
    sort(aintS_arr, aintS_arr + n +1);
    int m = unique(aintS_arr, aintS_arr + n +1) - aintS_arr;
    
    auto get_idx = [&](int x) -> int {
        return lower_bound(aintS_arr, aintS_arr + m, x) - aintS_arr +1;
    };
    
    SegmentTree BIT1(m); // max(dp[j]-j)
    SegmentTree BIT2(m); // max(dp[j])
    SegmentTree BIT3(m); // max(dp[j]+j)
    
    int idx_S0 = get_idx(S[0]);
    BIT1.update(idx_S0, 0);
    BIT2.update(idx_S0, 0);
    BIT3.update(idx_S0, 0);
    
    int dp_i = LLONG_MIN;
    for(int i=1;i<=n;i++){
        int Si = S[i];
        int idx_Si = get_idx(Si);
        
        int option1 = LLONG_MIN;
        if(idx_Si >1){
            int temp = BIT1.query(1, idx_Si -1);
            if(temp != LLONG_MIN){
                option1 = temp + i;
            }
        }
        
        int option2 = BIT2.query(idx_Si, idx_Si);
        
        int option3 = LLONG_MIN;
        if(idx_Si < m){
            int temp = BIT3.query(idx_Si +1, m);
            if(temp != LLONG_MIN){
                option3 = temp - i;
            }
        }
        
        dp_i = max(option1, max(option2, option3));
        
        BIT1.update(idx_Si, dp_i - i);
        BIT2.update(idx_Si, dp_i);
        BIT3.update(idx_Si, dp_i + i);
    }
    
    cout << dp_i;
}

本題的 Python 程式碼如下(由於 Python 常數過大,因此沒有辦法透過這道題所有的測試點,但是程式碼的正確性沒有問題):

class SegmentTree:
    def __init__(self, n):
        self.size = 1
        while self.size < n:
            self.size *= 2
        self.tree = [float('-inf')] * (2 * self.size)

    def update(self, pos, value):
        pos += self.size - 1
        self.tree[pos] = max(self.tree[pos], value)
        while pos > 1:
            pos //= 2
            self.tree[pos] = max(self.tree[2 * pos], self.tree[2 * pos + 1])

    def query(self, l, r):
        l += self.size - 1
        r += self.size - 1
        res = float('-inf')
        while l <= r:
            if l % 2 == 1:
                res = max(res, self.tree[l])
                l += 1
            if r % 2 == 0:
                res = max(res, self.tree[r])
                r -= 1
            l //= 2
            r //= 2
        return res


def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    
    n = int(data[0])
    A = list(map(int, data[1:n + 1]))
    
    S = [0] * (n + 1)
    for i in range(1, n + 1):
        S[i] = S[i - 1] + A[i - 1]
    
    aintS_arr = S[:]
    aintS_arr.sort()
    m = len(set(aintS_arr))
    aintS_arr = sorted(set(aintS_arr))
    
    def get_idx(x):
        # Return the index in the compressed array
        return aintS_arr.index(x) + 1
    
    BIT1 = SegmentTree(m)  # max(dp[j] - j)
    BIT2 = SegmentTree(m)  # max(dp[j])
    BIT3 = SegmentTree(m)  # max(dp[j] + j)
    
    idx_S0 = get_idx(S[0])
    BIT1.update(idx_S0, 0)
    BIT2.update(idx_S0, 0)
    BIT3.update(idx_S0, 0)
    
    dp_i = float('-inf')
    for i in range(1, n + 1):
        Si = S[i]
        idx_Si = get_idx(Si)
        
        option1 = float('-inf')
        if idx_Si > 1:
            temp = BIT1.query(1, idx_Si - 1)
            if temp != float('-inf'):
                option1 = temp + i
        
        option2 = BIT2.query(idx_Si, idx_Si)
        
        option3 = float('-inf')
        if idx_Si < m:
            temp = BIT3.query(idx_Si + 1, m)
            if temp != float('-inf'):
                option3 = temp - i
        
        dp_i = max(option1, option2, option3)
        
        BIT1.update(idx_Si, dp_i - i)
        BIT2.update(idx_Si, dp_i)
        BIT3.update(idx_Si, dp_i + i)
    
    print(dp_i)

if __name__ == "__main__":
    main()

ACGO 巔峰賽#15 - 題目解析

間隔四個月再戰 ACGO Rated,鑑於最近學業繁忙,比賽打地都不是很頻繁。雖然這次沒有 AK 排位賽(我可以說是因為週末太忙,沒有充足的時間思考題目…(好吧,其實也許是因為我把 T5 給想複雜了))。

本文依舊提供每道題的完整解析(因為我在賽後把題目做出來了)。

T1 - 高塔

題目連結跳轉:點選跳轉

插一句題外話,這道題的題目編號挺有趣的。

沒有什麼特別難的點,迴圈讀入每一個數字,讀入後跟第一個輸入的數字比較大小,如果讀入的數字比第一個讀入的數字要大(即 \(a_i > a_1\)),直接輸出 \(i\) 並結束主程式即可。

本題的 C++ 程式碼如下:

#include <iostream>
using namespace std;

int n, arr[105];

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n;
    for (int i=1; i<=n; i++){
        cin >> arr[i];
        if (arr[i] > arr[1]){
            cout << i << endl;
            return 0;
        }
    }
    cout << -1 << endl;
    return 0;
}

本題的 Python 程式碼如下:

n = int(input())
arr = list(map(int, input().split()))

for i in range(1, n + 1):
    if arr[i - 1] > arr[0]:
        print(i)
        break
else:
    print(-1)

T2 - 營養均衡

題目連結跳轉:點選跳轉

也是一道入門題目,沒有什麼比較難的地方,重點是把題目讀清楚了。

我們設定一個陣列 \(\tt{arr}\),其中 \(\tt{arr_i}\) 表示種營養元素還需要的攝入量。那麼,如果 \(\tt{arr_i} \le 0\) 的話,就表示該種營養元素的攝入量已經達到了 “健康飲食” 的所需標準了。按照題意模擬一下即可,最後遍歷一整個陣列判斷是否有無法滿足的元素。換句話說,只要有任意的 \(\forall i\),滿足 \(\tt{arr_i} > 0\) 就需要輸出 No

本題的 C++ 程式碼如下:

#include <iostream>
using namespace std;

int n, m;
long long arr[1005];

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> m;
    for (int i=1; i<=m; i++) 
        cin >> arr[i];
    for (int i=1; i<=n; i++){
        for (int j=1; j<=m; j++){
            int t; cin >> t;
            arr[j] -= t;
        }
    }
    for (int i=1; i<=m; i++){
        if (arr[i] > 0){
            cout << "No" << endl;
            return 0;
        }
    }
    cout << "Yes" << endl;
    return 0;
}

本題的 Python 程式碼如下:

n, m = map(int, input().split())

arr = list(map(int, input().split()))

for _ in range(n):
    t = list(map(int, input().split()))
    for j in range(m):
        arr[j] -= t[j]

if any(x > 0 for x in arr):
    print("No")
else:
    print("Yes")

T3 - ^_^ 還是 😦

題目連結跳轉:點選跳轉

一道簡單的思維題目,難度定在【普及-】還算是合理的。不過 USACO 的 Bronze 組別特別喜歡考這種類似的思維題目。

普通演算法

考慮採用貪心的思路,先把序列按照從大到小的原則排序。暴力列舉一個節點 \(i\),判斷是否有可能滿足選擇前 \(i\) 個數字 \(-1\),剩下的數字都至少 \(+1\) 的情況下所有的數字都大於零。

那麼該如何快速的判斷是否所有的數字都大於零呢?首先可以肯定的是,後 \(n - i\) 個數字一定是大於零的,因為這些數字只會增加不會減少。所以我們把重點放在前 \(i\) 個數字上面。由於陣列已經是有序的,因此如果第 \(i\) 個數字是大於 \(1\) 的,那麼前 \(i\) 個數字在減去 \(1\) 之後也一定是正整數。

由於使用了排序演算法,本演算法的單次查詢時間複雜度在 \(O(N \log_2 N)\) 級別,總時間複雜度為 \(O(N^2 \log_2 N)\),可以在 \(\tt{1s}\) 內透過所有的測試點。

本題的 C++ 程式碼如下:

#include <iostream>
#include <unordered_map>
#include <algorithm>
#include <cmath>
using namespace std;

int n;
int arr[1005];

void solve(){
    cin >> n;
    long long sum = 0;
    for (int i=1, t; i<=n; i++){
        cin >> arr[i];
    }
    sort(arr+1, arr+1+n, greater<int>());
    if (n == 1) {
    	cout << ":-(" << endl;
        return ;
    }
    // 暴力列舉,選擇前 i 個數字 - 1,剩下的所有數字都至少 + 1。
    bool flag = 0;
    for (int i=1; i<=n; i++){
        sum += arr[i];
        if (arr[i] == 1) break;
        if (sum - (n - i) >= i) flag = 1;
    }
    cout << (flag ? "^_^" : ":-(") << endl;
    return ;
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    int T; cin >> T;
    while(T--) solve();
    return 0;
}

本題的 Python 程式碼如下:

def solve():
    n = int(input())
    arr = list(map(int, input().split()))
    
    # 對陣列降序排序
    arr.sort(reverse=True)
    
    if n == 1:
        print(":-(")
        return

    # 暴力列舉前 i 個數字 - 1,剩下的數字 +1
    sum_ = 0
    flag = False
    for i in range(1, n + 1):
        sum_ += arr[i - 1]
        if arr[i - 1] == 1:
            break
        if sum_ - (n - i) >= i:
            flag = True
    
    print("^_^" if flag else ":-(")

def main():
    T = int(input())
    for _ in range(T):
        solve()

if __name__ == "__main__":
    main()

二分答案最佳化

注意到答案是單調的,因此可以使用二分答案的演算法來將演算法的單次查詢複雜度降低到 \(O(\log_2 N)\) 級別,因此該演算法的總時間複雜度為 \(O(N \log_2 N)\)

最佳化後的 C++ 程式碼如下:

#include <iostream>
#include <algorithm>
using namespace std;

int n;
int arr[1005];

void solve() {
    cin >> n;
    for (int i = 1; i <= n; i++) 
        cin >> arr[i];
    
    sort(arr + 1, arr + 1 + n, greater<int>());
    
    if (n == 1) {
        cout << ":-(" << endl;
        return;
    }

    int left = 1, right = n, res = -1;
    while (left <= right) {
        int mid = (left + right) / 2;
        if (arr[mid] > 1) {
            res = mid;
            right = mid - 1;
        } else {
            left = mid + 1;
        }
    }

    cout << (res != -1 ? "^_^" : ":-(") << endl;
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    int T; cin >> T;
    while (T--) solve();
    return 0;
}

最佳化後的 Python 演算法如下:

def solve():
    n = int(input())
    arr = list(map(int, input().split()))
    
    # 對陣列降序排序
    arr.sort(reverse=True)

    if n == 1:
        print(":-(")
        return

    left, right, res = 0, n - 1, -1
    while left <= right:
        mid = (left + right) // 2
        if arr[mid] > 1:
            res = mid
            right = mid - 1
        else:
            left = mid + 1

    print("^_^" if res != -1 else ":-(")

def main():
    T = int(input())
    for _ in range(T):
        solve()

if __name__ == "__main__":
    main()

T4 - Azusa的計劃

題目連結跳轉:點選跳轉

這道題的難度也不是很高,稍微思考一下即可。

任何事件時間 \(t\)\((a + b)\) 取模後,事件可以對映到一個固定的週期內。這樣,問題就轉化為一個固定長度的區間檢查問題。

因此,在讀入數字後,將所有的數字對 \((a + b)\) 取模並排序,如果數字分佈(序列的最大值和最小值的差值天數)在 \(a\) 範圍內即可滿足將所有的日程安排在休息日當中。但需要注意的是,兩個日期的差值天數不能單純地使用數字相減的方法求得。以正常 \(7\) 天為一週作為範例,週一和週日的日期差值為 \(1\) 天,而不是 \(7 - 1 = 6\) 天。這也是本題最難的部分。

如果做過 區間 DP 的使用者應該能非常快速地想到如果資料是一個 “環狀” 的情況下該如何解決問題(參考題目:石子合併(標準版))。我們可以使用 “剖環成鏈” 的方法,將環中的元素複製一遍並將每個數字增加 \((a + b)\),拼接在原陣列的末尾,這樣一個長度為 \(n\) 的環就被擴充套件為一個長度為 \(2n\) 的線性陣列。

最後只需要遍歷這個陣列內所有長度為 \(n\) 的區間 \([i, n + i - 1]\),判斷是否有任意一個區間的最大值和最小值的差在 \(a\) 以內即可判斷是否可以講所有的日程安排都分不在休息日中。

本題的時間複雜度為 \(O(N \log_2 N)\)

本題的 C++ 程式碼如下:

#include <iostream>
#include <algorithm>
using namespace std;

int n, a, b;
int arr[500005];
int maximum, minimum = 0x7f7f7f7f;

int main(){
	ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> a >> b;
    for (int i=1; i<=n; i++){
    	cin >> arr[i];
        arr[i] %= (a + b);
    }
    sort(arr+1, arr+1+n);
    for (int i=1; i<=n; i++){
    	arr[i+n] = arr[i] + (a + b);
    }
    bool flag = 0;
    for (int i=1; i+n-1<=2*n; i++) {
        if (arr[i+n-1] - arr[i] < a)
            flag = 1;
    }
    cout << (flag ? "Yes" : "No") << endl;
}

本題的 Python 程式碼如下:

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    
    n, a, b = map(int, data[:3])
    arr = list(map(int, data[3:]))
    
    mod_value = a + b
    arr = [x % mod_value for x in arr]
    
    arr.sort()
    
    arr += [x + mod_value for x in arr]

    flag = False
    for i in range(n):
        if arr[i + n - 1] - arr[i] < a:
            flag = True
            break

    print("Yes" if flag else "No")

if __name__ == "__main__":
    main()

T5 - 字首和問題

題目連結跳轉:點選跳轉

我個人認為這道題比最後一道題要難,也許是因為這類題目做的比較少的原因,看到題目後不知道從哪下手。

使用分類討論的方法,設定一個閾值 \(S\),考慮暴力列舉所有 \(b > S\) 的情況,並離線最佳化 \(b \le S\) 的情況。將 \(S\) 設定為 \(\sqrt{N}\),則有:

  1. 對於大步長 \(b > S\),任意一次查詢只需要最多遍歷 \(550\)(即 \(\sqrt{N}\))次就可以算出答案,因此暴力列舉這部分。
  2. 對於小步長 \(b \le S\),按 \(b\) 分組批次離線查詢。

對於大步長部分,每一次查詢的時間複雜度為 \(O(\sqrt{N})\),在最壞情況下總時間複雜度為 \(O(N \times \sqrt{N})\)。對於小步長的部分,每一次查詢的時間複雜度約為 \(O(n)\),在最壞情況下的時間複雜度為 \(O(N\times \sqrt{N})\),因此本題在最壞情況下的漸進時間複雜度為:

\[\large{O(N \times \sqrt{N})} \]

最後,本題的 C++ 程式碼如下:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

struct Query {
    int id;  
    int a, b;   
};

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    
    int n; cin >> n;
    
    vector<LL> a_arr(n + 1, 0);
    for(int i =1; i <=n; i++) cin >> a_arr[i];
    
    int q; cin >> q;
    
    vector<Query> queries(q);
    for(int i =0; i < q; i++){
        cin >> queries[i].a >> queries[i].b;
        queries[i].id = i;
    }
    
    int S = 550;
    
    // 分組查詢:小步長和大步長
    // 對於小步長 b <= S,按 b 分組
    // 對於大步長 b > S,單獨儲存
    vector<vector<pair<int, int>>> small_b_queries(S +1, vector<pair<int, int>>()); // small_b_queries[b]儲存 (a, id)
    vector<pair<int, int>> large_b_queries; // 儲存 (a, id) for b > S
    
    for(int i =0; i < q; i++) {
        if(queries[i].b <= S)
            small_b_queries[queries[i].b].emplace_back(queries[i].a, queries[i].id);
        else
            large_b_queries.emplace_back(make_pair(queries[i].a, queries[i].id));
    }
    
    vector<LL> res(q, 0);
    
    // 預處理小步長查詢
    // 對每個 b =1 to S
    for(int b =1; b <= S; b++){
        if(small_b_queries[b].empty()) continue;
        
        // 建立一個臨時陣列 s_arr,用於儲存當前步長 b 的累加和
        // 從 n downto 1
        // s_arr[a] = a_arr[a] + s_arr[a + b] (如果 a + b <=n)
        // 否則 s_arr[a] = a[a]
        vector<LL> s_arr(n + 5, 0);
        for(int a = n; a >=1; a--){
            if(a + b <= n){
                s_arr[a] = a_arr[a] + s_arr[a + b];
            }
            else{
                s_arr[a] = a_arr[a];
            }
        }
        
        // 回答所有步長為 b 的查詢
        for(auto &[a, id] : small_b_queries[b]){
            res[id] = s_arr[a];
        }
    }
    
    // 處理大步長查詢
    // 由於 b > S,且 S = 550,所以每個查詢最多需要 ~550 次操作
    for(auto &[a, id] : large_b_queries){
        LL sum = 0;
        int current = a;
        while(current <= n){
            sum += a_arr[current];
            current += queries[id].b;
        }
        res[id] = sum;
    }
    
    for(int i =0; i < q; i++) 
        cout << res[i] << "\n";
    
    return 0;
}

本題的 Python 程式碼如下(不保證可以透過所有的測試點):

class Query:
    def __init__(self, id, a, b):
        self.id = id
        self.a = a
        self.b = b

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    
    n = int(data[0])
    a_arr = [0] * (n + 1)
    for i in range(1, n + 1):
        a_arr[i] = int(data[i])
    
    q = int(data[n + 1])
    queries = []
    idx = n + 2
    for i in range(q):
        a, b = int(data[idx]), int(data[idx + 1])
        queries.append(Query(i, a, b))
        idx += 2
    
    S = 550
    small_b_queries = [[] for _ in range(S + 1)]
    large_b_queries = []
    
    for query in queries:
        if query.b <= S:
            small_b_queries[query.b].append((query.a, query.id))
        else:
            large_b_queries.append((query.a, query.id))
    
    res = [0] * q
    
    for b in range(1, S + 1):
        if not small_b_queries[b]:
            continue
        
        s_arr = [0] * (n + 5)
        for a in range(n, 0, -1):
            if a + b <= n:
                s_arr[a] = a_arr[a] + s_arr[a + b]
            else:
                s_arr[a] = a_arr[a]
        
        for a, id in small_b_queries[b]:
            res[id] = s_arr[a]
    
    for a, id in large_b_queries:
        sum_val = 0
        current = a
        b = queries[id].b
        while current <= n:
            sum_val += a_arr[current]
            current += b
        res[id] = sum_val
    
    sys.stdout.write("\n".join(map(str, res)) + "\n")

if __name__ == "__main__":
    main()

T6 - 劃分割槽間

題目連結跳轉:點選跳轉

一道線段樹最佳化動態規劃的題目,難度趨近於 CSP 提高組的題目和 USACO 鉑金組的中等題。一眼可以看出題目是一個典型的動態規劃問題,但奈何資料量太大了,\(O(N^2)\) 的複雜度肯定會 TLE。但無論如何都是 “車到山前必有路”,看到資料範圍不用怕,先打一個暴力的動態規劃再最佳化。

按照一位 OI 大神的說法:“所有的動態規劃最佳化都是在基礎的程式碼上等量代換”。

與打家劫舍等線性動態規劃類似,對於本題而言,設狀態的定義為 \(dp_i\) 表示對 \([1, i]\) 這個序列劃分後可得到的最大貢獻。透過暴力遍歷 \(j, (1 \le j < i)\),表示將 \((j, i]\) 歸位一組。另設 \(A(j, i)\) 為區間 \((j, i]\) 的貢獻值。根據以上資訊可以得到狀態轉移方程:

\[\large{dp_i = \max_{0 \le j<i}{(dp_j + A(j, i))}} \]

接下來就是關於 \(A(j, i)\) 的計算了。設字首和陣列 \(S_i\) 表示從區間 \([1, i]\) 的和,那麼 \((j, i]\) 區間的和可以被表示為 \(S[i] - S[j]\)。根據不同的 \(S[i] - S[j]\),則有以下三種情況:

  1. \(S[i] - S[j] > 0\) 時,證明該區間的和是正數,貢獻為 \(i - j\)
  2. \(S[i] - S[j] = 0\) 時,該區間的和為零,貢獻為 \(0\)
  3. \(S[i] - S[j] < 0\) 時,證明該區間的和是負數,貢獻為 \(- (i - j) = j - i\)

綜上所述,可以寫出一個暴力版本的動態規劃程式碼:

#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>
using namespace std;

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    int n;
    cin >> n;
    vector<int> A(n + 1);
    vector<long long> S(n + 1, 0); 

    for (int i = 1; i <= n; i++) {
        cin >> A[i];
        S[i] = S[i - 1] + A[i];
    }

    vector<long long> dp(n + 1, LLONG_MIN);
    dp[0] = 0;

    for (int i = 1; i <= n; i++) {
        for (int j = 0; j < i; j++) {
            if (S[i] - S[j] > 0)
                dp[i] = max(dp[i], dp[j] + (i - j));
            if (S[i] - S[j] < 0)
                dp[i] = max(dp[i], dp[j] - (i - j));
            if (S[i] - S[j] == 0)
                dp[i] = max(dp[i], dp[j]);
        }
    }

    cout << dp[n] << endl;
    return 0;
}

接下來考慮最佳化這個動態規劃,注意到每一次尋找 \(\tt{max}\) 都非常耗時,每一次都需要遍歷一遍才能求出最大值。有沒有一種方法可以快速求出某一個區間的最大值呢?答案就是線段樹。線段樹是一個非常好的快速求解區間最值問題的資料結構。

更多有關區間最值問題的學習請參考:[# 淺入線段樹與區間最值問題](# 淺入線段樹與區間最值問題)

綜上,我們可以透過構建線段樹來快速求得答案。簡化三種情況可得:

if (S[i] - S[j] > 0)
    dp[i] = max(dp[i], dp[j] - j + i);
if (S[i] - S[j] < 0)
    dp[i] = max(dp[i], dp[j] + j - i));
if (S[i] - S[j] == 0)
    dp[i] = max(dp[i], dp[j]);

因此我們構造三棵線段樹,分別來維護這三個區間:

  1. \(\max_{0\le j < i} dp_j\)
  2. \(\max_{0\le j < i} (dp_j - j)\)
  3. \(\max_{0\le j < i} (dp_j + j)\)

然而我們的線段樹不能僅僅維護這個區間,因為這三個的最大值還被 \(A(j, i)\) 的三種狀態所限制著,因此,我們需要找的是滿足 \(S_i - S_j\) 在特定條件下的最大值。這樣就出現了另一個嚴重的問題,\(S_i\) 的值可能非常的大,因此我們需要對字首和陣列離散化一下(座標壓縮:類似於權值線段樹的寫法)才可以防止記憶體超限。

這樣子對於每次尋找最大值,都可以在 \(O(\log_2N)\) 的情況下找到。本演算法的總時間複雜度也控制在了 \(O(N \times \log_2N)\) 級別。

本題的 C++ 程式碼如下:

#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>
#define int long long
using namespace std;

const int MAX = 500005;

struct SegmentTree {
    int size;
    vector<int> tree;

    SegmentTree(int n_) {
        size = 1;
        while (size < n_) size <<=1;
        tree.assign(2*size, LLONG_MIN);
    }

    void update(int pos, int value){
        pos += size -1;
        tree[pos] = max(tree[pos], value);
        while(pos >1){
            pos >>=1;
            tree[pos] = max(tree[2*pos], tree[2*pos+1]);
        }
    }

    int query(int l, int r){
        l += size -1; r += size -1;
        int res = LLONG_MIN;
        while(l <= r){
            if(l%2 ==1)
                res = max(res, tree[l++]);
            if(r%2 ==0)
                res = max(res, tree[r--]);
            l >>=1; r >>=1;
        }
        return res;
    }
};

int n;
int A[MAX], S[MAX], aintS_arr[MAX];

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    
    cin >> n;
    for(int i=1;i<=n;i++) cin >> A[i];
    
    S[0] = 0;
    for(int i=1;i<=n;i++) S[i] = S[i-1] + A[i];
    
    for(int i=0;i<=n;i++) aintS_arr[i] = S[i];
    sort(aintS_arr, aintS_arr + n +1);
    int m = unique(aintS_arr, aintS_arr + n +1) - aintS_arr;
    
    auto get_idx = [&](int x) -> int {
        return lower_bound(aintS_arr, aintS_arr + m, x) - aintS_arr +1;
    };
    
    SegmentTree BIT1(m); // max(dp[j]-j)
    SegmentTree BIT2(m); // max(dp[j])
    SegmentTree BIT3(m); // max(dp[j]+j)
    
    int idx_S0 = get_idx(S[0]);
    BIT1.update(idx_S0, 0);
    BIT2.update(idx_S0, 0);
    BIT3.update(idx_S0, 0);
    
    int dp_i = LLONG_MIN;
    for(int i=1;i<=n;i++){
        int Si = S[i];
        int idx_Si = get_idx(Si);
        
        int option1 = LLONG_MIN;
        if(idx_Si >1){
            int temp = BIT1.query(1, idx_Si -1);
            if(temp != LLONG_MIN){
                option1 = temp + i;
            }
        }
        
        int option2 = BIT2.query(idx_Si, idx_Si);
        
        int option3 = LLONG_MIN;
        if(idx_Si < m){
            int temp = BIT3.query(idx_Si +1, m);
            if(temp != LLONG_MIN){
                option3 = temp - i;
            }
        }
        
        dp_i = max(option1, max(option2, option3));
        
        BIT1.update(idx_Si, dp_i - i);
        BIT2.update(idx_Si, dp_i);
        BIT3.update(idx_Si, dp_i + i);
    }
    
    cout << dp_i;
}

本題的 Python 程式碼如下(由於 Python 常數過大,因此沒有辦法透過這道題所有的測試點,但是程式碼的正確性沒有問題):

class SegmentTree:
    def __init__(self, n):
        self.size = 1
        while self.size < n:
            self.size *= 2
        self.tree = [float('-inf')] * (2 * self.size)

    def update(self, pos, value):
        pos += self.size - 1
        self.tree[pos] = max(self.tree[pos], value)
        while pos > 1:
            pos //= 2
            self.tree[pos] = max(self.tree[2 * pos], self.tree[2 * pos + 1])

    def query(self, l, r):
        l += self.size - 1
        r += self.size - 1
        res = float('-inf')
        while l <= r:
            if l % 2 == 1:
                res = max(res, self.tree[l])
                l += 1
            if r % 2 == 0:
                res = max(res, self.tree[r])
                r -= 1
            l //= 2
            r //= 2
        return res


def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    
    n = int(data[0])
    A = list(map(int, data[1:n + 1]))
    
    S = [0] * (n + 1)
    for i in range(1, n + 1):
        S[i] = S[i - 1] + A[i - 1]
    
    aintS_arr = S[:]
    aintS_arr.sort()
    m = len(set(aintS_arr))
    aintS_arr = sorted(set(aintS_arr))
    
    def get_idx(x):
        # Return the index in the compressed array
        return aintS_arr.index(x) + 1
    
    BIT1 = SegmentTree(m)  # max(dp[j] - j)
    BIT2 = SegmentTree(m)  # max(dp[j])
    BIT3 = SegmentTree(m)  # max(dp[j] + j)
    
    idx_S0 = get_idx(S[0])
    BIT1.update(idx_S0, 0)
    BIT2.update(idx_S0, 0)
    BIT3.update(idx_S0, 0)
    
    dp_i = float('-inf')
    for i in range(1, n + 1):
        Si = S[i]
        idx_Si = get_idx(Si)
        
        option1 = float('-inf')
        if idx_Si > 1:
            temp = BIT1.query(1, idx_Si - 1)
            if temp != float('-inf'):
                option1 = temp + i
        
        option2 = BIT2.query(idx_Si, idx_Si)
        
        option3 = float('-inf')
        if idx_Si < m:
            temp = BIT3.query(idx_Si + 1, m)
            if temp != float('-inf'):
                option3 = temp - i
        
        dp_i = max(option1, option2, option3)
        
        BIT1.update(idx_Si, dp_i - i)
        BIT2.update(idx_Si, dp_i)
        BIT3.update(idx_Si, dp_i + i)
    
    print(dp_i)

if __name__ == "__main__":
    main()

當然也可以用樹狀陣列來寫,速度可能會更快一點:

#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>
#define int long long
using namespace std;

const int MAX = 500005;

struct FenwickTree {
    int size;
    vector<int> tree;

    FenwickTree(int n_) {
        size = n_;
        tree.assign(size + 1, LLONG_MIN);
    }

    void update(int pos, int value) {
        while (pos <= size) {
            tree[pos] = max(tree[pos], value);
            pos += pos & -pos;
        }
    }

    int query(int pos) {
        int res = LLONG_MIN;
        while (pos > 0) {
            res = max(res, tree[pos]);
            pos -= pos & -pos;
        }
        return res;
    }

    int query(int l, int r) {
        return max(query(r), (l > 1 ? query(l - 1) : LLONG_MIN));
    }
};

int n;
int A[MAX], S[MAX], aintS_arr[MAX];

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    cin >> n;
    for (int i = 1; i <= n; i++) cin >> A[i];

    S[0] = 0;
    for (int i = 1; i <= n; i++) S[i] = S[i - 1] + A[i];

    for (int i = 0; i <= n; i++) aintS_arr[i] = S[i];
    sort(aintS_arr, aintS_arr + n + 1);
    int m = unique(aintS_arr, aintS_arr + n + 1) - aintS_arr;

    auto get_idx = [&](int x) -> int {
        return lower_bound(aintS_arr, aintS_arr + m, x) - aintS_arr + 1;
    };

    FenwickTree BIT1(m); // max(dp[j] - j)
    FenwickTree BIT2(m); // max(dp[j])
    FenwickTree BIT3(m); // max(dp[j] + j)

    int idx_S0 = get_idx(S[0]);
    BIT1.update(idx_S0, 0);
    BIT2.update(idx_S0, 0);
    BIT3.update(idx_S0, 0);

    int dp_i = LLONG_MIN;
    for (int i = 1; i <= n; i++) {
        int Si = S[i];
        int idx_Si = get_idx(Si);

        int option1 = LLONG_MIN;
        if (idx_Si > 1) {
            option1 = BIT1.query(1, idx_Si - 1) + i;
        }

        int option2 = BIT2.query(idx_Si, idx_Si);

        int option3 = LLONG_MIN;
        if (idx_Si < m) {
            option3 = BIT3.query(idx_Si + 1, m) - i;
        }

        dp_i = max(option1, max(option2, option3));

        BIT1.update(idx_Si, dp_i - i);
        BIT2.update(idx_Si, dp_i);
        BIT3.update(idx_Si, dp_i + i);
    }

    cout << dp_i;
}

相關文章