從DDPM到DDIM(三) DDPM的訓練與推理

txdt發表於2024-07-25

從DDPM到DDIM(三) DDPM的訓練與推理

前情回顧

首先還是回顧一下之前討論的成果。

擴散模型的結構和各個機率模型的意義。下圖展示了DDPM的雙向馬爾可夫模型。
img

其中\(\mathbf{x}_T\)代表純高斯噪聲,\(\mathbf{x}_t, 0 < t < T\) 代表中間的隱變數, \(\mathbf{x}_0\) 代表生成的影像。

  • \(q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)\) 加噪過程的單步轉移機率,服從高斯分佈,這很好理解。
  • \(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}\right)\) 是真正的取樣過程的單步轉移機率,但是求解它比較困難。
  • \(p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}\right)\) 代表的是神經網路擬合的機率,我們希望神經網路能更好地擬合取樣過程的單步轉移機率。
  • \(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)\),給定最終生成結果 \(\mathbf{x}_{0}\) 的條件下,生成過程的單步轉移機率。\(\mathbf{x}_{0}\) 就像有監督學習中的標籤,指導著生成的方向。我們採用此機率來替代 \(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}\right)\) 做神經網路的擬合。如果無法理解,把它當作一個無物理意義的數學上的中間變數即可。

\(p_{\theta}(\mathbf{x}_{t-1} | \mathbf{x}_{t})\) 來表示。之所以這裡增加一個 \(\theta\) 下標,是因為 \(p_{\theta}(\mathbf{x}_{t-1} | \mathbf{x}_{t})\) 是用神經網路來逼近的轉移機率, \(\theta\) 代表神經網路引數。

聯合機率表示 擴散模型的聯合機率和前向條件聯合機率為:

\[\begin{aligned} {p\left(\mathbf{x}_{0:T}\right) = p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)} \end{aligned} \tag{1} \]

\[\begin{aligned} {q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) = \prod_{t=1}^T q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} \\ \end{aligned} \tag{2} \]

機率分佈的具體表示式 之前提到的各種條件機率的具體表示式為:

\[\begin{aligned} q(\mathbf{x}_{t} | \mathbf{x}_{t-1}) &= \mathcal{N}(\mathbf{x}_t; \sqrt{\alpha_t} \mathbf{x}_{t-1}, (1 - \alpha_t) \mathbf{I}) \\ q(\mathbf{x}_{t} | \mathbf{x}_{0}) &= \mathcal{N}(\mathbf{x}_{t}; \sqrt{\overline{\alpha}_t} \mathbf{x}_{0}, (1 - \overline{\alpha}_t) \mathbf{I}) \\ q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right) &= \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right) , \tilde{\bm{\Sigma}} \left(t\right))\\ \end{aligned} \tag{3} \]

其中

\[\begin{aligned} \tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right) &= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{0} \\ \tilde{\bm{\Sigma}} \left(t\right) &= \frac{\left(1 - \alpha_t\right) \left( 1 - \overline{\alpha}_{t-1} \right)}{ 1 - \overline{\alpha}_{t} } \mathbf{I} = \sigma^2 \left(t\right) \mathbf{I}\\ \end{aligned} \]

另外,\(p\left(\mathbf{x}_{T}\right)\) 服從標準高斯分佈,\(p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\) 是我們要訓練的神經網路。

根據貝葉斯公式,我們要改造的條件機率如下:

\[\begin{aligned} q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}, \mathbf{x}_{0}\right) = \frac{\textcolor{blue}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)} q\left(\mathbf{x}_{t} | \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{0}\right)} \end{aligned} \tag{4} \]

證據下界 我們原本要對生成的影像分佈進行極大似然估計,但直接估計無法計算。於是我們改為最大化證據下界,然後對證據下界進行化簡,現在,我們採用 \(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)\) 重新最佳化證據下界:

\[\begin{aligned} \log p\left(\mathbf{x}_0\right) &\geq \mathcal{L} \\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{\prod_{t=1}^T q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)}\right] \\ \end{aligned} \tag{5} \]

3.5、利用 \(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0})\) 重新推導證據下界

  書接上回。我們化簡證據下界的一個想法是,我們希望將 \(p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\)\(q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)\) 的每一項一一對齊;並且將含有 \(\left(\mathbf{x}_{0}, \mathbf{x}_{1}\right)\) 的項與其他項分開來。因為 \(\mathbf{x}_{0}\) 是影像,而其他隨機變數是隱變數。還有一種解釋是,這次我們採用了 \(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0})\),而當 \(t = 1\) 時,\(q(\mathbf{x}_{0} | \mathbf{x}_{1}, \mathbf{x}_{0})\) 看起來好像是無意義的。所以我們要將含有 \(\left(\mathbf{x}_{0}, \mathbf{x}_{1}\right)\) 的項與其他項分開。

\[\begin{aligned} \mathcal{L} &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{\prod_{t=1}^T q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)}\right] \\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \prod_{t=2}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right) \prod_{t=2}^T q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)}\right] \\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right)}{q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right)}\right] + \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{\prod_{t=2}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{\prod_{t=2}^T \textcolor{blue}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}, \mathbf{x}_{0}\right)}}\right]\\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right)}{q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right)}\right] + \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{\prod_{t=2}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{\prod_{t=2}^T \textcolor{blue}{\frac{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right) q\left(\mathbf{x}_{t} | \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{0}\right)}}}\right]\\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right)}{q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right)}\right] + \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \prod_{t=2}^{T} \frac{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t} | \mathbf{x}_{0}\right)}\right] + \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{\prod_{t=2}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{\prod_{t=2}^T \textcolor{blue}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)}}\right]\\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right)}{\textcolor{red}{q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right)}}\right] + \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{\textcolor{red}{q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right)}}{q\left(\mathbf{x}_{T} | \mathbf{x}_{0}\right)}\right] + \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{\prod_{t=2}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{\prod_{t=2}^T q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)}\right]\\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{0}\right)}\right] + \sum_{t=2}^{T} \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)}\right] \quad\quad 紅色消去,同時第三項乘法變加法\\ &= \textcolor{Skyblue}{\mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right)\right]} + \textcolor{Darkred}{\mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{0}\right)}\right]} + \textcolor{Darkgreen}{\sum_{t=2}^{T} \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)}\right]} \\ \end{aligned} \]

與之前一樣,上式三項也分別代表三部分:重建項先驗匹配項一致項

  • 重建項。顧名思義,這是對最後構圖的預測機率。給定預測的最後一個隱變數 \(\mathbf{x}_{1}\),預測生成的影像 \(\mathbf{x}_{0}\) 的對數機率。
  • 先驗匹配項。 這一項描述的是擴散過程的最後一步生成的高斯噪聲與純高斯噪聲的相似度,與之前相比,這一項的 \(q\) 部分的條件改為了 \(\mathbf{x}_{0}\)。同樣,這一項並沒有神經網路引數,所以不需要最佳化,後續網路訓練的時候可以將這一項捨去。
  • 一致項。這一項與之前有兩點不同。其一,與之前相比,不再有錯位比較。其二,這匹配目標改為了由 \(p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)\)\(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)\) 匹配,而之前是和擴散過程的單步轉移機率 \(q\left(\mathbf{x}_{t}|\mathbf{x}_{t-1}\right)\) 匹配。更加合理。

類似地,與之前的操作一樣,我們將上式的數學期望下角標中的無關的隨機變數約去(積分為1),然後轉化成KL散度的形式。我們看 先驗匹配項一致項

\[\begin{aligned} \textcolor{Darkred}{\mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{0}\right)}\right]} &= \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{0}\right)} d \mathbf{x}_{1:T}\\ &= \int q\left(\mathbf{x}_{T} | \mathbf{x}_{0}\right) \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{0}\right)} d \mathbf{x}_{T}\\ &= -\mathbb{D}_{\text{KL}} \left(q\left(\mathbf{x}_{T} | \mathbf{x}_{0}\right) \| p\left(\mathbf{x}_{T}\right)\right) \end{aligned} \]

\[\begin{aligned} \textcolor{Darkgreen}{\mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)}\right]} &= \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) \log \frac{p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)} d \mathbf{x}_{1:T}\\ &= \int q\left(\mathbf{x}_{t}, \mathbf{x}_{t-1} | \mathbf{x}_{0}\right) \log \frac{p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)} d \mathbf{x}_{t} d\mathbf{x}_{t-1}\\ &= \int q\left(\mathbf{x}_{t} | \mathbf{x}_{0}\right) q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right) \log \frac{p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)} d \mathbf{x}_{t} d\mathbf{x}_{t-1}\\ &= \int q\left(\mathbf{x}_{t} | \mathbf{x}_{0}\right) \mathbb{D}_{\text{KL}} \left(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\right) d \mathbf{x}_{t} \\ &= -\mathbb{E}_{q\left(\mathbf{x}_{t} | \mathbf{x}_{0}\right)} \left[ \mathbb{D}_{\text{KL}} \left(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\right)\right] \end{aligned} \]

重建項也類似,期望下角標的機率中,除了隨機變數 \(\mathbf{x}_1\) 之外都可以約掉。最後,我們終於得出證據下界的KL散度形式:

\[\begin{aligned} \mathcal{L} &= \textcolor{Skyblue}{\mathbb{E}_{q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right)} \left[ \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right)\right]} - \textcolor{Darkred}{\mathbb{D}_{\text{KL}} \left(q\left(\mathbf{x}_{T} | \mathbf{x}_{0}\right) \| p\left(\mathbf{x}_{T}\right)\right)} \\ &- \textcolor{Darkgreen}{\sum_{t=2}^{T} \mathbb{E}_{q\left(\mathbf{x}_{t} | \mathbf{x}_{0}\right)} \left[ \mathbb{D}_{\text{KL}} \left(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\right)\right]} \\ \end{aligned} \tag{6} \]

  下面聊聊數學期望的下角標的物理意義是。以重建項為例,下角標為 \(q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right)\),代表用 \(\mathbf{x}_{0}\) 加噪一步生成 \(\mathbf{x}_{1}\),然後用 \(\mathbf{x}_{1}\) 輸入到神經網路中得到估計的 \(\mathbf{x}_{0}\) 的分佈,然後最大化這個對數似然機率。而數學期望代表了多個圖片,一個 epoch 之後取平均作為期望。一致項也類似,只是用 \(\mathbf{x}_{0}\) 生成 \(\mathbf{x}_{t}\),然後透過神經網路計算與 \(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)\) 的KL散度。這實際上就是蒙特卡洛估計。

所以,我們需要計算loss的項有兩個,一個是重建項中的對數部分,一個是一致項中的KL散度。至於數學期望和下角標,我們並不需要展開計算,而是在訓練的時候用多個圖片並分別新增不同程度的噪聲來替代。

4、訓練過程

  下面我們利用 (3) 式對證據下界 (6) 式做進一步展開。從DDPM到DDIM(二) 這篇文章講過,在 \(\beta_t\) 很小的前提下,\(p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\) 也服從高斯分佈。因為 \(p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\) 的訓練目標是匹配 \(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)\),我們也寫成高斯分佈的形式,並與 \(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)\) 的形式做對比。

\[\begin{aligned} q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right) &= \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right) , \tilde{\bm{\Sigma}} \left(t\right))\\ &= \frac{1}{\sqrt{2 \pi} \sigma \left(t\right)} \exp \left[ -\frac{1}{2} \left(\mathbf{x}_{t-1} - \tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)\right)^T \tilde{\bm{\Sigma}}^{-1} \left(t\right) \left(\mathbf{x}_{t-1} - \tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)\right)\right]\\ p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right) &= \mathcal{N}(\mathbf{x}_{t-1}; \textcolor{blue}{\tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right)} , \tilde{\bm{\Sigma}} \left(t\right))\\ &= \frac{1}{\sqrt{2 \pi} \sigma \left(t\right)} \exp \left[ -\frac{1}{2} \left(\mathbf{x}_{t-1} - \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right)\right)^T \tilde{\bm{\Sigma}}^{-1} \left(t\right) \left(\mathbf{x}_{t-1} - \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right)\right)\right]\\ \end{aligned} \]

這裡 \(p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\) 的均值 \(\textcolor{blue}{\tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right)}\) 是神經網路輸出,方差我們採用和 \(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)\) 一樣的方差。神經網路 \(\textcolor{blue}{\tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right)}\) 的輸入有兩個,第一個是 \(\mathbf{x}_{t}\),這是顯然的,還有一個輸入時時刻 \(t\),因為當然,方差也可以作為神經網路來訓練,但是DDPM原文中做過實驗,這樣效果並不顯著。因此,上述兩個均值兩個方差中,只有藍色的 \(\textcolor{blue}{\tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right)}\) 是未知的,另外三個量都是已知量。

根據 (6) 式,我們只需要計算 重建項一致項,先驗匹配項沒有訓練引數。下面分別計算:

\[\begin{aligned} \textcolor{Skyblue}{ \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) } &= -\frac{1}{2 \sigma^2 \left(t\right)} \Vert \mathbf{x}_{0} - \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{1}, 1\right) \Vert_2^2 + \text{const}\\ \end{aligned} \]

其中 \(\text{const}\) 代表某個常數。

  下面計算一致項,即KL散度。高斯分佈的KL散度是有公式的,我們不加證明地給出,若需要證明,可以查閱維基百科。兩個 \(d\) 維隨機變數服從高斯分佈 \(Q = \mathcal{N}(\bm{\mu}_1, \bm{\Sigma}_1)\) , $P = \mathcal{N}(\bm{\mu}_2, \bm{\Sigma}_2) $,其中 \(\bm{\mu}_1, \bm{\mu}_2 \in \mathbb{R}^{d}, \bm{\Sigma}_1, \bm{\Sigma}_2 \in \mathbb{R}^{d \times d}\) 二者的Kullback-Leibler 散度(KL散度)可以用以下公式計算:

\[\begin{aligned} \mathbb{D}_{\text{KL}}(Q \| P) = \frac{1}{2} \left[\log \frac{\det \bm{\Sigma}_2}{\det \bm{\Sigma}_1} - d + \text{tr}(\bm{\Sigma}_2^{-1} \bm{\Sigma}_1) + (\bm{\mu}_2 - \bm{\mu}_1)^T \bm{\Sigma}_2^{-1} (\bm{\mu}_2 - \bm{\mu}_1)\right]\\ \end{aligned} \]

下面我們將一致項代入上述公式:

\[\begin{aligned} \textcolor{Darkgreen}{\mathbb{D}_{\text{KL}} \left(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\right)} &= \frac{1}{2} \left[\log 1 - d + d + \Vert\tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) - \tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)\Vert_2^2 / \sigma^2 \left(t\right)\right]\\ &= \frac{1}{2 \sigma^2 \left(t\right)} \Vert\tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) - \tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)\Vert_2^2\\ \end{aligned} \tag{7} \]

從上兩個式子可以看出,\(\tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right)\)\(t > 0\) 的時候,目標是匹配 \(\tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)\)。我們的研究哲學是,只要有解析形式,我們就將解析形式展開,直到某個變數沒有解析解,這時候才會用神經網路擬合,這樣可以最大化地保證擬合的效果。比如我們為了擬合一個二次函式 \(f(x) = a x^2 + 3 x + 2\),其中 \(a\) 是未知量,我們應該設計一個神經網路來估計 \(a\),而不應該用神經網路來估計 \(f(x)\),因為前者確保了神經網路估計出來的函式是二次函式,而後者則有更多的不確定性。

  為了更好地匹配,我們展開 \(\tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)\) 中的解析形式。

\[\begin{aligned} \tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right) &= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{0} \\ \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) &= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right) \\ \end{aligned} \tag{8} \]

\(\tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right)\) 展開的形式與 \(\tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)\) 相同。第一項是與 \(\mathbf{x}_{t}\) 相關的,因為 \(\mathbf{x}_{t}\) 是輸入,所以保持不變,但是 \(\mathbf{x}_{0}\) 是未知量,所以我們還是用神經網路來替代,神經網路的輸入同樣也是 \(\mathbf{x}_{t}\)\(t\)。將 (8) 式代入 (7) 式,有:

\[\begin{aligned} \textcolor{Darkgreen}{\mathbb{D}_{\text{KL}} \left(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\right)} &= \frac{1}{2 \sigma^2 \left(t\right)} \Vert\tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}\right) - \tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)\Vert_2^2 \\ &= \frac{1}{2 \sigma^2 \left(t\right)} \frac{\left(1 - \alpha_t\right)^2 \overline{\alpha}_{t-1}}{\left( 1 - \overline{\alpha}_{t} \right)^2} \Vert\mathbf{x}_{0} - \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right)\Vert_2^2 \\ \end{aligned} \]

重建項也可以繼續化簡,注意到 \(\beta_0 = 0, \alpha_0 = 1, \overline{\alpha}_{0} = 1, \overline{\alpha}_{1} = \alpha_1\)

\[\begin{aligned} \textcolor{Skyblue}{ \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) } &= -\frac{1}{2 \sigma^2 \left(t\right)} \Vert \mathbf{x}_{0} - \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{1}, 1\right) \Vert_2^2 + \text{const}\\ &= -\frac{1}{2 \sigma^2 \left(t\right)} \Vert \mathbf{x}_{0} - \frac{\left( 1 - \overline{\alpha}_{0} \right) \sqrt{\alpha_1}}{\left( 1 - \overline{\alpha}_{1} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_1\right) \sqrt{\overline{\alpha}_{0}}}{\left( 1 - \overline{\alpha}_{1} \right)} \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{1}, t\right) \Vert_2^2 + \text{const}\\ &= -\frac{1}{2 \sigma^2 \left(t\right)} \Vert \mathbf{x}_{0} - \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{1}, t\right) \Vert_2^2 + \text{const}\\ &= -\frac{1}{2 \sigma^2 \left(t\right)} \frac{\left(1 - \alpha_1\right)^2 \overline{\alpha}_{0}}{\left( 1 - \overline{\alpha}_{1} \right)^2} \Vert\mathbf{x}_{0} - \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right)\Vert_2^2 + \text{const} \\ \end{aligned} \]

上式最後一行是為了與KL散度的形式保持一致。經過這麼長時間的努力,我們終於將證據下界化為最簡形式。我們把我們計算出的重建項和一致項代入到 (6) 式,並捨棄和神經網路引數無關的先驗匹配項,有:

\[\begin{aligned} \mathcal{L} &= - \sum_{t=1}^{T} \mathbb{E}_{q\left(\mathbf{x}_{t} | \mathbf{x}_{0}\right)} \left[ \frac{1}{2 \sigma^2 \left(t\right)} \frac{\left(1 - \alpha_t\right)^2 \overline{\alpha}_{t-1}}{\left( 1 - \overline{\alpha}_{t} \right)^2} \Vert\mathbf{x}_{0} - \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right)\Vert_2^2 \right] \\ \end{aligned} \tag{9} \]

因為前面有個負號,所以最大化證據下界等價於最小化以下損失函式:

\[\textcolor{blue}{\boldsymbol{\theta}^*=\underset{\boldsymbol{\theta}}{\operatorname{argmin}} \sum_{t=1}^T \frac{1}{2 \sigma^2(t)} \frac{\left(1-\alpha_t\right)^2 \overline{\alpha}_{t-1}}{\left(1-\overline{\alpha}_t\right)^2} \mathbb{E}_{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\left[\Vert\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\Vert_2^2\right]} \]

理解上式也很簡單。首先我們看每一項的權重 \(\frac{1}{2 \sigma^2(t)} \frac{\left(1-\alpha_t\right)^2 \overline{\alpha}_{t-1}}{\left(1-\overline{\alpha}_t\right)^2}\),這表示了馬爾可夫鏈每一個階段預測損失的權重,DDPM論文的實驗證明,忽略此權重影響不大,所以我們繼續簡化為:

\[\textcolor{blue}{\boldsymbol{\theta}^*=\underset{\boldsymbol{\theta}}{\operatorname{argmin}} \sum_{t=1}^T \mathbb{E}_{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\left[\Vert\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\Vert_2^2\right]} \]

  其實現方式就是給你一張影像 \(\mathbf{x}_0\),然後分別按照不同的步驟加噪,最多加到 \(T\) 步噪聲,得到 \(\mathbf{x}_1, \mathbf{x}_2, ..., \mathbf{x}_T\) 個隱變數。如下圖所示,由於多步轉移機率的性質,我們可以從 \(\mathbf{x}_0\) 一步加噪到任意一個噪聲階段。

img

  然後將這些隱變數分別送入神經網路,輸出與 \(\mathbf{x}_0\) 計算二範數Loss,然後所有的Loss取平均。然而,實際實現的時候,我們不僅僅只有一張圖,而是有很多張圖。送入神經網路的時候也是以一個 batch 的形式處理的,如果每張圖片都加這麼多次噪聲,那訓練的工作量就會非常巨大。所以實際上我們採用這樣的方式:假設一個batch中有 \(N\) 張圖片,對於這 \(N\) 張圖片分別新增不同階段的高斯噪聲,影像新增噪聲的程度也是隨機的,比如第一張影像加噪 \(10\) 步,第二張影像加噪 \(910\) 步,等等。然後分別輸入加噪後的隱變數和時刻資訊,神經網路的輸出與每一張原始影像分別做二範數loss,最後平均。這樣相比於只給一張影像加 \(1000\) 種不同的噪聲的優勢是防止在一張影像上過擬合,然後陷入區域性極小。下面我們給出具體的訓練演算法流程:


Algorithm 1 . Training a Deniosing Diffusion Probabilistic Model. (Version: Predict image)

For every image \(\mathbf{x}_0\) in your training dataset:

  • Repeat the following steps until convergence.
  • Pick a random time stamp \(t \sim \text{Uniform}[1, T]\).
  • Draw a sample \(\mathbf{x}_{t} \sim q\left(\mathbf{x}_{t} | \mathbf{x}_{t}\right)\), i.e.

\[\mathbf{x}_{t} = \sqrt{\overline{\alpha}_t} \mathbf{x}_{0} + \sqrt{1 - \overline{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \]

  • Take gradient descent step on

\[\nabla_{\boldsymbol{\theta}} \Vert \tilde{\mathbf{x}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0 \Vert_2^2 \]

You can do this in batches, just like how you train any other neural networks. Note that, here, you are training one denoising network \(\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\) for all noisy conditions.


採用batch來訓練的話,就對每個圖片分別同時進行上述操作,值得注意的是,神經網路引數只有一個,無論是哪一個 \(t\) 步去噪,其不同只有輸入的不同,而神經網路只有 \(\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\) 一個。訓練示意圖如下:

img

  說句題外話,其實DDPM的原文很具有誤導性,如下圖的DDPM的原圖。從這張圖上看,或許有些同學以為神經網路是輸入 \(\mathbf{x}_{t}\) 來預測 \(\mathbf{x}_{t-1}\),實際上並非如此。是輸入 \(\mathbf{x}_{t}\) 來預測 \(\mathbf{x}_{0}\),原因就是我們採用 \(q\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0}\right)\) 來作為擬合目標,目標是匹配其均值 \(\tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)\),而不是匹配 \(\mathbf{x}_{t-1}\)。而 \(\tilde{\bm{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)\) 恰好是 \(\mathbf{x}_{0}\) 的函式,所以我們在訓練上的時候實際上是輸入 \(\mathbf{x}_{t}\) 用神經網路來預測 \(\mathbf{x}_{0}\)。而取樣過程才是一步一步取樣的。正因為訓練時候神經網路擬合的物件並不是 \(\mathbf{x}_{t-1}\),所以就給了我們在取樣過程中的加速的空間,這就是後話了。

img

5、推理過程

  大家先別翻論文,你覺得最簡單的一個生成影像的想法是什麼。我當時就想過,既然神經網路 \(\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\) 是輸入 \(\mathbf{x}_{t}\) 來預測 \(\mathbf{x}_{0}\),那麼我們直接給他一個隨機噪聲,一步生成影像不行嗎?這個問題存疑,因為最新的研究確實有單步影像生成的,不過筆者還沒有精讀,就暫不評價。

  按照馬爾可夫性質,還是用 \(p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}\right)\) 一步一步做蒙特卡洛生成:

\[\begin{aligned} \mathbf{x}_{t-1} &\sim p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}\right) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) , \sigma^2 \left(t\right) \mathbf{I}) \\ \mathbf{x}_{t-1} &= \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) + \sigma^2 \left(t\right) \bm{\epsilon} \\ &= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right) + \sigma^2 \left(t\right) \bm{\epsilon} \end{aligned} \tag{9} \]

其中 \(\sigma^2 \left(t\right) = \frac{\left(1 - \alpha_t\right) \left( 1 - \overline{\alpha}_{t-1} \right)}{ 1 - \overline{\alpha}_{t} }\)

  擴散模型給我的感覺就是,訓練過程和推理過程的差別很大。或許這就是生成模型,訓練演算法和推理演算法的形式有很大的區別,包括文字的自迴歸生成也是如此。他不像影像分類,推理的時候跟訓練時是一樣的計算方式,只是最後來一個取機率最大的類別就行。訓練過程和推理過程的極大差異決定了此推理形式不是唯一的形式,一定有更優的推理演算法。

這個推理過程由如下演算法描述。


Algorithm 2. Inference on a Deniosing Diffusion Probabilistic Model. (Version: Predict image)

Input: the trained model \(\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\).

  • You give us a white noise vector \(\mathbf{x}_{T} \sim \mathcal{N}\left(\mathbf{0}, \mathbf{I}\right)\)
  • Repeat the following for \(t = T, T − 1, ... , 1\).
  • Update according to

\[\mathbf{x}_{t-1} = \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right) + \sigma^2 \left(t\right) \bm{\epsilon}, \quad \bm{\epsilon} \sim \mathcal{N}\left(\mathbf{0}, \mathbf{I}\right) \]

Output: \(\mathbf{x}_{0}\).


img

  • 推理輸出的 \(\mathbf{x}_{0}\) 還需要進行去歸一化和離散化到 0 到 255 之間,這個我們留到下一篇文章講。
  • 另外,在DDPM原文中,並沒有直接預測 \(\mathbf{x}_{0}\),而是對 \(\mathbf{x}_{0}\) 進行了重引數化,讓神經網路預測噪聲 \(\bm{\epsilon}\),這是怎麼做的呢,我們也留到下一篇文章講。

下一篇文章 《從DDPM到DDIM(四) 預測噪聲與生圖後處理》

相關文章