從DDPM到DDIM (一) 極大似然估計與證據下界

txdt發表於2024-07-23

從DDPM到DDIM (一) 極大似然估計與證據下界

  現在網路上關於DDPM和DDIM的講解有很多,但無論什麼樣的講解,都不如自己推到一遍來的痛快。筆者希望就這篇文章,從頭到尾對擴散模型做一次完整的推導。本文的很多部分都參考了 Calvin Luo[1] 和 Stanley Chan[2] 寫的經典教程。也推薦大家取閱讀學習。

  DDPM[3]是一個雙向馬爾可夫模型,其分為擴散過程和取樣過程。

  擴散過程是對於圖片不斷加噪的過程,每一步新增少量的高斯噪聲,直到影像完全變為純高斯噪聲。為什麼逐步新增小的高斯噪聲,而不是一步到位,直接新增很強的噪聲呢?這一點我們留到之後來探討。

  取樣過程則相反,是對純高斯噪聲影像不斷去噪,逐步恢復原始影像的過程。

下圖展示了DDPM原文中的馬爾可夫模型。
img

其中\(\mathbf{x}_T\)代表純高斯噪聲,\(\mathbf{x}_t, 0 < t < T\) 代表中間的隱變數, \(\mathbf{x}_0\) 代表生成的影像。從 \(\mathbf{x}_0\) 逐步加噪到 \(\mathbf{x}_T\) 的過程是不需要神經網路引數的,簡單地講高斯噪聲和影像或者隱變數進行線性組合即可,單步加噪過程用\(q(\mathbf{x}_t | \mathbf{x}_{t-1})\)來表示。但是去噪的過程,我們是不知道的,這裡的單步去噪過程,我們用 \(p_{\theta}(\mathbf{x}_{t-1} | \mathbf{x}_{t})\) 來表示。之所以這裡增加一個 \(\theta\) 下標,是因為 \(p_{\theta}(\mathbf{x}_{t-1} | \mathbf{x}_{t})\) 是用神經網路來逼近的轉移機率, \(\theta\) 代表神經網路引數。

  擴散模型首先需要大量的圖片進行訓練,訓練的目標就是估計影像的機率分佈。訓練完畢後,生成影像的過程就是在計算出的機率分佈中取樣。因此生成模型一般都有訓練演算法和取樣演算法,VAE、GAN、diffusion,還有如今大火的大預言模型(LLM)都不例外。本文討論的DDPM和DDIM在訓練方法上是一樣的,只是DDIM在取樣方法上與前者有所不同[4]

  估計生成樣本的機率分佈的最經典的方法就是極大似然估計,我們從極大似然估計開始。

1、從極大似然估計開始

  首先簡單回顧一下機率論中的一些基本概念,邊緣機率密度、聯合機率密度、機率乘法公式和馬爾可夫鏈,最後回顧一個強大的數學工具:Jenson 不等式。對這些熟悉的同學可以不需要看1.1節。

1.1、概念回顧

邊緣機率密度和聯合機率密度: 大家可能還記得機率論中的邊緣機率密度,忘了也不要緊,我們簡單回顧一下。對於二維隨機變數\((X, Y)\),其聯合機率密度函式是\(f(x, y)\),那麼我不管\(Y\),單看\(X\)的機率密度,就是\(X\)的邊緣機率密度,其計算方式如下:

\[\begin{aligned} f_{X}(t) = \int_{-\infty}^{\infty} f(x, y) d y \\ \end{aligned} \\ \]

機率乘法公式: 對於聯合機率\(P(A_1 A_2 ... A_{n})\),若\(P(A_1 A_2 ... A_{n-1}) 0\),則:

\[\begin{aligned} P(A_1 A_2 ... A_{n}) &= P(A_1 A_2 ... A_{n-1}) P(A_n | A_1 A_2 ... A_{n-1}) \\ &= P(A_1) P(A_2 | A_1) P(A_3 | A_1 A_2) ... P(A_n | A_1 A_2 ... A_{n-1}) \end{aligned} \\ \]

機率乘法公式可以用條件機率的定義和數學歸納法證明。

馬爾可夫鏈定義: 隨機過程 \(\left\{X_n, n = 0,1,2,...\right\}\)稱為馬爾可夫鏈,若隨機過程在某一時刻的隨機變數 \(X_n\) 只取有限或可列個值(比如非負整數集,若不另外說明,以集合 \(\mathcal{S}\) 來表示),並且對於任意的 \(n \geq 0\) ,及任意狀態 \(i, j, i_0, i_1, ..., i_{n-1} \in \mathcal{S}\),有

\[\begin{aligned} P(X_{n+1} = j | X_{0} = i_{0}, X_{1} = i_{1}, ... X_{n} = i) = P(X_{n+1} = j | X_{n} = i) \\ \end{aligned} \\ \]

其中 \(X_n = i\) 表示過程在時刻 \(n\) 處於狀態 \(i\)。稱 \(\mathcal{S}\) 為該過程的狀態空間。上式刻畫了馬爾可夫鏈的特性,稱為馬爾可夫性。

Jenson 不等式。Jenson 不等式有多種形式,我們這裡採用其積分形式:

\(f(x)\) 為凸函式,另一個函式 \(q(x)\) 滿足:

\[\int_{-\infty}^{\infty} q(x) d x = 1 \\ \]

則有:

\[f \left[ \int_{-\infty}^{\infty} q(x) x d x \right] \leq \int_{-\infty}^{\infty} q(x) f(x) d x \\ \]

更進一步地,若 \(x\) 為隨機變數 \(X\) 的取值,\(q(x)\) 為隨機變數 \(X\) 的機率密度函式,則 Jenson不等式為:

\[f \left[ \mathbb{E}(x) \right] \leq \mathbb{E}[f(x)] \\ \]

關於 Jenson 不等式的證明,用凸函式的定義證明即可。網上有很多,這裡不再贅述。

1.2、機率分佈表示

  生成模型的主要目標是估計需要生成的資料的機率分佈。這裡就是\(p\left(\mathbf{x}_0\right)\),如何估計\(p\left(\mathbf{x}_0\right)\)呢。一個比較直接的想法就是把\(p\left(\mathbf{x}_0\right)\)當作整個馬爾可夫模型的邊緣機率:

\[\begin{aligned} p\left(\mathbf{x}_0\right) = \int p\left(\mathbf{x}_{0:T}\right) d \mathbf{x}_{1:T} \\ \end{aligned} \\ \]

這裡\(p\left(\mathbf{x}_{0:T}\right)\)表示\(\mathbf{x}_{0}, \mathbf{x}_{1}, ..., \mathbf{x}_{T}\) 多個隨機變數的聯合機率分佈。\(d \mathbf{x}_{1:T}\) 表示對\(\mathbf{x}_{1}, \mathbf{x}_{2}, ..., \mathbf{x}_{T}\)\(T\) 個隨機變數求多重積分。

  顯然,這個積分很不好求。Sohl-Dickstein等人在2015年的擴散模型的開山之作[5]中,採用的是這個方法:

\[\begin{aligned} p\left(\mathbf{x}_0\right) &= \int p\left(\mathbf{x}_{0:T}\right) \textcolor{blue}{\frac{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}} d \mathbf{x}_{1:T} \quad\quad 積分內部乘1\\ &= \int \textcolor{blue}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \frac{p\left(\mathbf{x}_{0:T}\right)}{\textcolor{blue}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}} d \mathbf{x}_{1:T} \\ &= \mathbb{E}_{\textcolor{blue}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}} \left[\frac{p\left(\mathbf{x}_{0:T}\right)}{\textcolor{blue}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}}\right] \quad\quad隨機變數函式的期望\\ \end{aligned} \tag{1} \]

  Sohl-Dickstein等人借鑑的是統計物理中的技巧:退火重要性取樣(annealed importance sampling) 和 Jarzynski equality。這兩個就涉及到筆者的知識盲區了,感興趣的同學可以自行找相關資料學習。(果然數學物理基礎不牢就搞不好科研~)。

  這裡有的同學可能會有疑問,為什麼用分子分母都為 \(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 的因子乘進去?這裡筆者嘗試給出另一種解釋,就是我們在求邊緣分佈的時候,可以嘗試將聯合機率分佈拆開,然後想辦法乘一個已知的並且與其類似的項,然後將這些項分別放在分子與分母的位置,讓他們分別進行比較。因為這是KL散度的形式,而KL散度是比較好算的。\(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 的好處就是也可以按照貝葉斯公式和馬爾可夫性質拆解成多個條件機率的連乘積,這些條件機率與 \(p\left(\mathbf{x}_{0:T}\right)\) 拆解之後的條件機率幾乎可以一一對應,而且每個條件機率表示的都是擴散過程的單步轉移機率,這我們都是知道的。那麼為什麼不用 \(q\left(\mathbf{x}_{0:T}\right)\) 呢?其實 \(p\)\(q\) 本質上是一種符號,\(q\left(\mathbf{x}_{0:T}\right)\)\(p\left(\mathbf{x}_{0:T}\right)\) 其實表示的是一個東西。

  這裡自然就引出了問題,這麼一堆隨機變數的聯合機率密度,我們還是不知道啊,\(p\left(\mathbf{x}_{0:T}\right)\)\(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 如何該表示?

  利用機率乘法公式,有:

\[\begin{aligned} p\left(\mathbf{x}_{0:T}\right) &= p\left(\mathbf{x}_{T}\right) p\left(\mathbf{x}_{T-1}|\mathbf{x}_{T}\right) p\left(\mathbf{x}_{T-2}|\mathbf{x}_{T-1},\mathbf{x}_{T}\right) ... p\left(\mathbf{x}_{0}|\mathbf{x}_{1:T}\right)\\ \end{aligned} \tag{2} \]

我們這裡是單獨把 \(p\left(\mathbf{x}_{T}\right)\),單獨提出來,這是因為 \(\mathbf{x}_{T}\) 服從高斯分佈,這是我們知道的分佈;如果反方向的來表示,這麼表示的話:

\[\begin{aligned} p\left(\mathbf{x}_{0:T}\right) &= p\left(\mathbf{x}_{0}\right) p\left(\mathbf{x}_{1}|\mathbf{x}_{0}\right) p\left(\mathbf{x}_{2}|\mathbf{x}_{1},\mathbf{x}_{0}\right) ... p\left(\mathbf{x}_{T}|\mathbf{x}_{0:T-1}\right)\\ \end{aligned} \tag{3} \]

(3)式這樣表示明顯不如(2)式,因為我們最初就是要求 \(p\left(\mathbf{x}_{0}\right)\) ,而計算(3)式則需要知道 \(p\left(\mathbf{x}_{0}\right)\),這樣就陷入了死迴圈。因此學術界採用(2)式來對聯合機率進行拆解。

因為擴散模型是馬爾可夫鏈,某一時刻的隨機變數只和前一個時刻有關,所以:

\[\begin{aligned} p\left(\mathbf{x}_{t-1}|\mathbf{x}_{\leq t}\right) = p\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\\ \end{aligned} \\ \]

於是有:

\[\begin{aligned} p\left(\mathbf{x}_{0:T}\right) = p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)\\ \end{aligned} \\ \]

文章一開始說到,在擴散模型的取樣過程中,單步轉移機率是不知道的,需要用神經網路來擬合,所以我們給取樣過程的單步轉移機率都加一個下標 \(\theta\),這樣就得到了最終的聯合機率:

\[\begin{aligned} \textcolor{blue}{p\left(\mathbf{x}_{0:T}\right) = p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)} \end{aligned} \tag{4} \]

類似地,我們來計算 \(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 的拆解表示:

\[\begin{aligned} q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) &= q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right) q\left(\mathbf{x}_{2} | \mathbf{x}_{0:1}\right) ... q\left(\mathbf{x}_{T} | \mathbf{x}_{0:T-1}\right) \quad\quad 機率乘法公式\\ &= \prod_{t=1}^T q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right) \quad\quad 馬爾可夫性質\\ \end{aligned} \\ \]

於是得到了以\(\mathbf{x}_0\) 為條件的擴散過程的聯合機率分佈:

\[\begin{aligned} \textcolor{blue}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) = \prod_{t=1}^T q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} \\ \end{aligned} \tag{5} \]

數學推導的一個很重要的事情就是分清楚哪些是已知量,哪些是未知量。\(q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)\) 是已知的,因為根據擴散模型的定義,\(q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)\) 是服從高斯分佈的;另外 \(p\left(\mathbf{x}_{T} \right)\) 也是已知的,因為 \(\mathbf{x}_{T}\) 代表的是純高斯噪聲,因此 \(p\left(\mathbf{x}_{T} \right)\) 也是服從高斯分佈的。

1.3、極大似然估計

  既然我們知道了\(p\left(\mathbf{x}_{0:T}\right)\)\(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 的表示式,我們可以繼續對 (1) 式進行化簡了。首先我們要對(1)式進行縮放,(1)式我們難以計算,我們可以計算(1)式的下界,然後極大化它的下界就相當於極大化(1)式了。這確實是一種巧妙的方法,這個方法在VAE推導的時候就已經用過了,大家不懂VAE也沒關係,我們在這裡重新推導一遍。

  計算(1)式的下界一般有兩種辦法。分別是Jenson不等式法KL散度方法。下面我們分別給出兩種方法的推導。

Jenson不等式法。在進行極大似然估計的時候,一般會對機率分佈取對數,於是我們對(1)式取對數可得:

\[\begin{aligned} \log p\left(\mathbf{x}_0\right) &= \log \int p\left(\mathbf{x}_{0:T}\right) {\frac{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}} d \mathbf{x}_{1:T}\\ &= \log \int {q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \frac{p\left(\mathbf{x}_{0:T}\right)}{{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}} d \mathbf{x}_{1:T} \\ &= \log \left\{\mathbb{E}_{{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}} \left[\frac{p\left(\mathbf{x}_{0:T}\right)}{{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}}\right] \right\}\\ &\geq \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{0:T}\right)}{{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}}\right] \quad\quad \text{Jenson不等式,log()為concave函式}\\ &= \mathcal{L} \end{aligned} \]

\(\mathcal{L} = \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{0:T}\right)}{{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}}\right]\) 被稱為隨機變數 \(\mathbf{x}_0\)證據下界 (Evidence Lower BOund, ELBO)。

KL散度方法。當然,我們也可以不採用Jenson不等式,利用KL散度的非負性,同樣也可以得出證據下界。將證據下界中的數學期望展開,寫為積分形式為:

\[\begin{aligned} \mathcal{L} = \int {q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \log \frac{p\left(\mathbf{x}_{0:T}\right)}{{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}} d \mathbf{x}_{1:T}\\ \end{aligned} \]

另外,我們定義一個KL散度:

\[\begin{aligned} \text{KL}\left(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) || p\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\right) = \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) \log \frac{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}{p\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} d \mathbf{x}_{1:T}\\ \end{aligned} \]

下面我們將驗證:

\[\begin{aligned} \log p\left(\mathbf{x}_0\right) &= \mathcal{L} + \text{KL}\left(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) || p\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\right)\\ \end{aligned} \tag{6} \]

具體地,有:

\[\begin{aligned} \mathcal{L} + \text{KL}\left(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) || p\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\right) &= \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) \log \frac{p\left(\mathbf{x}_{0:T}\right)}{p\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} d \mathbf{x}_{1:T}\\ &= \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) \log \frac{p\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) p\left(\mathbf{x}_{0}\right)}{p\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} d \mathbf{x}_{1:T}\\ &= \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) \log p\left(\mathbf{x}_{0}\right) d \mathbf{x}_{1:T}\\ &= \log p\left(\mathbf{x}_{0}\right) \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) d \mathbf{x}_{1:T} \quad\quad 機率密度積分為1\\ &= \log p\left(\mathbf{x}_{0}\right)\\ \end{aligned} \]

因此,(6)式成立。由於\(\text{KL}\left(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) || p\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\right) \geq 0\),所以:

\[\log p\left(\mathbf{x}_{0}\right) \geq \mathcal{L} \]

  個人還是更喜歡Jenson不等式法,因為此方法的思路一氣呵成;而KL散度法像是先知道最終的答案,然後取驗證答案的正確性。而且KL的非負性也是可以用Jenson不等式證明的,所以二者在數學原理上本質是一樣的。KL散度法有一個優勢,就是能讓我們知道 \(\log p\left(\mathbf{x}_{0}\right)\) 與證據下界的差距有多大,二者僅僅相差一個KL散度:\(\text{KL}\left(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) || p\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\right)\)。關於這兩個機率分佈的物理意義,筆者認為可以這樣理解:\(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 是真實的,在給定圖片 \(\mathbf{x}_0\) 作為條件下的前向聯合機率,而 \(p\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 是我們估計的條件前向聯合機率。而KL散度描述的是兩個機率分佈之間的距離,如果我們估計的比較準確的話,二者的距離是比較小的。因此這也印證了我們採用證據下界來替代 \(\log p\left(\mathbf{x}_0\right)\) 做極大似然估計的合理性。

  下面,我們的方向就是逐步簡化證據下界,直到簡化為我們程式設計可實現的形式。

2、簡化證據下界

  對證據下界的化簡,需要用到三個我們之前推匯出來的表示式。為了方便閱讀,我們把(4)式,(5)式,還有證據下界重寫到這裡。

\[\begin{aligned} \textcolor{blue}{p\left(\mathbf{x}_{0:T}\right) = p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)} \end{aligned} \tag{4} \]

\[\begin{aligned} \textcolor{blue}{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) = \prod_{t=1}^T q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} \\ \end{aligned} \tag{5} \]

\[\begin{aligned} \textcolor{blue}{\mathcal{L} = \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{0:T}\right)}{{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}}\right]}\\ \end{aligned} \tag{7} \]

下面我們將(4)式和(5)式代入到(7)式中,有:

\[\begin{aligned} \mathcal{L} &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{0:T}\right)}{{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)}}\right]\\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{\prod_{t=1}^T q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)}\right]\\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \prod_{t=2}^{T} p_{\theta}\left(\mathbf{x}_{t-1}|\mathbf{x}_{t}\right)}{\prod_{t=1}^T q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)}\right]\\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \prod_{t=1}^{T-1} \textcolor{blue}{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}}{\prod_{t=1}^T \textcolor{blue}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)}}\right] \\ \end{aligned} \]

藍色部分的操作,是將分子分母的表示的隨機變數機率保持一致,對同一個隨機變數的機率分佈的描述才具備可比性。而且,我們希望分子分母的連乘號下標保持一致,這樣才能進一步化簡。下面我們繼續:

\[\begin{aligned} \mathcal{L} &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \prod_{t=1}^{T-1} \textcolor{blue}{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}}{\prod_{t=1}^T \textcolor{blue}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)}}\right] \\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \prod_{t=1}^{T-1} {p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}}{q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right) \prod_{t=1}^{T-1} {q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)}}\right] \\ &= \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right) }\right] + \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \prod_{t=1}^{T-1} \frac{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} \right]\\ &= \textcolor{skyblue}{\mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \right]} + \textcolor{darkred}{\mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right)} \right]} + \textcolor{darkgreen}{\sum_{t=1}^{T-1} \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} \right]} \\ \end{aligned} \tag{8} \]

上式第一項,第二項,第三項分別被稱為 重建項(Reconstruction Term)先驗匹配項(Prior Matching Term)一致項(Consistency Term)

  • 重建項。顧名思義,這是對最後構圖的預測機率。給定預測的最後一個隱變數 \(\mathbf{x}_1\),預測生成的影像 \(\mathbf{x}_0\) 的對數機率。
  • 先驗匹配項。 這一項描述的是擴散過程的最後一步生成的高斯噪聲與純高斯噪聲的相似度,因為這一項並沒有神經網路引數,所以不需要最佳化,後續網路訓練的時候可以將這一項捨去。
  • 一致項。這一項描述的是取樣過程的單步轉移機率 \(p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)\) 和擴散過程的單步轉移機率 \(q\left(\mathbf{x}_{t}|\mathbf{x}_{t-1}\right)\) 的距離。由於 \(q\left(\mathbf{x}_{t}|\mathbf{x}_{t-1}\right)\) 是服從高斯分佈的(加噪過程自己定義的),所以我們希望取樣過程的單步轉移機率 \(p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)\) 也服從高斯分佈,這樣才能使得二者的KL散度更加接近。我們之後會看到,最小化二者的KL散度等價於最大似然估計。

到這裡我們透過觀察可以發現,乘上 \(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\),的期望中,存在很多無關的隨機變數,因此這三項可以進一步化簡。以上式重建項為例,我們將上式重建項寫成積分形式:

\[\begin{aligned} \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \right] = \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \textcolor{blue}{d \mathbf{x}_{0:T}} \end{aligned} \]

可能有的同學搞不清楚積分微元是哪個變數,如果不知道的話就把所有的隨機變數都算為積分微元。如果哪個微元不需要的話,是可以被積分積掉的。注意到,\(\mathbf{x}_0\) 是真實影像,而不是隨機變數,所以隨機變數最多就是 \(\mathbf{x}_{1:T}\)。 具體我們來看:

\[\begin{aligned} \textcolor{Skyblue}{\mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \right]} &= \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \textcolor{blue}{d \mathbf{x}_{1:T}} \\ &= \int \int ... \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \textcolor{blue}{d \mathbf{x}_{1} d \mathbf{x}_{2} ... d \mathbf{x}_{T}} \\ &= \int \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \textcolor{blue}{\left[\int ... \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) d \mathbf{x}_{2} d \mathbf{x}_{3} ... d \mathbf{x}_{T} \right]}d \mathbf{x}_{1} \\ &= \int \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \textcolor{blue}{q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right) }d \mathbf{x}_{1} \\ &= \mathbb{E}_{q\left(\mathbf{x}_{1} | \mathbf{x}_{0}\right)} \left[ \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \right] \end{aligned} \]

類似地,我們看 先驗匹配項一致項。先驗匹配項中除了 \(\mathbf{x}_T\)\(\mathbf{x}_{T-1}\) 這兩個隨機變數之外,其他的隨機變數都會被積分積掉,於是就是:

\[\begin{aligned} \textcolor{Darkred}{\mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right)} \right]} &= \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right)} d \mathbf{x}_{1:T}\\ &= \int q\left(\mathbf{x}_{T-1}, \mathbf{x}_{T} | \mathbf{x}_{0}\right) \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right)} d \mathbf{x}_{T-1} d\mathbf{x}_{T}\\ &= \mathbb{E}_{q\left(\mathbf{x}_{T-1}, \mathbf{x}_{T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right)} \right]\\ \end{aligned} \]

一致項也用類似的操作:

\[\begin{aligned} \textcolor{Darkgreen}{\sum_{t=1}^{T-1} \mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} \right]} &= \sum_{t=1}^{T-1} \int q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right) \log \frac{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} d \mathbf{x}_{1:T}\\ &= \sum_{t=1}^{T-1} \int q\left(\mathbf{x}_{t-1}, \mathbf{x}_{t}, \mathbf{x}_{t+1} | \mathbf{x}_{0}\right) \log \frac{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} d \mathbf{x}_{t-1} d \mathbf{x}_{t} d \mathbf{x}_{t+1}\\ &= \sum_{t=1}^{T-1} \mathbb{E}_{q\left(\mathbf{x}_{t-1}, \mathbf{x}_{t}, \mathbf{x}_{t+1} | \mathbf{x}_{0}\right)} \left[ \log \frac{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} \right] \end{aligned} \]

下面我們繼續化簡,化簡乘KL散度的形式。因為兩個高斯分佈的KL散度可以寫成二範數Loss的形式,這是我們程式設計可實現的。我們先給出KL散度的定義。設兩個機率分佈和 \(Q\)\(P\),在連續隨機變數的情況下,他們的機率密度函式分別為 \(q(x)\) \(p(x)\),那麼二者的KL散度為:

\[\begin{aligned} \mathbb{D}_{\text{KL}}\left(Q || P\right) = - \int q(x) \log \frac{p(x)}{q(x)} dx \end{aligned} \]

注意,KL散度沒有對稱性,即 \(\text{KL}\left(Q || P\right)\)\(\text{KL}\left(P || Q\right)\) 是不同的。

下面,我們以 一致項 中的其中一項為例子,來寫成KL散度形式:

\[\begin{aligned} \textcolor{darkgreen}{\mathbb{E}_{q\left(\mathbf{x}_{t-1}, \mathbf{x}_{t}, \mathbf{x}_{t+1} | \mathbf{x}_{0}\right)} \left[ \log \frac{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} \right]} &= \int q\left(\mathbf{x}_{t-1}, \mathbf{x}_{t}, \mathbf{x}_{t+1} | \mathbf{x}_{0}\right) \log \frac{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} d \mathbf{x}_{t-1} d \mathbf{x}_{t} d \mathbf{x}_{t+1} \\ &= \int \textcolor{red}{q\left(\mathbf{x}_{t-1}, \mathbf{x}_{t+1} | \mathbf{x}_{0}\right) q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} \log \frac{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} d \mathbf{x}_{t-1} d \mathbf{x}_{t} d \mathbf{x}_{t+1} \\ &= \int q\left(\mathbf{x}_{t-1}, \mathbf{x}_{t+1} | \mathbf{x}_{0}\right) \left\{\int q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right) \log \frac{p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)}{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)} d \mathbf{x}_{t} \right\} d \mathbf{x}_{t-1} d \mathbf{x}_{t+1} \\ &= - \int q\left(\mathbf{x}_{t-1}, \mathbf{x}_{t+1} | \mathbf{x}_{0}\right) \mathbb{D}_{\text{KL}}\left(q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right) || p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)\right) d \mathbf{x}_{t-1} d \mathbf{x}_{t+1} \\ &= - \mathbb{E}_{q\left(\mathbf{x}_{t-1}, \mathbf{x}_{t+1} | \mathbf{x}_{0}\right)} \left[ \mathbb{D}_{\text{KL}}\left(q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right) || p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)\right) \right] \\ \end{aligned} \]

上式紅色的部分,是參考的開篇提到的那兩個教程。筆者自己推了一下,並沒有得出相應的結果。

如果有好的解釋,歡迎討論。不過這裡是否嚴格並不重要,之後我們會解釋,事實上我們使用的是另外一種推導方式。

類似地,先驗匹配項 也可以用類似的方法表示成KL散度的形式:

\[\begin{aligned} \textcolor{Darkred}{\mathbb{E}_{q\left(\mathbf{x}_{T-1}, \mathbf{x}_{T} | \mathbf{x}_{0}\right)} \left[ \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right)} \right]} &= \int q\left(\mathbf{x}_{T-1}, \mathbf{x}_{T} | \mathbf{x}_{0}\right) \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right)} d \mathbf{x}_{T-1} d\mathbf{x}_{T}\\ &= \int \textcolor{red}{q\left(\mathbf{x}_{T-1} | \mathbf{x}_{0}\right) q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right)} \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right)} d \mathbf{x}_{T-1} d\mathbf{x}_{T}\\ &= \int q\left(\mathbf{x}_{T-1} | \mathbf{x}_{0}\right) \left\{\int q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right) \log \frac{p\left(\mathbf{x}_{T}\right)}{q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right)} d \mathbf{x}_{T} \right\} d\mathbf{x}_{T-1}\\ &= - \int q\left(\mathbf{x}_{T-1} | \mathbf{x}_{0}\right) \mathbb{D}_{\text{KL}}\left(q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right) || p\left(\mathbf{x}_{T}\right)\right) d\mathbf{x}_{T-1}\\ &= - \mathbb{E}_{q\left(\mathbf{x}_{T-1} | \mathbf{x}_{0}\right)} \left[ \mathbb{D}_{\text{KL}}\left(q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right) || p\left(\mathbf{x}_{T}\right)\right) \right] \\ \end{aligned} \]

這裡紅色的部分,我們可以詳細驗證一下:

\[\begin{aligned} q\left(\mathbf{x}_{T-1} | \mathbf{x}_{0}\right) q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right) &= q\left(\mathbf{x}_{T-1} | \mathbf{x}_{0}\right) q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}, \mathbf{x}_{0}\right) \quad\quad 馬爾可夫性質 \\ &= q\left(\mathbf{x}_{T-1}, \mathbf{x}_{T} | \mathbf{x}_{0}\right) \quad\quad 條件機率公式 \\ \end{aligned} \]

沒有什麼問題。

  下面我們整理一下結果。我們簡化的證據下界為:

\[\begin{aligned} \mathcal{L} = \textcolor{skyblue}{\mathbb{E}_{q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)} \left[ \log p_{\theta}\left(\mathbf{x}_{0}|\mathbf{x}_{1}\right) \right]} &- \textcolor{darkred}{\mathbb{E}_{q\left(\mathbf{x}_{T-1} | \mathbf{x}_{0}\right)} \left[ \mathbb{D}_{\text{KL}}\left(q\left(\mathbf{x}_{T} | \mathbf{x}_{T-1}\right) || p\left(\mathbf{x}_{T}\right)\right) \right]} \\ &- \textcolor{darkgreen}{\sum_{t=1}^{T-1} \mathbb{E}_{q\left(\mathbf{x}_{t-1}, \mathbf{x}_{t+1} | \mathbf{x}_{0}\right)} \left[ \mathbb{D}_{\text{KL}}\left(q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right) || p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)\right) \right]} \\ \end{aligned} \tag{9} \]

  我們看第三項,也就是 一致項。我們發現兩個機率分佈 \(q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}\right)\)\(p_{\theta}\left(\mathbf{x}_{t}|\mathbf{x}_{t+1}\right)\) 有時序上的錯位,如下圖展示的粉色線和綠色線。這相當於最小化兩個錯一位的機率分佈,這顯然不是我們希望的結果。雖然錯一位影響也不大,但總歸是不完美的。而且,這兩個轉移機率的方向也不一樣。因此,下面我們就要想辦法對證據下界進行最佳化,看看能不能推匯出兩個在時序上完全對齊的兩個機率分佈的KL散度。

img

如何最佳化證據下界呢。我們放到下一篇文章中來講:

從DDPM到DDIM (二) 前向過程與反向過程的機率分佈。


  1. Luo C. Understanding diffusion models: A unified perspective[J]. arXiv preprint arXiv:2208.11970, 2022. ↩︎

  2. Chan S H. Tutorial on Diffusion Models for Imaging and Vision[J]. arXiv preprint arXiv:2403.18103, 2024. ↩︎

  3. Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models[J]. Advances in neural information processing systems, 2020, 33: 6840-6851. ↩︎

  4. Song J, Meng C, Ermon S. Denoising diffusion implicit models[J]. arXiv preprint arXiv:2010.02502, 2020. ↩︎

  5. Sohl-Dickstein J, Weiss E, Maheswaranathan N, et al. Deep unsupervised learning using nonequilibrium thermodynamics[C]//International conference on machine learning. PMLR, 2015: 2256-2265. ↩︎

相關文章