CF2004 EDU169 F. Make a Palindrome

weirdoX發表於2024-08-19

首先考慮對於一個序列直接怎麼 \(O(n)\)\(f\) 值,就發現考慮維護兩個指標 \(l,r\),如果 \(a_l=a_r\),則 \(l+1,r-1\),否則我們就讓小的那一個分裂,那麼每次操作一定可以減少長度,所以最優。
然後就可以 \(O(n^3)\),考慮換一種可最佳化的方式計算 \(f\),透過猜想大概就是看一下字首字尾集合有多少個不相等的數字。
發現能夠過拍。

考慮最佳化,我們考慮列舉區間,區間的和為 \(val\),且這個區間作為某個字首,其左端點為 \(l\),右端點為 \(r\)
問題就是找有多少個 \(j(j\ge l)\) 滿足:
\(\forall i(l\le i)~ s_{i,j} \ne val\)
於是我們考慮列舉了 \(l,r\) 之後看看有多少個 \(j\) 不滿足條件即可。

然後發現某一個 \(j\) 所對應的 \(i\) 也一定只有至多一個滿足 \(s_{i,j}=val\)
問題就是問有多少個區間 \([i,j]\) 滿足和 \(s_{i,j}=s_{l,r},(l\le i,r\le j)\)
於是我們用 map 維護一個 int, vector<pair<int, int>> 表示值為 \(val\) 對應的區間。
然後按照左端點排序,對於每個區間看看它後面有多少個右端點大於它的右端點即可。
用樹狀陣列維護。
時間複雜度:\(O(n^2\log n)\)

Code:

#include <bits/stdc++.h>
using namespace std;
#define rep(i,l,r) for (int i = (l); i <= (r); i++)
#define per(i,r,l) for (int i = (r); i >= (l); i--)
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
#define maxn(a, b) a = max(a, b)
#define minn(a, b) a = min(a, b)
typedef vector<int> VI;
typedef long long ll;
typedef pair<int,int> PII;
typedef double db;
const ll mod = 998244353;
mt19937 gen(114514);
ll rp(ll a,ll b) {ll res=1%mod;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll gcd(ll a,ll b) { return b?gcd(b,a%b):a;}
template<class T>
struct BIT {
    vector<T> bit;
    int size;
    void init(int n) {
        bit.clear(); bit.resize(n + 1, 0);
        size = n;
    }
    void update(int x, T v) {
        assert(x);
        while (x <= size) {
            bit[x] += v;
            x += (x & -x);
        }
    }

    T query(int x) {
        assert(x <= size);
        T res = 0;
        while (x) {
            res += bit[x];
            x -= (x & -x);
        }
        return res;
    }

    int find(T k) {
        assert(k); k--;
        int x = 0;
        for (int i = log2(size); i >= 0; i--) {
            int y = x + (1 << i);
            if (y <= size && bit[y] <= k) {
                x = y;
                k -= bit[y];
            }
        }
        return x + 1;
    }
};
BIT<int> t;
const int N = 2043;
int n;
int a[N];
map<int, vector<PII>> p;

void solve() {
    // init();
    scanf("%d", &n);
    rep(i,1,n) scanf("%d", &a[i]);
    ll ans = 0;

    p.clear();
    per(i,n,1) {
    	int sum = 0;
    	rep(j,i,n) {
    		sum += a[j];
    		p[sum].eb(i, j);
    	}
    }
    for (auto [val, ve] : p) {
        t.init(n);
        // cout << "val: " << val << endl;
    	for (auto [l, r] : ve) {
            ans += n - r - t.query(n - r + 1);
            // cout << l << ' ' << r << endl;
            t.update(n - r + 1, 1);
    	}
        // cout << ans << endl;
    }
    printf("%lld\n", ans);
}

int main() {
    int T;
    scanf("%d", &T);
    while (T--)
        solve();
    return 0;
}

相關文章