題目描述
給定一個所有數互不相同的長度為 \(N\) 的序列 \(P\),你可以執行以下操作任意次:
- 選擇一對 \(1\le l < r\le N\),並把其中除最小值外的所有元素刪除。
求最終可以得到的不同序列數量。
思路
我們考慮怎樣透過刪除最少的元素來刪除 \(i\),很明顯,就是選擇區間 \([l,i]\) 或 \([i,r]\),這裡 \(l\) 是最大的滿足 \(l<i且 P_l<P_i\) 的下標,而 \(r\) 是最小的滿足 \(r>i且 P_r<P_i\) 的下標。這個可以使用單調棧求出。
我們每次都只刪除這種最小區間,因為這樣會使最終的方案更多。
令 \(dp_{0/1,i}\) 表示考慮前 \(i\) 個數,最後一個數選/不選的種類數。
我們有 \(dp_{0,i}\leftarrow dp_{0/1,l_i}\)(因為沒有任何影響,隨便怎麼選,但如果不刪除 \(i\),那麼 \(l_i+1\) 到 \(i\) 都無法刪除)。以及 \(dp_{1,i}\leftarrow dp_{1,j}(l_i\le j<i),dp_{0,l_i}\),因為刪掉 \(i\) 就意味著刪掉 \(l_i+1\) 到 \(i\),而 \(l_i\) 是既可以刪也可以不刪。這個可以使用字首和最佳化。
時空複雜度均為 \(O(N)\)。
程式碼
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 300001, MOD = 998244353;
int t, n, a[MAXN], stk[MAXN], top, dp[2][MAXN], sum[MAXN];
void Solve() {
cin >> n;
for(int i = 1; i <= n; ++i) {
cin >> a[i];
}
top = 0;
dp[1][0] = sum[0] = 1;
for(int i = 1; i <= n; ++i) {
for(; top && a[stk[top]] >= a[i]; --top) {
}
int j = stk[top];
dp[0][i] = (j ? (dp[1][j] + dp[0][j]) % MOD : 0);
dp[1][i] = (0ll + dp[0][j] + sum[i - 1] - (j ? sum[j - 1] : 0) + MOD) % MOD;
sum[i] = (sum[i - 1] + dp[1][i]) % MOD;
stk[++top] = i;
}
cout << (dp[0][n] + dp[1][n]) % MOD << "\n";
}
int main() {
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
for(cin >> t; t--; Solve()) {
}
return 0;
}