【題目全解】ACGO排位賽#13

Macw發表於2024-10-09

ACGO排位賽#13 - 題目解析

感謝大家參加本次排位賽!

T1 - 紀元流星雨

題目連結跳轉:點選跳轉

也沒有特別大的難度,手動模擬一下就可以了。

解題步驟

  1. 先計算出這個人一生中第一次看到流星雨的日子:\((E + B) \mod 50\)​。
  2. 計算出剩餘一生中可以看到流星雨的年份 \(Y\)
  3. 答案就是 \(\dfrac{Y}{50} + 1\)

程式碼實現

本題的 C++ 程式碼如下:

#include <iostream>
using namespace std;

int solve(int B, int L, int E) {
    int age_at_first_shower = (E + B) % 50;
    if (age_at_first_shower > L) return 0;
    int years_from_first_shower = 
        L - age_at_first_shower;
    return years_from_first_shower / 50 + 1;
}

int main() {
    int T; cin >> T;
    for (int i = 0; i < T; i++) {
        int B, L, E;
        cin >> B >> L >> E;
        cout << solve(B, L, E) << '\n';
    }
    return 0;
}

本題的 Python 程式碼如下:

T = int(input())
for _ in range(T):
    B, L, E = map(int, input().split(' '))
    before = B + E
    after = L - before
    print(before//50 + after//50 + 1)

T2 - MARCOncatenate

題目連結跳轉:點選跳轉

根據題目模擬就可以了,沒有什麼難度。

程式碼實現

本題的 C++ 程式碼如下:

#include <iostream>
#include <string>

using namespace std;

string marco = "marco";
string capmarco = "MARCO";

string solve(string S) {
    int i = 0;
    int max = 0;
    for (int i = 1; i <= min(5, int(S.size())); ++i) {
        if (marco.substr(5 - i, i) == S.substr(0, i)) {
            max = i;
        }
    }
    if (max == 0) {
        return S;
    } else {
        return capmarco + S.substr(max, int(S.size()) - max);
    }
}

int main() {
    int T;
    cin >> T;
    for (int i = 0; i < T; i++) {
        string S;
        cin >> S;
        cout << solve(S) << '\n';
    }
    return 0;
}

本題的 Python 程式碼如下:

def solve(S: str) -> str:    
    marco = "MARCO"
    marco_lower = "marco"
    matching_count = 0

    for i in range(7):
        if S[0:i] == marco_lower[5 - i:]:
            matching_count = i
            
    return marco + S[matching_count:] if matching_count != 0 else S

def main():
    T = int(input())
    for _ in range(T):
        S = input()
        print(solve(S))


if __name__ == '__main__':
    main()

T3 - TNT接力

題目連結跳轉:點選跳轉

這道題是本次比賽的第三道題,但是許多參賽選手認為這道題是本場比賽最難的題目。

解題思路

  1. 滑動視窗:使用滑動視窗技術,先在前 \(K\) 個方塊中計算有多少個空氣方塊。接著,隨著視窗向右滑動,我們移除視窗左端的一個方塊,並加入視窗右端的一個方塊,更新空氣方塊的數量。

  2. 最大空氣方塊數量:我們維護一個變數 mx,表示在當前視窗中最多有多少個空氣方塊。

  3. 計算答案:每次計算答案時,使用 \(K - \text{最大空氣方塊數量}\) 來表示最小的 TNT 方塊數量,這表示需要多少步才能避免 TNT 的塌陷。如果 \(K \geq N\),表示玩家可以直接跳過整個橋,輸出 \(-1\)​。

時間複雜度

每次處理一個序列的時間複雜度是 \(O(N)\),其中 \(N\) 是方塊序列的長度。整體複雜度為 \(O(T \times N)\),其中 \(T\) 是測試用例的數量。關於本體,還可以用二分來進一步最佳化。本文不過多陳述。

程式碼實現

本題的 C++ 程式碼如下:

#include <iostream>
#include <string>

using namespace std;

int solve(int N, int K, string S) {
    if (K >= N) return -1;
    ++K;
    int mx = 0, cur = 0;
    for (int i = 0; i < N; ++i) {
        if (S[i] == '-') ++cur;
        if (i - K >= 0 && S[i - K] == '-') --cur;
        mx = max(mx, cur);
    }
    return K - mx;
}

int main() {
    int T;
    cin >> T;
    for (int i = 0; i < T; i++) {
        int N, K;
        cin >> N >> K;
        string S;
        cin >> S;
        cout << solve(N, K, S) << '\n';
    }
    return 0;
}

本題的 Python 程式碼如下:

def solve(N: int, K: int, S: str) -> int:
    if K >= N:
        return -1
    mx, cur = 0, 0
    K += 1
    for i in range(N):
        if S[i] == '-':
            cur += 1
        if i - K >= 0 and S[i - K] == '-':
            cur -= 1
        mx = max(mx, cur)
    return K - mx


def main():
    T = int(input())
    for _ in range(T):
        temp = input().split()
        N, K = int(temp[0]), int(temp[1])
        S = input()
        print(solve(N, K, S))


if __name__ == '__main__':
    main()

T4 - 小丑牌

題目連結跳轉:點選跳轉

解題思路

也是一道模擬題,但需要注意以下幾點:

  1. 同一個點數可能會出現五次,那麼此時應該輸出 High Card(如題意)。
  2. 如果有一個牌型符合上述多條描述,請輸出符合描述的牌型中在規則中最後描述的牌型。
  3. 牌的數量不侷限於傳統的撲克牌,每張牌可以題目中四種花色的任意之一。

程式碼實現

本題的 C++ 程式碼如下:

#include <iostream>
#include <algorithm>
using namespace std;

int t;
struct card{
    int rank;
    string suit;
} cards[10];

bool cmp(card a, card b){
    return a.rank < b.rank;
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
	cin >> t;
    while(t--){
        for (int i=1; i<=5; i++){
            string rank; string suit;
            cin >> rank >> suit;
            int act;
            if (rank == "J") act = 11;
            else if (rank == "Q") act = 12;
            else if (rank == "K") act = 13;
            else if (rank == "A") act = 14;
            else if (rank == "10") act = 10;
            else act = rank[0] - '0';
            cards[i] = (card){act, suit};
        }
        sort(cards+1, cards+6, cmp);
        int cnt[20] = {};
        bool isSameSuit = true;
        bool isRanked = true;
        int pairs = 0; int greatest = 0;
        for (int i=1; i<=5; i++){
            cnt[cards[i].rank] += 1;
            if (i > 1 && cards[i].suit != cards[i-1].suit)
                isSameSuit = false;
            if (i > 1 && cards[i-1].rank + 1 != cards[i].rank)
				isRanked = false;
        }
        for (int i=1; i<=15; i++){
            if (cnt[i] == 2) pairs++;
            greatest = max(greatest, cnt[i]);
        }
        if (isRanked && isSameSuit){
            if (cards[5].rank == 14) cout << "Royal Flush" << endl;
            else cout << "Straight Flush" << endl;
        } else if (isRanked) cout << "Straight" << endl;
        else if (pairs == 1 && greatest == 3) cout << "Full House" << endl;
        else if (greatest == 4) cout << "Four of a Kind" << endl;
        else if (greatest == 3) cout << "Three of a Kind" << endl;
        else if (pairs == 2) cout << "Two Pairs" << endl;
        else if (pairs == 1) cout << "One Pair" << endl;
        else cout << "High Card" << endl;
    }
    return 0;
}

T5 - Vertex Verse

題目連結跳轉:點選跳轉

直接模擬情況就可以了,但是細節比較多需要注意一下。

時間複雜度

其中 work 函式會在每次迭代中被呼叫 \(4\) 次,每次的複雜度是 \(O(E)\)。因此,對於每個輸入的 \(q\) 對點,總的時間複雜度是 \(\Theta(4 \times q \times E)\),即\(O(q \times E)\)。但在最壞情況下,圖中的邊數 \(E\) 可以接近 \(O(n \times m)\),因此總時間複雜度是 \(O(q \times n \times m)\)​。

程式碼實現

本題的 C++ 程式碼如下:

#include <iostream>
#include <cstring>
#include <map>
#include <algorithm>
using namespace std;

int n, m, q, ei;
int a, b, c, d;
int macw, alex;
struct perEdge{
    int to;
    int next;
} edges[2000005];
int vertex[1000005], cnt;
bool vis[1000005], memo[1000005][5], track;

inline int calc_dir(int x, int y){
    if (x + 1 == y) return 1;
    if (x + m == y) return 2;
    if (x + m + 1 == y) return 3;
    return -1;
}

void add(int v1, int v2){
    ei += 1;
    edges[ei].to = v2;
    edges[ei].next = vertex[v1];
    vertex[v1] = ei;
}

void dfs(int x, int steps, int origin, int dir){
    if (steps > 4) return ;
    if (steps == 4 && x == origin){
        // 說明走一圈可以走到原點。
        // 判斷這個環是否已經被之前記錄過了。
        if (memo[origin][dir]) return ;
        memo[origin][dir] = 1;
        cnt += 1;
        return ;
    }
    for (int index = vertex[x]; index; index = edges[index].next){
        int to = edges[index].to;
        // 只走編號比自己大的點。
        if (to >= origin || (to == origin && steps + 1 == 4)){
            if (vis[to]) continue;
            vis[to] = 1;
            dfs(to, steps + 1, origin, dir);
            vis[to] = 0;
        }
    }
}

void work(int x){
    for (int index = vertex[x]; index; index = edges[index].next){
        int to = edges[index].to;
        if (to <= x) continue;
        int dir = calc_dir(x, to);
        if (dir != -1){
            vis[to] = 1;
            dfs(to, 1, x, dir);
            vis[to] = 0;
        }
    }
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> m >> q;
    for (int i=1; i<=2*q; i++){
        cin >> a >> b >> c >> d;
        int v1 = (a - 1) * m + b;
        int v2 = (c - 1) * m + d;
        add(v1, v2); add(v2, v1);
        // 從v1點開始走四步,看一下能否回到原點
        cnt = 0;
        if (a == c){ 
            // 在同一排
            work(v1); work(v2);
            work(v1 - m); work(v2 - m);
        } else if (b == d){
            // 在同一列
            work(v1); work(v2);
            work(v1 - 1); work(v2 - 1);
        }
        if (i % 2) macw += cnt / 2;
        else alex += cnt / 2;
    }
    cout << macw << " " << alex << endl;
    return 0;
}

另一種寫法如下:

#include <iostream>
#include <unordered_map>
#include <map>
#include <vector>
#include <utility>
using namespace std;
typedef long long ll;
typedef pair<int,int> pi;
const int N = 2e5 + 10;
map<pi,map<pi,int>> dis;
int mp[2];
bool check(pi a,pi b,pi c,pi d) {
    return dis[a][b] && dis[a][c] && dis[c][d] && dis[b][d];
}
int main() {
    int n,m,q;
    cin >> n >> m >> q;
    q <<= 1;
    for (int i=0;i<q;i++) {
        pi a,b;
        cin >> a.first >> a.second >> b.first >> b.second;
        if (a > b) swap(a,b);
        dis[a][b]++;
        if (dis[a][b] > 1) {
            continue;
        }
        if (a.first == b.first) {
            pi c = make_pair(a.first + 1,a.second);
            pi d = make_pair(b.first + 1,b.second);
            if (check(a,b,c,d)) {
                mp[i%2]++;
            }
            pi e = make_pair(a.first - 1,a.second);
            pi f = make_pair(b.first - 1,b.second);
            if (check(e,f,a,b)) {
                mp[i%2]++;
            }
        }else {
            pi c = make_pair(a.first,a.second + 1);
            pi d = make_pair(b.first,b.second + 1);
            if (check(a,c,b,d)) {
                mp[i%2]++;
            }
            pi e = make_pair(a.first,a.second - 1);
            pi f = make_pair(b.first,b.second - 1);
            if (check(e,a,f,b)) {
                mp[i%2]++;
            }
        }
    }
    cout << mp[0] << " " << mp[1] << endl;
    return 0;
}

T6 - 最優政府大樓選址-2

題目連結跳轉:點選跳轉

解題思路

本題有好多解決方法,可以使用帶權中位數寫 \(N = 10^5\)​,為了考慮到樸素的模擬退火演算法,本題的資料範圍被適當降低了。如果學過模擬退火演算法做這道題就非常的簡單,把模板的評估函式改成計算意向程度的函式即可。

時間複雜度

模擬退火演算法的時間複雜度約為 \(\Theta(k \times N)\)。其中 \(k\) 表示的是在模擬退火過程中計算舉例的 F2 函式的呼叫次數。\(N\) 表示資料規模。

程式碼實現

本題的 C++ 程式碼如下(模擬退火):

#include <iostream>
#include <algorithm>
#include <cmath>
#include <random>
using namespace std;

double n;
struct apart{
    double x, y;
} arr[10005];
double ei[10005];
double answ = 1e18, ansx, ansy;

double dis(double x1, double y1, double x2, double y2, int c){
    return (abs(x1 - x2) + abs(y1 - y2)) * ei[c];
}

double F2(double x, double y){
    double ans = 0;
    for (int i=1; i<=n; i++){
        ans += dis(x, y, arr[i].x, arr[i].y, i);
    }
    return ans;
}

void SA(){
    double T = 3000, cold = 0.999, range = 1e-20;
    answ = F2(ansx, ansy);
    while(T > range){
        double ex = ansx + (rand() * 2.0 - RAND_MAX) * T;
        double ey = ansy + (rand() * 2.0 - RAND_MAX) * T;
        double ea = F2(ex, ey);
        double dx = ea - answ;
        if (dx < 0){
            ansx = ex;
            ansy = ey;
            answ = ea;
        } else if (exp(-dx/T) * RAND_MAX > rand()){
            ansx = ex;
            ansy = ey;
        }
        T *= cold;
    }
    return ;
}

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n;
    for (int i=1; i<=n; i++) cin >> ei[i];
    for (int i=1; i<=n; i++){
        cin >> arr[i].x >> arr[i].y;
        ansx += arr[i].x; ansy += arr[i].y;
    }
    ansx /= n; ansy /= n;
    for (int i=1; i<=1; i++) SA();
    printf("%.5lf\n", answ);
    return 0;
}

使用加權中位數演算法的 C++ 程式碼:

#include <bits/stdc++.h>

constexpr double EPS = 1e-6;

double get_med(const std::vector<double> &a, const std::vector<int> &e) {
    int n = a.size();
    double lo = -1e3, hi = 1e3;
    auto f = [&](double x) {
        double sum = 0.0;
        for (int i = 0; i < n; ++i)
            sum += std::abs(x - a[i]) * e[i];
        return sum;
    };
    while (hi - lo > EPS) {
        double mid = (lo + hi) / 2;
        if (f(mid - EPS) > f(mid) and f(mid) > f(mid + EPS))
            lo = mid;
        else
            hi = mid;
    }
    return lo;
}

int main() {
    int n; std::cin >> n;
    std::vector<int> e(n);
    for (auto &x : e) std::cin >> x;
    std::vector<double> x(n), y(n);
    for (int i = 0; i < n; ++i)
        std::cin >> x[i] >> y[i];
    double ax = get_med(x, e), ay = get_med(y, e);
    double res = 0.0;
    for (int i = 0; i < n; ++i)
        res += (std::abs(ax - x[i]) + std::abs(ay - y[i])) * e[i];
    std::cout << std::setprecision(6) << std::fixed << res << '\n';
    return 0;
}

T7 - 烏龜養殖場

題目連結跳轉:點選跳轉

前置知識:

  1. 瞭解過基本的動態規劃。
  2. 熟練掌握二進位制的位運算。

至於為什麼放了一道模版題,原因是因為需要湊到八道題目,實在湊不到了,找了一個難度適中的。

題解思路

這是一道典型的狀壓動態規劃問題。設 \(dp_{i, j}\) 表示遍歷到第 \(i\) 行的時候,當前行以 \(j_{(base2)}\) 的形式排列烏龜可以構成的方案數。

對於每一行的方案,我們可以用一個二進位制來表示。例如二進位制數字 \(10100\),表示有一個橫向長度為 \(5\) 的場地中,第 \(1, 3\) 號位置分別放置了一隻小烏龜。因此,每一種擺放狀態都可以用一個二進位制數字來表示。我們也可以透過遍歷的方式來遍歷出二進位制的每一種擺放狀態。

首先,我們預處理出橫排所有放置烏龜的合法情況。根據題意,兩個烏龜不能相鄰放置,因此在二進位制中,不能有兩個 \(1\) 相鄰。如何預處理出這種情況呢?我們可以使用位運算的方法:

如果存在一個二進位制數字有兩個 \(1\) 相鄰,那麼如果我們對這個數字 \(x\) 進行位運算操作 (x << 1) & x 的結果或 (x >> 1) & x 的結果必定大於等於 \(1\)。我們透過把這種情況排除在外。同時,我們還需要注意有些格子中不能放置烏龜。這一步也可以透過二進位制的方法預處理掉,如果網箱在第 \(i\) 一個格子中不能放置烏龜,那麼在列舉所有方案數的時候直接忽略掉第 \(i\) 位為 \(1\) 的情況即可。

接下來如何保證上下兩行的烏龜不衝突?假如上一行的擺放狀態是 \(y\),當前行的擺放狀態為 \(j\),如果 i & j 的結果大於等於 \(1\),也可以證明有兩個數字 \(1\) 在同一位置上。因此我們也需要把這種情況排除在外。

綜上所述,我們可以得出狀態轉移方程:\(dp_{i, j} = dp_{i, j} + dp_{i-1, k}\)。其中,\(j\)\(k\) 表示所有橫排合法的方案。答案就是 \(\mathtt{ANS} = \sum_{j=0}^{2^M-1}{dp_{N, j}}\)

狀態的初始化也很簡單,另 \(dp_{0, 0} = 1\)​,表示一隻烏龜都不放有一種擺放方案。

時間複雜度

透過觀察上述程式碼,在列舉所有狀態和轉移狀態的時候有三層迴圈,分別是列舉當前行、列舉當前行的合法擺放情況以及列舉上一行的擺放情況。因此總時間複雜度約為 \(O(n \times 2^M \times 2^M) = O(n \times 2^{M^2}) = O(n \times 4^M)\)。但由於合法的擺放數量遠遠少於 \(2^M\),因此實際情況下程式執行的速度會快許多。

程式碼實現

本題的程式碼實現如下。在輸出的時候需要減一,因為不放置也是一種合法情況,根據題目要求需要把這一合法情況排除。

#include <iostream>
using namespace std;

const int MOD = 1e9+7;
int n, m, ans;
int arr[505][505];
// 所有橫排合法的情況。
int terrain[505];
int ok[1050], cnt;
int dp[505][1050];

int main(){
    cin >> n >> m;
    for (int i=1; i<=n; i++){
        for (int j=1; j<=m; j++){
            cin >> arr[i][j];
        }
    }
    
    // 預處理非法地形。
    for (int i=1; i<=n; i++){
        for (int j=1; j<=m; j++){
            terrain[i] = (terrain[i] << 1) + !arr[i][j];
        }
    }
    
    // 預處理出所有橫排的合法情況。
    for (int i=0; i<(1<<m); i++){
        if (((i<<1)|(i>>1)) & i) continue;
        ok[++cnt] = i;
    }
    dp[0][1] = 1;

    // 列舉。
    for (int i=1; i<=n; i++){
        for (int s1=1; s1<=cnt; s1++){  // 列舉當前行。
            if (ok[s1] & terrain[i]) continue;
            for (int s2=1; s2<=cnt; s2++){  // 列舉上一行。
                if (ok[s2] & terrain[i-1]) continue;
                if (ok[s1] & ok[s2]) continue;
                dp[i][s1] = (dp[i][s1] + dp[i-1][s2]) % MOD;
            }
        }
    }

    // 統計答案。
    int ans = 0;
    for (int i=1; i<=cnt; i++)
        ans = (ans + dp[n][i]) % MOD;
    
    cout << ans - 1 << endl;
    return 0;
}

本題的 Python 程式碼如下,Python 可以透過本題的所有測試點:

MOD = int(1e9 + 7)
n, m, ans = 0, 0, 0
arr = [[0] * 505 for _ in range(505)]
terrain = [0] * 505
ok = [0] * 1050
dp = [[0] * 1050 for _ in range(505)]
cnt = 0

def main():
    global n, m, cnt, ans
    
    # 輸入 n 和 m
    n, m = map(int, input().split())
    
    # 輸入 arr 陣列
    for i in range(1, n + 1):
        arr[i][1:m + 1] = map(int, input().split())
    
    # 預處理非法地形
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            terrain[i] = (terrain[i] << 1) + (1 - arr[i][j])
    
    # 預處理出所有橫排的合法情況
    for i in range(1 << m):
        if ((i << 1) | (i >> 1)) & i:
            continue
        cnt += 1
        ok[cnt] = i
    
    dp[0][1] = 1
    
    # 列舉
    for i in range(1, n + 1):
        for s1 in range(1, cnt + 1):  # 列舉當前行
            if ok[s1] & terrain[i]:
                continue
            for s2 in range(1, cnt + 1):  # 列舉上一行
                if ok[s2] & terrain[i - 1]:
                    continue
                if ok[s1] & ok[s2]:
                    continue
                dp[i][s1] = (dp[i][s1] + dp[i - 1][s2]) % MOD
    
    # 統計答案
    ans = 0
    for i in range(1, cnt + 1):
        ans = (ans + dp[n][i]) % MOD
    
    print(ans - 1)

if __name__ == "__main__":
    main()

再提供一個暴力解法用於對拍:

#include <iostream>
using namespace std;

const int MOD = 1e9+7;
int n, m, ans;
int arr[505][505];
int dx[] = {0, 1, -1, 0};
int dy[] = {1, 0, 0, -1};

// 深度優先搜尋 Brute Force
void dfs(int x, int y){
    if (x > n) {
        ans += 1;
        ans %= MOD;
        return ;
    }
    if (y > m){
        dfs(x+1, 1);
        return ;
    }
    if (arr[x][y] == 0){
        dfs(x, y+1);
        return ;
    }
    // 不放魚
    dfs(x, y+1);

    // 放魚
    for (int i=0; i<4; i++){
        int cx = x + dx[i];
        int cy = y + dy[i];
        if (cx < 1 || cy < 1 || cx > n || cy > m) continue;
        if (arr[cx][cy] == 2) return ;
    }
    arr[x][y] = 2;
    dfs(x, y+1);
    arr[x][y] = 1;
    return ;
}

int main(){
    cin >> n >> m;
    for (int i=1; i<=n; i++){
        for (int j=1; j<=m; j++){
            cin >> arr[i][j];
        }
    }
    // dfs 暴力
    dfs(1, 1);
    cout << ans-1 << endl;
    return 0;
}

T8 - 資料中心能耗分析

題目連結跳轉:點選跳轉

本文僅針對對線段樹有一定了解且並且有著較高的程式設計能力的選手。本文的前置知識如下:

  1. 線段樹的構造與維護 - 可以參考文章 淺入線段樹與區間最值問題
  2. 初中數學 - 完全平方和公式和完全立方和公式。
  3. 取模 - 之如何保證對所有整數取模後的結果為非負整數 - 可以參考本題的 說明/提示 部分。

原本出題的時候我並不知道這道題在許多 OJ 上是有原題的,so sad(下次再改進)。

題目本身應該非常好理解,就是給定一個陣列,讓你設計一個程式,對程式進行區間求立方和和區間修改的操作。但本題的資料量比較大,\(N, M\) 最大可以達到 \(10^5\),對於每一次修改和查詢都是 \(O(N)\) 的時間複雜度,顯然用暴力的方法時間複雜度絕對會超時,最高會到 \(O(N^2 * M)\) (大概需要 \(115\) 天的時間才可以搞定一個測試點)。當看到區間查詢和維護操作的時候,不難想到用線段樹的方法,線上段樹中,單次操作的時間複雜度約為 \(O(log_2 N)\),即使當 \(N\) 非常大的時候線段樹也可以跑得飛起。

解題思路

不得不說,這是一道比較噁心的線段樹區間維護的題目。不光寫起來比較費勁,而且維護操作運算量比較大。稍有不慎就會寫歪(因此寫這道題的時候要集中注意力,稍微一個不起眼的問題就容易爆 \(0\))。

本題的主要難點就是對一個區間內進行批次區間累加的操作。很容易就想到跟完全立方公式的聯絡:\((a+b)^3 = a^3 +3a^2b + 3ab^2 + b^3\)。區間累加操作也只不過是對區間的所有數字都進行該操作,並對所有操作的結果求和就是該區間進行操作後的立方和。化簡可得:

\[\begin{align} \mathtt{ANS} &= \sum_{i=1}^{n}{(a_i+x)^3}\\ &= (a_1+b)^3+(a_2+b)^3+ \cdots + (a_n+b)^3 \\ &= (a_1^3 + 3a_1^2b + 3a_1b^2 + b^3) + (a_2^3 + 3a_2^2b + 3a_2b^2 + b^3) + \cdots + (a_n^3 + 3a_n^2b + 3a_nb^2 + b^3) \\ &= (a_1^3 + a_2^3 + \cdots + a_n^3) + 3b(a_1^2 + a_2^2 + \cdots + a_n^2) + 3b^2(a_1 + a_2 + \cdots + a_n) + nb^3 \\ &= \sum_{i=1}^{n}{a_i^3} + 3b\sum_{i=1}^{n}{a_i^2} + 3b^2\sum_{i=1}^{n}{a_i} + nb^3 \end{align} \]

綜上所述,我們只需要用線段樹維護三個欄位,分別是區間的立方和、區間的平方和以及區間和。在維護平方和的過程中與維護立方和類似,根據完全平方公式 \((a+b)^2 = a^2 + 2ab + b^2\)。經過累加和化簡可得:

\[\begin{align} \mathtt{ANS} &= \sum_{i=1}^{n}{(a_i+x)^2}\\ &= (a_1+b)^2 + (a_2+b)^2 + \cdots + (a_n+b)^2 \\ &= (a_1^2 + 2a_1b + b^2) + (a_2^2 + 2a_2b + b^2) + \cdots + (a_n^2 + 2a_nb + b^2) \\ &= (a_1^2 + a_2^2 + \cdots + a_n^2) + 2b(a_1 + a_2 + \cdots + a_n) + nb^2 \\ &= \sum_{i=1}^{n}{a_i^2} + 2b\sum_{i=1}^{n}{a_i} + nb^2 \end{align} \]

以上三個欄位可以在構造線段樹的時候一併初始化,之後的每次更新直接修改懶標記就可以了。一切都交給 push_down() 函式。在每次區間查詢和修改之前都進行懶標記下放操作,對區間進行維護。具體維護操作如下:

// rt 是父節點,l和r是rt的兩個子節點,len是rt節點區間的長度。
// 其中,(len - len / 2)是l區間的長度,(len / 2)是r區間的長度。
void push_down(Node &rt, Node &l, Node &r, int len){
    if (rt.tag){
        int num = rt.tag;
		// 維護立方和
        l.s3 += 3 * num * l.s2 + 3 * num * num * l.s1 + (len - len / 2) * num * num * num;
        r.s3 += 3 * num * r.s2 + 3 * num * num * r.s1 + (len / 2) * num * num * num;
		
        //維護平方和
        l.s2 += 2 * num * l.s1 + (len - len / 2) * num * num;
        l.s2 += 2 * num * r.s1 + (len / 2) * num * num;

        // 維護區間總和
        l.s1 += (len - len / 2) * num;
        r.s1 += (len / 2) * num;
		
        // 將標記下放到兩個子區間
        l.tag += num;
        r.tag += num;
        rt.tag = 0;  // 清空標記。
    }
    return ;
}

注意事項

  1. 請注意取模,為了保證答案正確性,請在每一步操作的時候都對結果取模。
  2. long long,不然的話只能過前三個測試點(出題人還是挺好的,留了三個小的測試點騙粉)。
  3. 在維護立方和、平方和以及和的時候,請注意維護的順序。應當先維護立方和,再維護平方和,最後再維護區間總和。
  4. 注意線段樹陣列的大小,應當為 \(4 \times N\)
  5. 建議使用讀入最佳化,直接使用 cin 的效率比 std 慢約 \(100\%\)​。

時間複雜度

線段樹單次查詢和修改的複雜度約為 \(O(log_2 N)\),初始化的時間複雜度為 \(\Theta(N)\),因此本程式碼的整體時間複雜度可以用多項式 \(\Theta(N + M \cdot log_2(N))\) 來表示,整體程式碼的時間複雜度就為 \(O(M \cdot log_2(N))\)。在極限資料下,程式只需要 \(160ms\) 就可以完成暴力一整年所需的工作。

程式碼實現

  1. 程式碼使用了宏定義,方便後期進行調式。
  2. 以下程式碼與普通的線段樹無太大區別,但請著重關注 push_down() 下放操作。
#include <iostream>
#include <algorithm>
using namespace std;

const int N = 1e5 + 5;
const int MOD = 1e9 + 7;
// 宏定義:lc和rc分別表示左兒子和右兒子在陣列中的索引。
#define lc root << 1
#define rc root << 1 | 1
#define int long long
int n, m, k, x, y, v;

struct Node {
    // 分別表示總和、平方和與立方和。
    int s1, s2, s3;
    int tag;
} tree[N << 2];

// 合併區間,直接將兩個子區間相加即可。
void push_up(int root) {
    tree[root].s1 = (tree[lc].s1 + tree[rc].s1) % MOD;
    tree[root].s2 = (tree[lc].s2 + tree[rc].s2) % MOD;
    tree[root].s3 = (tree[lc].s3 + tree[rc].s3) % MOD;
    return;
}

// 下放操作,確實碼量比較大。在借鑑的時候需要仔細點。
void push_down(Node &rt, Node &l, Node &r, int len) {
    if (rt.tag) {
        int num = rt.tag;

        l.s3 = (l.s3 + 3 * num % MOD * l.s2 % MOD + 3 * num % MOD * num % MOD * l.s1 % MOD + (len - len / 2) * num % MOD * num % MOD * num % MOD) % MOD;
        l.s3 = (l.s3 + MOD) % MOD;
        r.s3 = (r.s3 + 3 * num % MOD * r.s2 % MOD + 3 * num % MOD * num % MOD * r.s1 % MOD + (len / 2) * num % MOD * num % MOD * num % MOD) % MOD;
        r.s3 = (r.s3 + MOD) % MOD;

        l.s2 = (l.s2 + 2 * num % MOD * l.s1 % MOD + (len - len / 2) * num % MOD * num % MOD) % MOD;
        l.s2 = (l.s2 + MOD) % MOD;
        r.s2 = (r.s2 + 2 * num % MOD * r.s1 % MOD + (len / 2) * num % MOD * num % MOD) % MOD;
        r.s2 = (r.s2 + MOD) % MOD;

        l.s1 = (l.s1 + (len - len / 2) * num % MOD) % MOD;
        l.s1 = (l.s1 + MOD) % MOD;
        r.s1 = (r.s1 + (len / 2) * num % MOD) % MOD;
        r.s1 = (r.s1 + MOD) % MOD;

        l.tag = (l.tag + num) % MOD;
        l.tag = (l.tag + MOD) % MOD;
        r.tag = (r.tag + num) % MOD;
        r.tag = (r.tag + MOD) % MOD;
        rt.tag = 0;
    }
    return;
}

// 非常簡單的造樹操作。
void build(int l, int r, int root) {
    if (l == r) {
        int t; cin >> t;
        tree[root].s1 = t % MOD; 
        tree[root].s2 = t * t % MOD;
        tree[root].s3 = t * t % MOD * t % MOD;
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, lc);
    build(mid + 1, r, rc);
    push_up(root);
    return;
}

// 更新操作。
void update(int l, int r, int v, int L, int R, int root) {
    // 跟push_down()函式基本類似。
    if (L <= l && r <= R) {
        tree[root].tag = (tree[root].tag + v) % MOD;
        tree[root].tag = (tree[root].tag + MOD) % MOD;
        tree[root].s3 = (tree[root].s3 + 3 * v % MOD * tree[root].s2 % MOD + 3 * v % MOD * v % MOD * tree[root].s1 % MOD + (r - l + 1) * v % MOD * v % MOD * v % MOD) % MOD;
        tree[root].s3 = (tree[root].s3 + MOD) % MOD;
        tree[root].s2 = (tree[root].s2 + 2 * v % MOD * tree[root].s1 % MOD + (r - l + 1) * v % MOD * v % MOD) % MOD;
        tree[root].s2 = (tree[root].s2 + MOD) % MOD;
        tree[root].s1 = (tree[root].s1 + (r - l + 1) * v % MOD) % MOD;
        tree[root].s1 = (tree[root].s1 + MOD) % MOD;
        return;
    }
    push_down(tree[root], tree[lc], tree[rc], r - l + 1);
    int mid = (l + r) >> 1;
    if (L <= mid) update(l, mid, v, L, R, lc);
    if (R >= mid + 1) update(mid + 1, r, v, L, R, rc);
    push_up(root);
}

// 區間查詢操作。
int query(int l, int r, int L, int R, int root) {
    if (L <= l && r <= R)
        return (tree[root].s3 + MOD) % MOD;
    int sum = 0;
    push_down(tree[root], tree[lc], tree[rc], r - l + 1);
    int mid = (l + r) >> 1;
    if (L <= mid) sum = (sum + query(l, mid, L, R, lc)) % MOD;
    if (R >= mid + 1) sum = (sum + query(mid + 1, r, L, R, rc)) % MOD;
    return (sum + MOD) % MOD;
}

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> m;
    build(1, n, 1);
    while (m--) {
        cin >> k >> x >> y;
        if (k == 1) {
            cin >> v;
            update(1, n, v, x, y, 1);
        } else cout << query(1, n, x, y, 1) << endl;
    }
    return 0;
}

以下是本題的 Python 程式碼,但由於 Python 常數過大,沒有辦法透過所有的測試點:

MOD = 10**9 + 7
N = int(1e5 + 5)

class Node:
    def __init__(self):
        self.s1 = 0
        self.s2 = 0
        self.s3 = 0
        self.tag = 0

tree = [Node() for _ in range(N * 4)]

def push_up(root):
    tree[root].s1 = (tree[root * 2].s1 + tree[root * 2 + 1].s1) % MOD
    tree[root].s2 = (tree[root * 2].s2 + tree[root * 2 + 1].s2) % MOD
    tree[root].s3 = (tree[root * 2].s3 + tree[root * 2 + 1].s3) % MOD

def push_down(root, l, r, length):
    if tree[root].tag != 0:
        num = tree[root].tag % MOD
        left_child = root * 2
        right_child = root * 2 + 1
        left_len = length - length // 2
        right_len = length // 2

        # Left child updates
        tree[left_child].s3 = (tree[left_child].s3 + (3 * num * tree[left_child].s2 % MOD) + (3 * num * num % MOD * tree[left_child].s1 % MOD) + (left_len * num % MOD * num % MOD * num % MOD)) % MOD
        tree[left_child].s2 = (tree[left_child].s2 + (2 * num * tree[left_child].s1 % MOD) + (left_len * num % MOD * num % MOD)) % MOD
        tree[left_child].s1 = (tree[left_child].s1 + (left_len * num % MOD)) % MOD
        tree[left_child].tag = (tree[left_child].tag + num) % MOD

        # Right child updates
        tree[right_child].s3 = (tree[right_child].s3 + (3 * num * tree[right_child].s2 % MOD) + (3 * num * num % MOD * tree[right_child].s1 % MOD) + (right_len * num % MOD * num % MOD * num % MOD)) % MOD
        tree[right_child].s2 = (tree[right_child].s2 + (2 * num * tree[right_child].s1 % MOD) + (right_len * num % MOD * num % MOD)) % MOD
        tree[right_child].s1 = (tree[right_child].s1 + (right_len * num % MOD)) % MOD
        tree[right_child].tag = (tree[right_child].tag + num) % MOD

        tree[root].tag = 0

def build(l, r, root):
    if l == r:
        t = data[l - 1] % MOD
        tree[root].s1 = t
        tree[root].s2 = t * t % MOD
        tree[root].s3 = t * t % MOD * t % MOD
        return
    mid = (l + r) // 2
    build(l, mid, root * 2)
    build(mid + 1, r, root * 2 + 1)
    push_up(root)

def update(l, r, v, L, R, root):
    if L <= l and r <= R:
        num = v % MOD
        length = r - l + 1
        tree[root].tag = (tree[root].tag + num) % MOD
        tree[root].s3 = (tree[root].s3 + (3 * num * tree[root].s2 % MOD) + (3 * num * num % MOD * tree[root].s1 % MOD) + (length * num % MOD * num % MOD * num % MOD)) % MOD
        tree[root].s2 = (tree[root].s2 + (2 * num * tree[root].s1 % MOD) + (length * num % MOD * num % MOD)) % MOD
        tree[root].s1 = (tree[root].s1 + (length * num % MOD)) % MOD
        return
    push_down(root, l, r, r - l + 1)
    mid = (l + r) // 2
    if L <= mid:
        update(l, mid, v, L, R, root * 2)
    if R > mid:
        update(mid + 1, r, v, L, R, root * 2 + 1)
    push_up(root)

def query(l, r, L, R, root):
    if L <= l and r <= R:
        return tree[root].s3 % MOD
    push_down(root, l, r, r - l + 1)
    mid = (l + r) // 2
    res = 0
    if L <= mid:
        res = (res + query(l, mid, L, R, root * 2)) % MOD
    if R > mid:
        res = (res + query(mid + 1, r, L, R, root * 2 + 1)) % MOD
    return res % MOD

if __name__ == '__main__':
    import sys
    sys.setrecursionlimit(1 << 25)
    n, m = map(int, sys.stdin.readline().split())
    data = list(map(int, sys.stdin.readline().split()))
    build(1, n, 1)
    for _ in range(m):
        tmp = sys.stdin.readline().split()
        if not tmp:
            continue
        k = int(tmp[0])
        x = int(tmp[1])
        y = int(tmp[2])
        if k == 1:
            v = int(tmp[3])
            update(1, n, v, x, y, 1)
        else:
            print(query(1, n, x, y, 1))

相關文章