題目連結:http://codeforces.com/problemset/problem/148/D
題意:
一個袋子中有w只白老鼠,b只黑老鼠。
公主和龍輪流從袋子裡隨機抓一隻老鼠出來,不放回,公主先拿。
公主每次抓一隻出來。龍每次在抓一隻出來之後,會隨機有一隻老鼠跳出來(被龍嚇的了。。。)。
先抓到白老鼠的人贏。若兩人最後都沒有抓到白老鼠,則龍贏。
問你公主贏的概率。
題解:
表示狀態:
dp[i][j] = probability to win(當前公主先手,公主贏的概率)
i:剩i只白老鼠
j:剩j只黑老鼠
找出答案:
ans = dp[w][b]
邊界條件:
if i==0 dp[i][j] = 0 (沒有白老鼠了,不可能贏)
else if j==0 dp[i][j] = 1 (有且只有白老鼠,一定贏)
else if j==1 dp[i][j] = i/(i+1) (如果公主拿了黑老鼠,那麼龍一定會拿到白老鼠,公主輸。所以公主一下就要拿到白老鼠)
如何轉移:
對於dp[i][j],有兩種贏的方法:
(1)公主在這個回合一次就抓到了白老鼠。
(2)公主和龍都各抓了一隻黑老鼠,然後公主在下一個回合贏了。
P(一次就抓到了白老鼠) = i/(i+j)
P(進入下個回合,即兩人都抓到黑老鼠) = P(公主抓到黑老鼠) * P(龍抓到黑老鼠) = j/(i+j) * (j-1)/(i+j-1)
所以dp[i][j] = P(一次就抓到了白老鼠) + P(進入下個回合) * P(在下個回合贏)
那麼考慮下個回合可能的狀態。
因為公主和龍都已經抓走了兩隻黑老鼠,那麼下個回合取決於跳出來的老鼠,有三種可能:
(1)跳出來白老鼠
(2)跳出來黑老鼠
(3)老鼠已經抓完了,沒有老鼠跳出來
對於情況(3),原狀態(i,j)只可能為:(1,1) , (0,2) , (2,0),均包含在邊界條件中,所以不作考慮。
剩下兩種情況的可能性:
(1)P(跳出來白老鼠) = i/(i+j-2) (i>=1 and j>=2)
(2)P(跳出來黑老鼠) = (j-2)/(i+j-2) (j>=3)
所以P(在下個回合贏) = P(跳出來白老鼠) * dp[i-1][j-2] + P(跳出來黑老鼠) * dp[i][j-3]
總方程:
nex = 0
if i>=1 and j>=2 nex += i/(i+j-2)*dp[i-1][j-2]
if j>=3 nex += (j-2)/(i+j-2)*dp[i][j-3]
dp[i][j] = i/(i+j) + j/(i+j) * (j-1)/(i+j-1) * nex
另外,這道題的題解有兩個版本,一種記憶化搜尋,一種for迴圈版,都差不多。
AC Code(記憶化搜尋):
1 // state expression: 2 // dp[i][j] = probability to win 3 // i: i white mice 4 // j: j black mice 5 // 6 // find the answer: 7 // ans = dp[w][b] 8 // 9 // transferring: 10 // if i>=1 and j>=2 nex += i/(i+j-2)*dp[i-1][j-2] 11 // if j>=3 nex += (j-2)/(i+j-2)*dp[i][j-3] 12 // dp[i][j] = i/(i+j) + j/(i+j) * (j-1)/(i+j-1) * nex 13 // 14 // boundary: 15 // if i==0 dp[i][j] = 0 16 // if j==0 dp[i][j] = 1 17 // if j==1 dp[i][j] = i/(i+1) 18 #include <iostream> 19 #include <stdio.h> 20 #include <string.h> 21 #define MAX_N 1005 22 23 using namespace std; 24 25 int w,b; 26 bool vis[MAX_N][MAX_N]; 27 double ans; 28 double dp[MAX_N][MAX_N]; 29 30 double dfs(int i,int j) 31 { 32 if(vis[i][j]) return dp[i][j]; 33 vis[i][j]=true; 34 if(i==0) return dp[i][j]=0; 35 if(j==0) return dp[i][j]=1; 36 if(j==1) return dp[i][j]=(double)i/(i+1); 37 double nex=0; 38 nex+=(double)i/(i+j-2)*dfs(i-1,j-2); 39 if(j>=3) nex+=(double)(j-2)/(i+j-2)*dfs(i,j-3); 40 return dp[i][j]=(double)i/(i+j)+(double)j/(i+j)*(j-1)/(i+j-1)*nex; 41 } 42 43 void read() 44 { 45 cin>>w>>b; 46 } 47 48 void solve() 49 { 50 memset(vis,false,sizeof(vis)); 51 ans=dfs(w,b); 52 } 53 54 void print() 55 { 56 printf("%.9f\n",ans); 57 } 58 59 int main() 60 { 61 read(); 62 solve(); 63 print(); 64 }
AC Code(for迴圈):
1 #include <iostream> 2 #include <stdio.h> 3 #include <string.h> 4 #define MAX_N 1005 5 6 using namespace std; 7 8 int w,b; 9 double ans; 10 double dp[MAX_N][MAX_N]; 11 12 void read() 13 { 14 cin>>w>>b; 15 } 16 17 void solve() 18 { 19 memset(dp,0,sizeof(dp)); 20 for(int i=0;i<=w;i++) 21 { 22 for(int j=0;j<=b;j++) 23 { 24 if(i==0) 25 { 26 dp[i][j]=0; 27 continue; 28 } 29 if(j==0) 30 { 31 dp[i][j]=1; 32 continue; 33 } 34 if(j==1) 35 { 36 dp[i][j]=(double)i/(i+1); 37 continue; 38 } 39 double nex=(double)i/(i+j-2)*dp[i-1][j-2]; 40 if(j>=3) nex+=(double)(j-2)/(i+j-2)*dp[i][j-3]; 41 dp[i][j]=(double)i/(i+j)+(double)j/(i+j)*(j-1)/(i+j-1)*nex; 42 } 43 } 44 } 45 46 void print() 47 { 48 printf("%.9f\n",dp[w][b]); 49 } 50 51 int main() 52 { 53 read(); 54 solve(); 55 print(); 56 }