首先考慮對於一個序列直接怎麼 \(O(n)\) 求 \(f\) 值,就發現考慮維護兩個指標 \(l,r\),如果 \(a_l=a_r\),則 \(l+1,r-1\),否則我們就讓小的那一個分裂,那麼每次操作一定可以減少長度,所以最優。
然後就可以 \(O(n^3)\),考慮換一種可最佳化的方式計算 \(f\),透過猜想大概就是看一下字首字尾集合有多少個不相等的數字。
發現能夠過拍。
考慮最佳化,我們考慮列舉區間,區間的和為 \(val\),且這個區間作為某個字首,其左端點為 \(l\),右端點為 \(r\)。
問題就是找有多少個 \(j(j\ge l)\) 滿足:
\(\forall i(l\le i)~ s_{i,j} \ne val\)
於是我們考慮列舉了 \(l,r\) 之後看看有多少個 \(j\) 不滿足條件即可。
然後發現某一個 \(j\) 所對應的 \(i\) 也一定只有至多一個滿足 \(s_{i,j}=val\)。
問題就是問有多少個區間 \([i,j]\) 滿足和 \(s_{i,j}=s_{l,r},(l\le i,r\le j)\)。
於是我們用 map
維護一個 int, vector<pair<int, int>>
表示值為 \(val\) 對應的區間。
然後按照左端點排序,對於每個區間看看它後面有多少個右端點大於它的右端點即可。
用樹狀陣列維護。
時間複雜度:\(O(n^2\log n)\)。
Code:
#include <bits/stdc++.h>
using namespace std;
#define rep(i,l,r) for (int i = (l); i <= (r); i++)
#define per(i,r,l) for (int i = (r); i >= (l); i--)
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
#define maxn(a, b) a = max(a, b)
#define minn(a, b) a = min(a, b)
typedef vector<int> VI;
typedef long long ll;
typedef pair<int,int> PII;
typedef double db;
const ll mod = 998244353;
mt19937 gen(114514);
ll rp(ll a,ll b) {ll res=1%mod;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll gcd(ll a,ll b) { return b?gcd(b,a%b):a;}
template<class T>
struct BIT {
vector<T> bit;
int size;
void init(int n) {
bit.clear(); bit.resize(n + 1, 0);
size = n;
}
void update(int x, T v) {
assert(x);
while (x <= size) {
bit[x] += v;
x += (x & -x);
}
}
T query(int x) {
assert(x <= size);
T res = 0;
while (x) {
res += bit[x];
x -= (x & -x);
}
return res;
}
int find(T k) {
assert(k); k--;
int x = 0;
for (int i = log2(size); i >= 0; i--) {
int y = x + (1 << i);
if (y <= size && bit[y] <= k) {
x = y;
k -= bit[y];
}
}
return x + 1;
}
};
BIT<int> t;
const int N = 2043;
int n;
int a[N];
map<int, vector<PII>> p;
void solve() {
// init();
scanf("%d", &n);
rep(i,1,n) scanf("%d", &a[i]);
ll ans = 0;
p.clear();
per(i,n,1) {
int sum = 0;
rep(j,i,n) {
sum += a[j];
p[sum].eb(i, j);
}
}
for (auto [val, ve] : p) {
t.init(n);
// cout << "val: " << val << endl;
for (auto [l, r] : ve) {
ans += n - r - t.query(n - r + 1);
// cout << l << ' ' << r << endl;
t.update(n - r + 1, 1);
}
// cout << ans << endl;
}
printf("%lld\n", ans);
}
int main() {
int T;
scanf("%d", &T);
while (T--)
solve();
return 0;
}