從DDPM到DDIM(四) 預測噪聲與後處理
前情回顧
下圖展示了DDPM的雙向馬爾可夫模型。
訓練目標。最大化證據下界等價於最小化以下損失函式:
\[\boldsymbol{\theta}^*=\underset{\boldsymbol{\theta}}{\operatorname{argmin}} \sum_{t=1}^T \frac{1}{2 \sigma^2(t)} \frac{\left(1-\alpha_t\right)^2 \overline{\alpha}_{t-1}}{\left(1-\overline{\alpha}_t\right)^2} \mathbb{E}_{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\left[\Vert\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\Vert_2^2\right] \tag{1}
\]
推理過程。推理過程利用馬爾可夫鏈蒙特卡羅方法。
\[\begin{aligned}
\mathbf{x}_{t-1} &\sim p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}\right) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) , \sigma^2 \left(t\right) \mathbf{I}) \\
\mathbf{x}_{t-1} &= \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) + \sigma \left(t\right) \bm{\epsilon} \\
&= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right) + \sigma \left(t\right) \bm{\epsilon}
\end{aligned} \tag{2}
\]
1、預測噪聲
上一篇文章我們提到,擴散模型的神經網路用於預測 \(\mathbf{x}_{0}\),然而DDPM並不是這樣做的,而是用神經網路預測噪聲。這也是DDPM 第一個字母 D(Denoising)的含義。為什麼採用預測噪聲的引數化方法?DDPM作者在原文中提到去噪分數匹配(denoising score matching, DSM),並說這樣訓練和DSM是等價的。可見應該是收了DSM的啟發。另外一個解釋我們一會來講。
按照上一篇文章的化簡技巧,對於神經網路的預測輸出 \(\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)\),也可以進行進一步引數化(parameterization):
已知:
\[\begin{aligned}
\mathbf{x}_{t} = \sqrt{\overline{\alpha}_t} \mathbf{x}_{0} + \sqrt{1 - \overline{\alpha}_t} \bm{\epsilon}
\end{aligned} \tag{3}
\]
於是:
\[\begin{aligned}
\mathbf{x}_{0} = \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{\sqrt{1 - \overline{\alpha}_t}}{\sqrt{\overline{\alpha}_t}} \bm{\epsilon}
\end{aligned} \tag{4}
\]
\[\begin{aligned}
\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) = \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{\sqrt{1 - \overline{\alpha}_t}}{\sqrt{\overline{\alpha}_t}} \tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)
\end{aligned} \tag{5}
\]
這裡我們解釋以下為什麼採用預測噪聲的方式的第二個原因。從(4)(5)兩式可見,噪聲項可以看作是 \(\mathbf{x}_{0}\) 與 \(\mathbf{x}_{t}\) 的殘差項。回顧經典的Resnet結構:
\[\left[\mathbf{y}=\mathbf{x}+\mathcal{F}\left(\mathbf{x}, W_i\right)\right]
\]
Resnet也是用神經網路學習的殘差項。DDPM採用預測噪聲的方法和Resnet殘差學習由異曲同工之妙。
下面我們將(3)(4)兩式代入(1)式,繼續化簡,有:
\[\begin{aligned}
\Vert\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\Vert_2^2 &= \frac{1 - \overline{\alpha}_t}{\overline{\alpha}_t} \Vert\tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\bm{\epsilon}\Vert_2^2
\end{aligned}
\]
注意 \(\overline{\alpha}_t\) = \(\overline{\alpha}_{t-1} \alpha_t\)於是可以得出新的最佳化方程:
\[\boldsymbol{\theta}^*=\underset{\boldsymbol{\theta}}{\operatorname{argmin}} \sum_{t=1}^T \frac{1}{2 \sigma^2(t)} \frac{\left(1-\alpha_t\right)^2}{\left(1-\overline{\alpha}_t\right) \alpha}_t \mathbb{E}_{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\left[\Vert\tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\sqrt{\overline{\alpha}_t} \mathbf{x}_{0} + \sqrt{1 - \overline{\alpha}_t} \bm{\epsilon}, t\right)-\bm{\epsilon}\Vert_2^2\right] \tag{6}
\]
(6) 式表示,我們的神經網路 \(\tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\sqrt{\overline{\alpha}_t} \mathbf{x}_{0} + \sqrt{1 - \overline{\alpha}_t} \bm{\epsilon}, t\right)\) 被用於預測最初始的噪聲 \(\bm{\epsilon}\)。忽略掉前面的係數,對應的訓練演算法如下:
Algorithm 3 . Training a Deniosing Diffusion Probabilistic Model. (Version: Predict noise)
Repeat the following steps until convergence.
- For every image \(\mathbf{x}_0\) in your training dataset \(\mathbf{x}_0 \sim q\left(\mathbf{x}_0\right)\)
- Pick a random time step \(t \sim \text{Uniform}[1, T]\).
- Generate normalized Gaussian random noise \(\bm{\epsilon} \sim \mathcal{N} \left(\mathbf{0}, \mathbf{I}\right)\)
- Take gradient descent step on
\[\nabla_{\boldsymbol{\theta}} \Vert\tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\sqrt{\overline{\alpha}_t} \mathbf{x}_{0} + \sqrt{1 - \overline{\alpha}_t} \bm{\epsilon}, t\right)-\bm{\epsilon}\Vert_2^2
\]
You can do this in batches, just like how you train any other neural networks. Note that, here, you are training one denoising network \(\tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\) for all noisy conditions.
推理的過程依然從馬爾可夫鏈蒙特卡洛(MCMC)開始,因為這裡是預測噪聲,而推理的過程中也需要加噪聲,為了區分,我們將推理過程中新增的噪聲用 \(\mathbf{z} \sim \mathcal{N} \left(\mathbf{0}, \mathbf{I}\right)\) 來表示。推理過程中每次推理的噪聲 \(\mathbf{z}\) 都是不同的,但訓練過程中要擬合的最初的目標噪聲 \(\bm{\epsilon}\) 是相同的。
\[\begin{aligned}
\mathbf{x}_{t-1} &\sim p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}\right) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) , \sigma^2 \left(t\right) \mathbf{I}) \\
\mathbf{x}_{t-1} &= \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) + \sigma \left(t\right) \mathbf{z} \\
&= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right) + \sigma \left(t\right) \mathbf{z}
\end{aligned} \tag{7}
\]
將(5)式代入:
\[\begin{aligned}
\tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) &= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right) \\
&= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \left( \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{\sqrt{1 - \overline{\alpha}_t}}{\sqrt{\overline{\alpha}_t}} \tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) \right) \\
&= \text{some algebra calculation} \\
&= \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{1 - \alpha_t}{ \sqrt{ \left( 1 - \overline{\alpha}_{t} \right)\alpha}_t} \tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)
\end{aligned}
\]
所以推理的表示式為:
\[\begin{aligned}
\mathbf{x}_{t-1} &= \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{1 - \alpha_t}{ \sqrt{ \left( 1 - \overline{\alpha}_{t} \right)\alpha}_t} \tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) + \sigma \left(t\right) \mathbf{z}
\end{aligned} \tag{7}
\]
下面可以寫出採用擬合噪聲策略的推理演算法:
Algorithm 4 . Inference on a Deniosing Diffusion Probabilistic Model. (Version: Predict noise)
You give us a white noise vector \(\mathbf{x}_T \sim \mathcal{N} \left(\mathbf{0}, \mathbf{I}\right)\)
Repeat the following for \(t = T, T − 1, ... , 1\).
- Generate \(\mathbf{z} \sim \mathcal{N} \left(\mathbf{0}, \mathbf{I}\right)\) if \(t > 1\) else \(\mathbf{z} = \mathbf{0}\)
\[\mathbf{x}_{t-1} = \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{1 - \alpha_t}{ \sqrt{ \left( 1 - \overline{\alpha}_{t} \right)\alpha}_t} \tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) + \sigma \left(t\right) \mathbf{z}
\]
Return \(\mathbf{x}_{0}\)
2、後處理
首先要注意到,在推理演算法的最後一步,生成影像的時候,並沒有新增噪聲,而是直接採用預測的均值作為 \(\mathcal{x}_0\) 的估計值。
另外,生成的影像原本是歸一化到 \([-1, 1]\) 之間的,所以要反歸一化到 \([0, 255]\)。這裡比較簡單,直接看 diffusers 庫中的程式碼:
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
3、總結
我們最初的目標是估計影像的機率分佈,採用極大似然估計法,求 \(\log p\left(\mathbf{x}_0\right)\)。但是直接求解,很難求:
\[\begin{aligned}
p\left(\mathbf{x}_0\right) = \int p\left(\mathbf{x}_{0:T}\right) d \mathbf{x}_{1:T} \\
\end{aligned} \\
\]
而且 \(p\left(\mathbf{x}_{0:T}\right)\) 也不知道。於是我們選擇估計它的證據下界。在計算證據下界的過程中,我們解析了雙向馬爾可夫鏈中的很多分佈和變數,最終推匯出證據下界的表示式,以KL散度的方式來表示。這樣做本質上是用已知的分佈 \(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 來對未知的分佈做逼近。這其實是 變分推斷 的思想。變分法是尋找一個函式使得這個函式最能滿足條件,而變分推斷是尋找一個分佈使之更加逼近已知的分佈。
於是我們而在高斯分佈的假設下,KL散度恰好等價於二範數的平方。最大似然估計等價於最小化二範數loss。之後就順理成章地推匯出了訓練方法,並根據馬爾可夫鏈蒙特卡洛推匯出推理演算法。關於變分推斷和馬爾可夫鏈蒙特卡洛相關的知識,讀者可以自行查詢,有時間我也會寫篇文章來介紹。
以上就是DDPM的全部內容了,我用了四篇文章對DDPM進行了詳細推導,寫文章的過程中也弄懂了自己之前不懂的一些細節。我的最大的感受是,初學者千萬不要相信諸如《一文讀懂DDPM》之類的文章,如果要真正搞懂DDPM,只有自己把所有公式手推一邊才是正道。
下一篇我們開始介紹DDPM的一個經典的推理加速方法:DDIM