接著擴散模型 簡述訓練擴散模型過程中用到的損失函式形式。完整的觀察資料\(x\)的對數似然如下:
\[\begin{aligned}
\mathrm{log}\ p(x)
&\geq \mathbb{E}_{q_{\phi}(z_{1:T}|z_0)} \mathrm{log} \frac{p(z_T)\prod_{t=0}^{T-1}p_{\theta}(z_t|z_{t+1})}{\prod_{t=0}^{T-1}q_{\phi}(z_{t+1}|z_t)} \\
&= \mathbb{E}_{q_{\phi}(z_{1}|z_0)} [\mathrm{log}\ p_{\theta}(z_0|z_1) ] - \mathbb{D}_{KL}(q_{\phi}(z_T|z_0)||p(z_T)) - \sum_{t=2}^{T} \mathbb{E}_{q_{\phi}(z_t|z_0)} [ \mathbb{D}_{KL}(q_{\phi}(z_{t-1}|z_t,z_0)||p_{\theta}(z_{t-1}|z_t)) ]
\end{aligned}
\tag {1}
\]
其中,\(q_{\phi}(z_{t-1}|z_t,z_0)\)為了便於計算,已經近似為高斯分佈
\[\mathcal N(\mu_q(z_t,z_0), \Sigma_q(t)) \tag {2}\]
\[\mu_q(z_t, z_0) = \frac{\alpha_t(1-\bar{\alpha}_{t-1}^2) z_t + \bar{\alpha}_{t-1}( 1 - \alpha_t^2 ) z_0 }{ 1 - \bar {\alpha}_t^2 }
\tag {3}
\]
\[\Sigma_q(t) = \frac{ (1 - \alpha_t^2) (1 - \bar{\alpha}_{t-1}^2) }{ 1 - \bar{\alpha}_{t}^2 }I
\tag {4}
\]
形式一
為了使得去噪過程\(p_{\theta}(z_{t-1}|z_t)\)和“真實”的\(q_{\phi}(z_{t-1}|z_t,z_0)\)儘可能接近,因此也可以將\(p_{\theta}(z_{t-1}|z_t)\)建模為一個高斯分佈。又由於所有的\(\alpha\)項在每個時間步都是固定的,因此可以將其方差設計與“真實”的\(q(z_{t-1}|z_t,z_0)\)的方差是一樣的。且這個高斯分佈與初始值\(z_0\)是無關的,因此可以將其均值設計為關於\(z_t, t\)的函式,即設為\(\mu_{\theta}(z_t,t)\).
考慮兩個高斯分佈的KL散度等於
\[\begin{aligned}
& \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(x;\mu_x,\Sigma_x) || \mathcal N(y;\mu_y,\Sigma_y)) \\
& = \frac{1}{2}[log\frac{|\Sigma_y|}{|\Sigma_x|} - d + tr(\Sigma_y^{-1}\Sigma_x) + (\mu_y-\mu_x)^T\Sigma_y^{-1}(\mu_y-\mu_x)]
\end{aligned}
\tag {5}
\]
應用到公式(1)中的第三項,因此有
\[\begin{aligned}
& \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\
& = \frac{1}{2\sigma_{q}^2(t)}||\mu_{\theta}(x_t,t) - \mu_{q}(x_t,x_0)||^2
\end{aligned}
\tag {6}
\]
其中\(\sigma_{q}^2(t)\)是公式(4)前的係數即\(\sigma_{q}^2(t)= \frac{ (1 - \alpha_t^2) (1 - \bar{\alpha}_{t-1}^2) }{ 1 - \bar{\alpha}_{t}^2 }\)
由於\(\mu_{\theta}(x_t,t)\)也是\(x_t\)的函式,因此,可以參考公式(3)的形式,將進一步假設
\[\mu_{\theta}(x_t, t) = \frac{\alpha_t(1-\bar{\alpha}_{t-1}^2) z_t + \bar{\alpha}_{t-1}( 1 - \alpha_t^2 ) z_{\theta}(z_t, t) }{ 1 - \bar {\alpha}_t^2 }
\tag {7}
\]
這樣公式(6)進一步化簡為
\[\begin{aligned}
& \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\
& = \frac{1}{2\sigma_{q}^2(t)} \frac{\bar{\alpha}_{t-1}^2( 1 - \alpha_t^2 )^2}{ (1 - \bar {\alpha}_t^2)^2} ||z_{\theta}(z_t,t) - z_0||^2
\end{aligned}
\tag {8}
\]
至此,最佳化VDM就變成了學習一個神經網路,從樣本任意時刻的加噪版本預測出其原來的樣本。最終最小化公式(1)中的第三項,等價於最小化關於時間步的期望,因此有
\[arg min \mathbb{E}_{t \sim U\{2,T\}} [ \mathbb{E}_{q_{\phi}(z_t|z_0)}[ \mathbb{D}_{KL}(q_{\phi}(z_{t-1}|z_t,z_0)||p_{\theta}(z_{t-1}|z_t)) ] ]
\]
形式二
由
\[z_t = \bar \alpha_t z_0 + \sqrt{1-\bar {\alpha}_t^2} \bar \epsilon_t
\tag {9}
\]
可得
\[z_0 = \frac{z_t - \sqrt{(1-\bar {\alpha}_t^2)} \bar {\epsilon}_t}{\bar {\alpha}_t}
\tag {10}
\]
再代入公式(3)得
\[\mu_q(x_t,x_0) = \frac{1}{\alpha_t}x_t - \frac{1-\alpha_t^2}{\sqrt{1-\bar{\alpha}_t^2} \alpha_t} \bar \epsilon_t
\tag{11}
\]
參考形式一中的假設方式,可以假設
\[\mu_{\theta}(x_t,t) = \frac{1}{\alpha_t}x_t - \frac{1-\alpha_t^2}{\sqrt{1-\bar{\alpha}_t^2} \alpha_t} \epsilon_{\theta}(z_t, t)
\tag{12}
\]
再代入公式(6)可以得到
\[\begin{aligned}
& \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\
& = \frac{1}{2\sigma_{q}^2(t)} \frac{( 1 - \alpha_t^2 )^2}{ (1 - \bar {\alpha}_t^2)\alpha_t^2} ||\epsilon_{\theta}(z_t,t) - \epsilon_t||^2
\end{aligned}
\tag {12}
\]
至此,最佳化VDM就變成了學習一個神經網路,從樣本任意時刻的加噪版本預測出按照公式(10)新增的原始噪音。
形式三
由公式(8)和公式(12)可以得到
\[||\epsilon_{\theta}(z_t,t) - \epsilon_t||^2 = \frac{\bar{\alpha_t}^2}{1-\bar{\alpha_t}^2} ||z_{\theta}(z_t,t) - z_0||^2
\tag{13}
\]
由於\(\bar {\alpha_t}, \sqrt{1-\bar {\alpha_t}^2}\) 分別是\(t\)時間步的加噪訊號公式(9)中的原始訊號和噪音訊號係數,因此將訊雜比SNR(t)定義為係數平方之比,即
\[SNR(t) = \frac{\bar{\alpha_t}^2}{1-\bar{\alpha_t}^2}
\tag {14}
\]
這個訊雜比在時間步初期其值較大,代表真實訊號佔比多噪音佔比少;在時間步後期其值較小,代表真實訊號佔比少噪音佔比多。因為推理過程是完全從高斯分佈隨機取樣,為了保證推理與訓練保持一致,訓練過程採取特定的\(\bar {\alpha}_t\)使得T步得到的是完全噪音,不包含任何原始訊號。此時訊雜比是0.
當預測傳送在訊雜比接近0(\(\bar \alpha_t \to 0\))時,模型原始預測是噪音\(\bar \epsilon\),因此根據公式(10)預估對應的原始訊號
\[\bar z_0 = \frac{z_t - \sqrt{(1-\bar {\alpha}_t^2)} \bar {\epsilon}}{\bar {\alpha}_t}
\]
這樣網路預測的微小差異就會被放大很多倍,因此在論文[3]模型蒸餾過程,這就不是一個穩定的設計。為了避免這個問題,作者提出了3種解決辦法。
- 直接預測\(z\),而非噪音\(\epsilon\)
- 同時預測\(z, \epsilon\),透過兩個獨立的輸出通道\(z, \epsilon\)。由於根據公式(10)可以再由\(\epsilon\)再推斷出\(z^{'}\),然後可以根據\(\bar \alpha_t^2, 1-\bar \alpha_t^2\)對這兩個值進行差值。
- 預測混合體 \(v=\alpha_t\epsilon - \sqrt{1-\alpha_t^2}z\)
參考
[1]. https://www.cnblogs.com/wolfling/p/17938102
[2]. Understanding Diffusion Models: A Unified Perspective
[3]. Progressive Distillation for Fast Sampling of Diffusion Models