DDMP中的損失函式

星辰大海,绿色星球發表於2024-06-16

接著擴散模型 簡述訓練擴散模型過程中用到的損失函式形式。完整的觀察資料\(x\)的對數似然如下:

\[\begin{aligned} \mathrm{log}\ p(x) &\geq \mathbb{E}_{q_{\phi}(z_{1:T}|z_0)} \mathrm{log} \frac{p(z_T)\prod_{t=0}^{T-1}p_{\theta}(z_t|z_{t+1})}{\prod_{t=0}^{T-1}q_{\phi}(z_{t+1}|z_t)} \\ &= \mathbb{E}_{q_{\phi}(z_{1}|z_0)} [\mathrm{log}\ p_{\theta}(z_0|z_1) ] - \mathbb{D}_{KL}(q_{\phi}(z_T|z_0)||p(z_T)) - \sum_{t=2}^{T} \mathbb{E}_{q_{\phi}(z_t|z_0)} [ \mathbb{D}_{KL}(q_{\phi}(z_{t-1}|z_t,z_0)||p_{\theta}(z_{t-1}|z_t)) ] \end{aligned} \tag {1} \]

其中,\(q_{\phi}(z_{t-1}|z_t,z_0)\)為了便於計算,已經近似為高斯分佈

\[\mathcal N(\mu_q(z_t,z_0), \Sigma_q(t)) \tag {2}\]

\[\mu_q(z_t, z_0) = \frac{\alpha_t(1-\bar{\alpha}_{t-1}^2) z_t + \bar{\alpha}_{t-1}( 1 - \alpha_t^2 ) z_0 }{ 1 - \bar {\alpha}_t^2 } \tag {3} \]

\[\Sigma_q(t) = \frac{ (1 - \alpha_t^2) (1 - \bar{\alpha}_{t-1}^2) }{ 1 - \bar{\alpha}_{t}^2 }I \tag {4} \]

形式一

為了使得去噪過程\(p_{\theta}(z_{t-1}|z_t)\)和“真實”的\(q_{\phi}(z_{t-1}|z_t,z_0)\)儘可能接近,因此也可以將\(p_{\theta}(z_{t-1}|z_t)\)建模為一個高斯分佈。又由於所有的\(\alpha\)項在每個時間步都是固定的,因此可以將其方差設計與“真實”的\(q(z_{t-1}|z_t,z_0)\)的方差是一樣的。且這個高斯分佈與初始值\(z_0\)是無關的,因此可以將其均值設計為關於\(z_t, t\)的函式,即設為\(\mu_{\theta}(z_t,t)\).

  考慮兩個高斯分佈的KL散度等於

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(x;\mu_x,\Sigma_x) || \mathcal N(y;\mu_y,\Sigma_y)) \\ & = \frac{1}{2}[log\frac{|\Sigma_y|}{|\Sigma_x|} - d + tr(\Sigma_y^{-1}\Sigma_x) + (\mu_y-\mu_x)^T\Sigma_y^{-1}(\mu_y-\mu_x)] \end{aligned} \tag {5} \]

應用到公式(1)中的第三項,因此有

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\ & = \frac{1}{2\sigma_{q}^2(t)}||\mu_{\theta}(x_t,t) - \mu_{q}(x_t,x_0)||^2 \end{aligned} \tag {6} \]

其中\(\sigma_{q}^2(t)\)是公式(4)前的係數即\(\sigma_{q}^2(t)= \frac{ (1 - \alpha_t^2) (1 - \bar{\alpha}_{t-1}^2) }{ 1 - \bar{\alpha}_{t}^2 }\)

由於\(\mu_{\theta}(x_t,t)\)也是\(x_t\)的函式,因此,可以參考公式(3)的形式,將進一步假設

\[\mu_{\theta}(x_t, t) = \frac{\alpha_t(1-\bar{\alpha}_{t-1}^2) z_t + \bar{\alpha}_{t-1}( 1 - \alpha_t^2 ) z_{\theta}(z_t, t) }{ 1 - \bar {\alpha}_t^2 } \tag {7} \]

這樣公式(6)進一步化簡為

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\ & = \frac{1}{2\sigma_{q}^2(t)} \frac{\bar{\alpha}_{t-1}^2( 1 - \alpha_t^2 )^2}{ (1 - \bar {\alpha}_t^2)^2} ||z_{\theta}(z_t,t) - z_0||^2 \end{aligned} \tag {8} \]

至此,最佳化VDM就變成了學習一個神經網路,從樣本任意時刻的加噪版本預測出其原來的樣本。最終最小化公式(1)中的第三項,等價於最小化關於時間步的期望,因此有

\[arg min \mathbb{E}_{t \sim U\{2,T\}} [ \mathbb{E}_{q_{\phi}(z_t|z_0)}[ \mathbb{D}_{KL}(q_{\phi}(z_{t-1}|z_t,z_0)||p_{\theta}(z_{t-1}|z_t)) ] ] \]

形式二

\[z_t = \bar \alpha_t z_0 + \sqrt{1-\bar {\alpha}_t^2} \bar \epsilon_t \tag {9} \]

可得

\[z_0 = \frac{z_t - \sqrt{(1-\bar {\alpha}_t^2)} \bar {\epsilon}_t}{\bar {\alpha}_t} \tag {10} \]

再代入公式(3)得

\[\mu_q(x_t,x_0) = \frac{1}{\alpha_t}x_t - \frac{1-\alpha_t^2}{\sqrt{1-\bar{\alpha}_t^2} \alpha_t} \bar \epsilon_t \tag{11} \]

參考形式一中的假設方式,可以假設

\[\mu_{\theta}(x_t,t) = \frac{1}{\alpha_t}x_t - \frac{1-\alpha_t^2}{\sqrt{1-\bar{\alpha}_t^2} \alpha_t} \epsilon_{\theta}(z_t, t) \tag{12} \]

再代入公式(6)可以得到

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\ & = \frac{1}{2\sigma_{q}^2(t)} \frac{( 1 - \alpha_t^2 )^2}{ (1 - \bar {\alpha}_t^2)\alpha_t^2} ||\epsilon_{\theta}(z_t,t) - \epsilon_t||^2 \end{aligned} \tag {12} \]

至此,最佳化VDM就變成了學習一個神經網路,從樣本任意時刻的加噪版本預測出按照公式(10)新增的原始噪音。

形式三

由公式(8)和公式(12)可以得到

\[||\epsilon_{\theta}(z_t,t) - \epsilon_t||^2 = \frac{\bar{\alpha_t}^2}{1-\bar{\alpha_t}^2} ||z_{\theta}(z_t,t) - z_0||^2 \tag{13} \]

由於\(\bar {\alpha_t}, \sqrt{1-\bar {\alpha_t}^2}\) 分別是\(t\)時間步的加噪訊號公式(9)中的原始訊號和噪音訊號係數,因此將訊雜比SNR(t)定義為係數平方之比,即

\[SNR(t) = \frac{\bar{\alpha_t}^2}{1-\bar{\alpha_t}^2} \tag {14} \]

這個訊雜比在時間步初期其值較大,代表真實訊號佔比多噪音佔比少;在時間步後期其值較小,代表真實訊號佔比少噪音佔比多。因為推理過程是完全從高斯分佈隨機取樣,為了保證推理與訓練保持一致,訓練過程採取特定的\(\bar {\alpha}_t\)使得T步得到的是完全噪音,不包含任何原始訊號。此時訊雜比是0.

當預測傳送在訊雜比接近0(\(\bar \alpha_t \to 0\))時,模型原始預測是噪音\(\bar \epsilon\),因此根據公式(10)預估對應的原始訊號

\[\bar z_0 = \frac{z_t - \sqrt{(1-\bar {\alpha}_t^2)} \bar {\epsilon}}{\bar {\alpha}_t} \]

這樣網路預測的微小差異就會被放大很多倍,因此在論文[3]模型蒸餾過程,這就不是一個穩定的設計。為了避免這個問題,作者提出了3種解決辦法。

  • 直接預測\(z\),而非噪音\(\epsilon\)
  • 同時預測\(z, \epsilon\),透過兩個獨立的輸出通道\(z, \epsilon\)。由於根據公式(10)可以再由\(\epsilon\)再推斷出\(z^{'}\),然後可以根據\(\bar \alpha_t^2, 1-\bar \alpha_t^2\)對這兩個值進行差值。
  • 預測混合體 \(v=\alpha_t\epsilon - \sqrt{1-\alpha_t^2}z\)

參考

[1]. https://www.cnblogs.com/wolfling/p/17938102
[2]. Understanding Diffusion Models: A Unified Perspective
[3]. Progressive Distillation for Fast Sampling of Diffusion Models

相關文章