[ABC234G] Divide a Sequence

wyl123ly發表於2024-09-26

[ABC234G] Divide a Sequence

給定長度為 \(N\) 的序列 \(A\),我們定義一種將 \(A\) 劃分為若干段的方案的價值為每一段的最大值減去最小值的差的乘積,你需要求出所有劃分方案的價值的總和,答案對 \(998244353\) 取模。

  • $ 1\ \leq\ N\ \leq\ 3\ \times\ 10^5 $
  • $ 1\ \leq\ A_i\ \leq\ 10^9 $

先考慮樸素 \(dp\),設 \(dp_i\) 為劃分序列 \(A\) 的前 \(i\) 項所有劃分方案的總和,則容易得到轉移方程式:

\[dp_i = \sum^{i - 1}_{x = 1} dp_x \times \{ \max^{i}_{k = x + 1}\{a_k\} + \min^{i}_{k = x + 1}\{a_k\} \} \]

複雜度 \(O(n^2)\) 考慮最佳化。

首先拆分式子:

\[dp_i = \sum^{i - 1}_{x = 1}dp_x \times \{\max^{i}_{k = x + 1}\{a_k\}\} + \sum^{i - 1}_{x = 1} dp_x \times \{\min^{i}_{k = x + 1}\{a_k\}\} \]

分成 \(max\)\(min\) 的兩個子問題處理。

我們這裡單單考慮 \(max\) 的情況:

對於 \(dp_i \to dp_{i + 1}\) 轉移的情況,容易發現:

\[\max^{i}_{k = x + 1}\{a_k\} \]

這個式子的值只會在 \(x > pos\) (其中 \(pos\)\(a_{i + 1}\)第一個\(a_{i + 1}\) 大的數的下標位置)的那些 \(x\) 值時會改變。

\(a_{i + 1}\)第一個\(a_{i + 1}\) 大的數的下標位置我們可以用單調棧 \(O(n)\) 的複雜度預處理完成。

程式碼:

#include<iostream>
#include<algorithm>
#include<stack>
using namespace std;
#define int long long
const int MOD = 998244353;
const int MAXN = 3e5 + 7;
int n;
int a[MAXN];
int idmax[MAXN],idmin[MAXN];
stack<int> st;
void init(){
	for(int i = 1;i <= n;i++){
		while(!st.empty() && a[st.top()] <= a[i]){
			st.pop();
		}
		if(st.empty()) idmax[i] = 0;
		else idmax[i] = st.top();
		st.push(i);
	}
	while(!st.empty()) st.pop();
	for(int i = 1;i <= n;i++){
		while(!st.empty() && a[st.top()] >= a[i]){
			st.pop();
		}
		if(st.empty()) idmin[i] = 0;
		else idmin[i] = st.top();
		st.push(i);
	}
	while(!st.empty()) st.pop();
}
int dp[MAXN];
int predp[MAXN];
int maxx[MAXN],minn[MAXN];
signed main(){
	// ios::sync_with_stdio(false);
	// cin.tie(0),cout.tie(0);
	cin>>n;
	for(int i = 1;i <= n;i++) cin>>a[i];
	init();
	dp[0] = predp[0] = 1;
	// for(int i = 1;i <= n;i++) cout<<idmin[i]<<" ";
	for(int i = 1;i <= n;i++){
		if(idmax[i]) maxx[i] = (maxx[idmax[i]] + (predp[i - 1] - predp[idmax[i] - 1]) * a[i]) % MOD;
		else maxx[i] = (predp[i - 1] * a[i]) % MOD;
		if(idmin[i]) minn[i] = (minn[idmin[i]] + (predp[i - 1] - predp[idmin[i] - 1]) * a[i]) % MOD;
		else minn[i] = (predp[i - 1] * a[i]) % MOD;
		dp[i] = ((maxx[i] - minn[i]) % MOD + MOD) % MOD;
		predp[i] = (predp[i - 1] + dp[i]) % MOD;
	}
	// for(int i = 1;i <= n;i++) cout<<dp[i]<<" "<<predp[i]<<endl;
	cout<<dp[n];
	return 0;
}

相關文章