CF908D-New Year and Arbitrary Arrangement

HANGRY_Sol&Cekas發表於2024-10-13

CF908D-New Year and Arbitrary Arrangement

前言

不是這題為啥星 \(2200\) 啊,感覺做的很多 \(3000\) 左右的題都比這道題水吧。

簡化題意

給定空字串,每次在串尾加入 \(a\)\(b\) ,各有一定機率。

若其中有 \(\ge k\)\(ab\) 子序列 , 則停止加入。

問至加入結束時,含有 \(ab\) 子序列個數的期望值。

\(k \le 1000\)

題解

感覺一眼機率 \(dp\) .

狀態設計的話, \(k\) 這麼小,感覺就像是二維。

那麼我們分析一下,如果 \(k = 1 , p_a = p_b = \frac{1}{2}\)

存在什麼情況捏,分兩類。

  1. 第一位是 a : \(ab , aab , aaab , aaaaaaaaaaaaaaa \dots b\)

注意好像 \(a\) 可以有無限個。

  1. 第一位是 b : \(bab , bbab , baab , \dots\)

我們發現這種情況其實就是第一位是 a 的情況前面不知道加幾個 b

所以我們只用統計第一種情況即可。

我們發現只要不存在逆天 \(aaaaaaaaaaaa \dots\)

\(a\) 的個數應該是 \(k\) 之內。如果存在較多 \(a\) , 那較多 \(a\) 處應該在串尾。

那我們證一下,如果存在前面 \(k + 1\)\(a\) , 那加一個 \(b\) 直接就停止加入了,所以其在串尾。

有些長得帥的小夥伴就問了,好像 \(b\) 的個數更為確定在 \(k\) 內吧,為什麼不用 \(b\) 轉移?

欸,其實我最早想的就是 \(b\) , 但是由於無法判斷新的機率值,所以沒法轉移。

所以 \(dp\) 狀態很明顯了, \(dp_{i , j}\) 表示 \(i\)\(a\) , \(j\) 個答案子序列的機率和。

\[dp_{i , j} = dp_{i - 1 , j} \times p_a + dp_{i , j - i} \times p_b \]

好的,那機率有了,怎麼統計答案?

呃呃呃這個時候就要想一想,如果每個位置都加 \(aaaaa \dots b\) , 那可能會重複的。

例: \(aabab\)\(aab\)

那怎麼整?

我們只需將 \(a\) 還沒用滿的時候,後面加個 \(b\) , \(a\) 用滿後,在加 \(aaaaa \dots b\) .

那這個逆天長串答案怎麼統計?

設原串中 \(a\)\(x\) 個,顯然答案為:

\[\begin{aligned} &= dp_{i , j} \times \left(\sum{p_b p_a^i \times (i + x)} \right) \\ &= dp_{i , j} \times \left(\frac{x p_b}{1 - p_a} + \frac{p_a + p_b}{(1 - p_a)^2}\right) \end{aligned}\]

做完啦!!!

code

CODE
#include <bits/stdc++.h>
using namespace std ; 
typedef long long ll ; 
const int N = 1e3 + 10 ; 
const int mod = 1e9 + 7 ; 

ll k , p11 , p22 , p1 , p2 , dp[N][N] ; 

inline ll Quick_Pow(ll a , ll b) {
	ll ans = 1 ; 

	while (b) {
		if (b & 1) ans = (ans * a) % mod ; 
		b >>= 1 , a = (a * a) % mod ; 
	}

	return ans ; 
}

ll ans = 0 ; 

signed main() {
	ios::sync_with_stdio(0) , cin.tie(0) , cout.tie(0) ; 
	cin >> k >> p11 >> p22 ; 

	p1 = (p11 * Quick_Pow(p11 + p22 , mod - 2)) % mod , p2 = (p22 * Quick_Pow(p11 + p22 , mod - 2)) % mod ; 
	dp[0][0] = 1 ; 

	for (int i = 1 ; i <= k ; ++ i) {
		for (int j = 0 ; j < k ; ++ j) {
			dp[i][j] = (dp[i - 1][j] * p1) % mod ; 

			if (j >= i) dp[i][j] = (dp[i][j] + dp[i][j - i] * p2 % mod) % mod ; 
			if (i != k && i + j >= k) ans = (ans + (((dp[i][j] * p2) % mod) * ((1ll * i + j) % mod))) % mod ; 
		}
	}

	ll sum1 , sum2 , ny ; 

	for (int i = 0 ; i < k ; ++ i) {
		ny = Quick_Pow((1 - p1 + mod) % mod , mod - 2) ; 
		sum1 = ((k * p2) % mod * ny) % mod , sum2 = (((p1 * p2) % mod) * ((ny * ny) % mod)) % mod ; 
		ans = (ans + (dp[k][i] * (sum1 + sum2 + i)) % mod) % mod ; 
	}

	ans = (ans * Quick_Pow((1 + mod - p2) % mod , mod - 2)) % mod ; 
	cout << ans ; 
}