從DDPM到DDIM

txdt發表於2024-07-17

從DDPM到DDIM (一)

現在網路上關於DDPM和DDIM的講解有很多,但無論什麼樣的講解,都不如自己推到一邊來的痛快。筆者希望就這篇文章,從頭到尾對擴散模型做一次完整的推導。

DDPM是一個雙向馬爾可夫模型,其分為擴散過程和取樣過程。

擴散過程是對於圖片不斷加噪的過程,每一步新增少量的高斯噪聲,直到影像完全變為純高斯噪聲。為什麼逐步新增小的高斯噪聲,而不是一步到位,直接新增很強的噪聲呢?這一點我們留到之後來探討。

取樣過程則相反,是對純高斯噪聲影像不斷去噪,逐步恢復原始影像的過程。

下圖展示了DDPM原文中的馬爾可夫模型。
img

其中\(\mathbf{x}_T\)代表純高斯噪聲,\(\mathbf{x}_t, 0 < t < T\) 代表中間的隱變數, \(\mathbf{x}_0\) 代表生成的影像。從 \(\mathbf{x}_0\) 逐步加噪到 \(\mathbf{x}_T\) 的過程是不需要神經網路引數的,簡單地講高斯噪聲和影像或者隱變數進行線性組合即可,單步加噪過程用\(q(\mathbf{x}_t | \mathbf{x}_{t-1})\)來表示。但是去噪的過程,我們是不知道的,這裡的單步去噪過程,我們用 \(p_{\theta}(\mathbf{x}_{t-1} | \mathbf{x}_{t})\) 來表示。之所以這裡增加一個 \(\theta\) 下標,是因為 \(p_{\theta}(\mathbf{x}_{t-1} | \mathbf{x}_{t})\) 是用神經網路來逼近的轉移機率, \(\theta\) 代表神經網路引數。

擴散模型首先需要大量的圖片進行訓練,訓練的目標就是估計影像的機率分佈。訓練完畢後,生成影像的過程就是在計算出的機率分佈中取樣。因此生成模型一般都有訓練演算法和取樣演算法,VAE、GAN、diffusion,還有如今大火的大預言模型(LLM)都不例外。本文討論的DDPM和DDIM在訓練方法上是一樣的,只是DDIM在取樣方法上與前者有所不同。

而訓練演算法的最經典的方法就是極大似然估計,我們從極大似然估計開始。

1、從極大似然估計開始

首先簡單回顧一下機率論中的機率論中的一些基本概念。

1.1、概念回顧

邊緣機率密度和聯合機率密度: 大家可能還記得機率論中的邊緣機率密度,忘了也不要緊,我們簡單回顧一下。對於二維隨機變數\((X, Y)\),其聯合機率密度函式是\(f(x, y)\),那麼我不管\(Y\),單看\(X\)的機率密度,就是\(X\)的邊緣機率密度,其計算方式如下:

\[\begin{aligned} f_{X}(t) = \int_{-\infty}^{\infty} f(x, y) d y \\ \end{aligned} \\ \]

機率乘法公式: 對於聯合機率\(P(A_1 A_2 ... A_{n})\),若\(P(A_1 A_2 ... A_{n-1}) 0\),則:

\[\begin{aligned} P(A_1 A_2 ... A_{n}) &= P(A_1 A_2 ... A_{n-1}) P(A_n | A_1 A_2 ... A_{n-1}) \\ &= P(A_1) P(A_2 | A_1) P(A_3 | A_1 A_2) ... P(A_n | A_1 A_2 ... A_{n-1}) \end{aligned} \\ \]

機率乘法公式可以用條件機率的定義和數學歸納法證明。

馬爾可夫鏈定義: 隨機過程 \(\left\{X_n, n = 0,1,2,...\right\}\)稱為馬爾可夫鏈,若隨機過程在某一時刻的隨機變數 \(X_n\) 只取有限或可列個值(比如非負整數集,若不另外說明,以集合 \(\mathcal{S}\) 來表示),並且對於任意的 \(n \geq 0\) ,及任意狀態 \(i, j, i_0, i_1, ..., i_{n-1} \in \mathcal{S}\),有

\[\begin{aligned} P(X_{n+1} = j | X_{0} = i_{0}, X_{1} = i_{1}, ... X_{n} = i) = P(X_{n+1} = j | X_{n} = i) \\ \end{aligned} \\ \]

其中 \(X_n = i\) 表示過程在時刻 \(n\) 處於狀態 \(i\)。稱 \(\mathcal{S}\) 為該過程的狀態空間。上式刻畫了馬爾可夫鏈的特性,稱為馬爾可夫性。

1.2、機率分佈表示

  生成模型的主要目標是估計需要生成的資料的機率分佈。這裡就是\(p\left(\mathbf{x}_0\right)\),如何估計\(p\left(\mathbf{x}_0\right)\)呢。一個比較直接的想法就是把\(p\left(\mathbf{x}_0\right)\)當作整個馬爾可夫模型的邊緣機率:

\[\begin{aligned} p\left(\mathbf{x}_0\right) = \int p\left(\mathbf{x}_{0:T}\right) d \mathbf{x}_{1:T} \\ \end{aligned} \\ \]

這裡\(p\left(\mathbf{x}_{0:T}\right)\)表示\(\mathbf{x}_{0}, \mathbf{x}_{1}, ..., \mathbf{x}_{T}\) 多個隨機變數的聯合機率分佈。\(d \mathbf{x}_{1:T}\) 表示對\(\mathbf{x}_{1}, \mathbf{x}_{2}, ..., \mathbf{x}_{T}\)\(T\) 個隨機變數求多重積分。

  顯然,這個積分很不好求。Sohl-Dickstein等人在2015年的擴散模型的開山之作[1]中,採用的是這個方法:

\[\begin{aligned} p\left(\mathbf{x}_0\right) &= \int p\left(\mathbf{x}_{0:T}\right) \textcolor{blue}{\frac{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}} d \mathbf{x}_{1:T} \quad\quad 積分內部乘1\\ &= \int \textcolor{blue}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \frac{p\left(\mathbf{x}_{0:T}\right)}{\textcolor{blue}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}} d \mathbf{x}_{1:T} \\ &= \mathbb{E}_{\textcolor{blue}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}} \left[\frac{p\left(\mathbf{x}_{0:T}\right)}{\textcolor{blue}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}}\right] \quad\quad隨機變數函式的期望\\ \end{aligned} \tag{1} \]

  Sohl-Dickstein等人借鑑的是統計物理中的技巧:退火重要性取樣(annealed importance sampling) 和 Jarzynski equality。這兩個就涉及到筆者的知識盲區了,感興趣的同學可以自行找相關資料學習。(果然數學物理基礎不牢就搞不好科研~)。

  這裡有的同學可能會有疑問,為什麼用分子分母都為 \(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 的因子乘進去?這裡筆者嘗試給出另一種解釋,就是我們在求邊緣分佈的時候,可以嘗試將聯合機率分佈拆開,然後想辦法乘一個已知的並且與其類似的項,然後將這些項分別放在分子與分母的位置,讓他們分別進行比較。因為這是KL散度的形式,而KL散度是比較好算的。\(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 的好處就是也可以按照貝葉斯公式和馬爾可夫性質拆解成多個條件機率的連乘積,這些條件機率與 \(p\left(\mathbf{x}_{0:T}\right)\) 拆解之後的條件機率幾乎可以一一對應,而且每個條件機率表示的都是擴散過程的單步轉移機率,這我們都是知道的。那麼為什麼不用 \(q\left(\mathbf{x}_{0:T}\right)\) 呢?其實 \(p\)\(q\) 本質上是一種符號,\(q\left(\mathbf{x}_{0:T}\right)\)\(p\left(\mathbf{x}_{0:T}\right)\) 其實表示的是一個東西。

  這裡自然就引出了問題,這麼一堆隨機變數的聯合機率密度,我們還是不知道啊,\(p\left(\mathbf{x}_{0:T}\right)\)\(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 如何該表示?

  利用機率乘法公式,有:

\[\begin{aligned} p\left(\mathbf{x}_{0:T}\right) &= p\left(\mathbf{x}_{T}\right) p\left(\mathbf{x}_{T-1}|\mathbf{x}_{T}\right) p\left(\mathbf{x}_{T-2}|\mathbf{x}_{T-1},\mathbf{x}_{T}\right) ... p\left(\mathbf{x}_{0}|\mathbf{x}_{1:T}\right)\\ \end{aligned} \tag{2} \]

我們這裡是單獨把 \(p\left(\mathbf{x}_{T}\right)\),單獨提出來,這是因為 \(\mathbf{x}_{T}\) 服從高斯分佈,這是我們知道的分佈;如果反方向的來表示,這麼表示的話:

\[\begin{aligned} p\left(\mathbf{x}_{0:T}\right) &= p\left(\mathbf{x}_{0}\right) p\left(\mathbf{x}_{1}|\mathbf{x}_{0}\right) p\left(\mathbf{x}_{2}|\mathbf{x}_{1},\mathbf{x}_{0}\right) ... p\left(\mathbf{x}_{T}|\mathbf{x}_{0:T-1}\right)\\ \end{aligned} \tag{3} \]

(3)式這樣表示明顯不如(2)式,因為我們最初就是要求 \(p\left(\mathbf{x}_{0}\right)\) ,而計算(3)式則需要知道 \(p\left(\mathbf{x}_{0}\right)\),這樣就陷入了死迴圈。因此學術界採用(2)式來對聯合機率進行拆解。

因為擴散模型是馬爾可夫鏈,某一時刻的隨機變數只和前一個時刻有關,所以:

\[\begin{aligned} p\left(\mathbf{x}_{t-1}|\mathbf{x}_{\leq t}\right) = p\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\\ \end{aligned} \\ \]

於是有:

\[\begin{aligned} p\left(\mathbf{x}_{0:T}\right) = p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\\ \end{aligned} \\ \]

文章一開始說到,在擴散模型的取樣過程中,單步轉移機率是不知道的,需要用神經網路來擬合,所以我們給取樣過程的單步轉移機率都加一個下標 \(\theta\),這樣就得到了最終的聯合機率:

\[\begin{aligned} \textcolor{blue}{p\left(\mathbf{x}_{0:T}\right) = p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)} \end{aligned} \tag{4} \]

類似地,我們來計算 \(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 的拆解表示:

\[\begin{aligned} q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) &= q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right) q\left(\mathbf{x}_{2} | \mathbf{x}_{0:1}\right) ... q\left(\mathbf{x}_{T} | \mathbf{x}_{0:T-1}\right) \quad\quad 機率乘法公式\\ &= \prod_{t=1}^T q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right) \quad\quad 馬爾可夫性質\\ \end{aligned} \\ \]

於是得到了以\(\mathbf{x}_0\) 為條件的擴散過程的聯合機率分佈:

\[\begin{aligned} \textcolor{blue}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) = \prod_{t=1}^T q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} \\ \end{aligned} \tag{5} \]


  1. Sohl-Dickstein J, Weiss E, Maheswaranathan N, et al. Deep unsupervised learning using nonequilibrium thermodynamics[C]//International conference on machine learning. PMLR, 2015: 2256-2265. ↩︎

相關文章