模擬賽搬的題,dp 思路很明顯,但難點就在於找到要轉移的點在哪。
暴力
首先我們可以先考慮 \(k=1\) 的情況,這應該很好想,就是對於每一個右括號,找到其匹配的左括號,然後進行轉移即可,這個過程可以用棧維護。
\(dp[i]\) 定義為以 \(i\) 為結尾的合法序列個數。假設當前右括號在 \(i\) 處,匹配的左括號在 \(j\) 處,則:
注意一定是要在保證能找到的情況下,轉移離自己最近的左括號,才能保證所有括號序列都被統計到了。
最後掃一遍把所有的 \(dp[i]\) 累加即可。考場做法拿了 40pts。
正解
上面的做法,我們發現可以擴充到全域性,也就是同時有 \(k\) 個序列的情況。
我們考慮一個括號序列的常用 trick:把左括號看作 \(+1\),把右括號看作 \(-1\),一個括號序列合法,當且僅當其總和為 \(0\) 且任何一段字首和都 \(\ge 0\)。
總和為 \(0\) 很好考慮,我們主要想任何一段字首和 \(\ge 0\) 怎麼搞。
觀察到字首和陣列每次相對前一項的變化量要麼是 \(1\) 要麼是 \(-1\),並且由於先要保證能找到,所以我們先找出可以匹配的左括號的區間左端點。
但是這樣並不好做,因為如果 \([l,r]\) 的和 \(<0\),\([l-1,r]\) 的和卻不一定 \(<0\)。所以固定右括號的方式不可行。
因此,我們才考慮固定左括號,去尋找右括號,並且把 dp 倒著做。
於是找出字首和 \(<0\) 的就很簡單了,對於一個左括號,其最多能匹配到的右括號一定在後面離自己最近的使字首和 \(<0\) 的地方。
這個我們可以透過從後往前掃描,記錄下考慮序列第 \(i\) 位到第 \(n\) 位裡面字首和為每一種數的最小下標,這樣我們就可以快速查詢在左括號後面,第一個使字首和 \(<0\) 的右括號在哪了。
如果找不到,就說明右括號在右邊的哪裡都可以,所以賦為最大值。
算完最大右端點後,我們對於每一列,求出其最大右端點中的最小值,這就是某一列裡可能的匹配範圍。
接下來考慮總和為 \(0\) 的限制,很容易發現對於字首和陣列而言是這樣的:
可得:
一個括號序列合法,必須每一行都滿足這個條件,也就是說對於兩個列而言,每一行的字首和相同,它才可能合法。
所以我們對每一列雜湊,存進 unordered_map,然後統計離自己最近的且在最大右端點左邊的相同位即可。
最後來個 dp 就完事了,時間是 \(O(nk)\) 的,但 unordered_map 可能有點常數。
程式碼
程式碼還是比較好寫的。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pi;
const ll eps=500005,mod=998244353;
int n,k,a[15][50005],f[15][50005],tot[15][110005],r[15][50005],pr[50005],y[50005];
ll hs[50005],dp[50005],ans;
unordered_map<ll,int>mp;
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin>>k>>n;
//處理原括號序列、字首和陣列、各列的雜湊值
for(int i=1;i<=k;i++)
{
for(int j=1;j<=n;j++)
{
char c;
cin>>c;
if(c=='(')a[i][j]=1;
else a[i][j]=-1;
f[i][j]=f[i][j-1]+a[i][j];
hs[j]=(hs[j]*10007%mod+f[i][j])%mod;
}
}
//統計右邊最遠可達的括號
memset(tot,0x3f,sizeof(tot));
memset(r,0x3f,sizeof(r));
for(int i=1;i<=k;i++)
{
for(int j=n;j>=1;j--)
{
tot[i][f[i][j]+eps]=j;
r[i][j]=tot[i][f[i][j-1]-1+eps];
}
}
//記錄對於每一列而言的右邊界
memset(pr,0x3f,sizeof(pr));
for(int i=1;i<=n;i++)
{
for(int j=1;j<=k;j++)
{
pr[i]=min(pr[i],r[j][i]);
}
}
//找出相同的雜湊值
for(int i=n;i>=1;i--)
{
mp[hs[i]]=i;
y[i]=mp[hs[i-1]];
}
//dp
for(int i=n;i>=1;i--)
{
if(y[i]!=0&&y[i]<=pr[i])
{
dp[i]=dp[y[i]+1]+1;
}
}
//統計答案
for(int i=1;i<=n;i++)
{
ans+=dp[i];
}
cout<<ans;
return 0;
}