在一篇部落格GAN網路從入門教程(一)之GAN網路介紹中,簡單的對GAN網路進行了一些介紹,介紹了其是什麼,然後大概的流程是什麼。
在這篇部落格中,主要是介紹其數學公式,以及其演算法流程。當然數學公式只是簡單的介紹,並不會設計很複雜的公式推導。如果想詳細的瞭解GAN網路的原理,推薦去看李宏毅老師的課程。B站和Youtube上面都有。
概率分佈
生成器
首先我們是可以知道真實圖片的分佈函式\(p_{data}(x)\),同時我們把假的圖片也看成一個概率分佈,稱之為\(p_g = (x,\theta)\)。那麼我們的目標是什麼呢?我們的目標就是使得\(p_g(x,\theta)\)儘量的去逼近\(p_{data}(x)\)。在GAN中,我們使用神經網路去逼近\(p_g = (x,\theta)\)。
在生成器中,我們有如下模型:
其中\(z \sim P_{z}(z)\),因此\(G(z)\)也是一個針對於\(z\)概率密度分佈函式。
判別器
針對於判別器,我們有\(D(x,\theta)\),其代表某一張z圖片\(x\)為真的概率。
目標函式
在Generative Adversarial Nets論文中給出了以下的目標函式,也就是GAN網路需要優化的東西。
\[\begin{equation}\min _{G} \max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]\end{equation}
\]
公式看起來很複雜,但是我們分開來看還是比較簡單的。
\(D^*\)
\(D\)網路的目標是什麼?能夠辨別真假,也就是說,給定一張真的圖片\(x\),\(D\)網路能夠給出一個高分,也就是\(D(x)\)儘量大一點。而針對於生成器\(G\)生成的圖片\(G(z)\),我們希望判別器\(D\)儘量給低分,也就是\(D(G(z))\)儘量的小一點。因此\(D\)網路的目標函式如下所示:
\[\begin{equation}\max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]\end{equation}
\]
在目標函式中,\(x\)代表的是真實資料(也就是真的圖片),\(G(z)\)代表的是生成器生成的圖片。
\(G^*\)
\(G\)網路的目標就是使得\(D(G(z))\)儘量得高分,因此其目標函式可以寫成:
\[\begin{equation}\max _{G} V(D, G)=\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (D(G(\boldsymbol{z})))]\end{equation}
\]
\(D(G(z))\)儘量得高分(分數在\([0,1]\)之間),等價於\(1 - D(G(z))\)儘量的低分,因此,上述目標函式等價於:
\[\begin{equation}\min _{G} V(D, G)=\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]\end{equation}
\]
因此我們優化\(D^*\)和優化\(G^*\)結合起來,也就是變成了論文中的目標函式:
\[\begin{equation}\min _{G} \max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]\end{equation}
\]
證明存在全域性最優解
上面的公式看起來很合理,但是如果不存在最優解的話,一切也都是無用功。
D最優解
首先,我們固定G,來優化D,目標函式為:
\(\begin{equation} V(G, D)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]\end{equation}\)
我們可以寫做:
\[\begin{equation}\begin{aligned}
V(G, D) &=\int_{\boldsymbol{x}} p_{\text {data }}(\boldsymbol{x}) \log (D(\boldsymbol{x})) d x+\int_{\boldsymbol{z}} p_{\boldsymbol{z}}(\boldsymbol{z}) \log (1-D(g(\boldsymbol{z}))) d z \\
&=\int_{\boldsymbol{x}} [ p_{\text {data }}(\boldsymbol{x}) \log (D(\boldsymbol{x}))+p_{g}(\boldsymbol{x}) \log (1-D(\boldsymbol{x}))] d x
\end{aligned}\end{equation}
\]
我們設(\(D\)代表\(D(x)\),可以代表任何函式):
\[f(D) = P_{data}(x) log D + P_G(x)log(1-D)
\]
對於每一個固定的\(x\)而言,為了使\(V\)最大,我們當然是希望\(f(D)\)越大越好,這樣積分後的值也就越大。因為固定了\(G\),因此\(p_g(x)\)是固定的,而\(P_{data}\)是客觀存在的,則值也是固定的。我們對\(f(D)\)求導,然後令\(f'(D) = 0\),可得:
\[\begin{equation}D^{*}=\frac{P_{d a t a}(x)}{P_{d a t a}(x)+P_{G}(x)}\end{equation}
\]
下圖表示了,給定三個不同的 \(G1,G3,G3\) 分別求得的令 \(V(G,D)\)最大的那個$ D^∗\(,橫軸代表了\)P_{data}$,藍色曲線代表了可能的 \(P_G\),綠色的距離代表了 \(V(G,D)\):
G最優解
同理,我們可以求\(\underset{D}{max}\ V(G,D)\),我們將前面的得到的\(D^{*}=\frac{P_{d a t a}(x)}{P_{d a t a}(x)+P_{G}(x)}\)帶入可得:
\[% <![CDATA[
\begin{align}
& \underset{D}{min}\ V(G,D) \\
& = V(G,D^{* })\\
& = E_{x \sim P_{data} } \left [\ log\ D^{* }(x) \ \right ] + E_{x \sim P_{G} } \left [\ log\ (1-D^{* }(x)) \ \right ] \\
& = E_{x \sim P_{data} } \left [\ log\ \frac{P_{data}(x)}{P_{data}(x)+P_G(x)} \ \right ] + E_{x \sim P_{G} } \left [\ log\ \frac{P_{G}(x)}{P_{data}(x)+P_G(x)} \ \right ]\\
& = \int_{x} P_{data}(x) log \frac{P_{data}(x)}{P_{data}(x)+P_G(x)} dx+ \int_{x} P_G(x)log(\frac{P_{G}(x)}{P_{data}(x)+P_G(x)})dx \\
& = \int_{x} P_{data}(x) log \frac{\frac{1}{2}P_{data}(x)}{\frac{P_{data}(x)+P_G(x)}{2} } dx+ \int_{x} P_{G}(x) log \frac{\frac{1}{2}P_{G}(x)}{\frac{P_{data}(x)+P_G(x)}{2} } dx \\
& = \int_{x}P_{data}(x)\left ( log \frac{1}{2}+log \frac{P_{data}(x)}{\frac{P_{data}(x)+P_G(x)}{2} } \right ) dx \\
& \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ + \int_{x}P_{G}(x)\left ( log \frac{1}{2}+log \frac{P_{G}(x)}{\frac{P_{data}(x)+P_G(x)}{2} } \right ) dx \\
& = \int_{x}P_{data}(x) log \frac{1}{2} dx + \int_{x}P_{data}(x) log \frac{P_{data}(x)}{\frac{P_{data}(x)+P_G(x)}{2} } dx \\
& \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ + \int_{x}P_{G}(x) log \frac{1}{2} dx + \int_{x}P_{G}(x) log \frac{P_{G}(x)}{\frac{P_{data}(x)+P_G(x)}{2} } dx \\
& = 2 log \frac{1}{2} + \int_{x}P_{data}(x) log \frac{P_{data}(x)}{\frac{P_{data}(x)+P_G(x)}{2} } dx + \int_{x}P_{G}(x) log \frac{P_{G}(x)}{\frac{P_{data}(x)+P_G(x)}{2} } dx\\
& = 2 log \frac{1}{2} + 2 \times \left [ \frac{1}{2} KL\left( P_{data}(x) || \frac{P_{data}(x)+P_{G}(x)}{2}\right )\right ] \\
& \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ + 2 \times \left [ \frac{1}{2} KL\left( P_{G}(x) || \frac{P_{data}(x)+P_{G}(x)}{2}\right )\right ] \\
& = -2 log 2 + 2 JSD \left ( P_{data}(x) || P_G(x) \right)
\end{align} %]]>
\]
其中\(JSD ( P_{data}(x) || P_G(x))\)的取值範圍是從 \(0\)到\(log2\),其中當\(P_{data} = P_G\)是,\(JSD\)取最小值0。也就是說$ V(G,D)$的取值範圍是\(0\)到\(-2log2\),也就是說$ V(G,D)\(存在最小值,且此時\)P_{data} = P_G$。
演算法流程
上述我們從理論上討論了全域性最優值的可行性,但實際上樣本空間是無窮大的,也就是我們沒辦法獲得它的真實期望(\(\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}\)和\(\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}(\boldsymbol{z})\)是未知的),因此我們使用估測的方法來進行。
\[\tilde V = \frac{1}{m}\sum_{i=1}^{m} log D(x^i) + \frac{1}{m}\sum_{i=1}^{m} log (1-D(\tilde x^i))
\]
演算法流程圖如下所示(來自生成對抗網路——原理解釋和數學推導):
總結
上述便是GAN網路的數學原理,以及推導流程還有演算法。我也是剛開始學,參考瞭如下的部落格,其中生成對抗網路——原理解釋和數學推導非常值得一看,裡面非常詳細的對GAN進行了推導,同時,bilibili——【機器學習】白板推導系列(三十一) ~ 生成對抗網路(GAN)中的視訊也不錯,手把手白板的對公式進行了推導。如有任何問題,或文章有任何錯誤,歡迎在評論區下方留言。
參考