一文詳解擴散模型:DDPM

京東雲開發者發表於2023-03-29

作者:京東零售 劉巖

擴散模型講解

前沿

人工智慧生成內容(AI Generated Content,AIGC)近年來成為了非常前沿的一個研究方向,生成模型目前有四個流派,分別是生成對抗網路(Generative Adversarial Models,GAN),變分自編碼器(Variance Auto-Encoder,VAE),標準化流模型(Normalization Flow, NF)以及這裡要介紹的擴散模型(Diffusion Models,DM)。擴散模型是受到熱力學中的一個分支,它的思想來源是非平衡熱力學(Non-equilibrium thermodynamics)。擴散模型的演算法理論基礎是透過變分推斷(Variational Inference)訓練引數化的馬爾可夫鏈(Markov Chain),它在許多工上展現了超過GAN等其它生成模型的效果,例如最近非常火熱的OpenAI的DALL-E 2,Stability.ai的Stable Diffusion等。這些效果驚豔的模型擴散模型的理論基礎便是我們這裡要介紹的提出擴散模型的文章[1]和非常重要的DDPM[2],擴散模型的實現並不複雜,但其背後的數學原理卻非常豐富。在這裡我會介紹這些重要的數學原理,但省去了這些公式的推導計算,如果你對這些推導感興趣,可以學習參考文獻[4,5,11]的相關內容。我在這裡主要以一個相對簡單的角度來講解擴散模型,幫助你快速入門這個非常重要的生成演算法。

1. 背景知識: 生成模型

生成模型的本質是透過一個已知的機率模型來擬合所給的資料樣本,也就是說,我們往往需要透過模型得到一個帶引數的分佈。即如果訓練資料的分佈是\(p_\text{data}(x)\),生成樣本的分佈為\(p_\text{model}(x)\),我們希望得到的分佈和訓練資料的分佈儘可能相似。目前生成模型主要有圖1所示的四類。其中GAN的原理是透過判別器和生成器的互相博弈來讓生成器生成足以以假亂真的影像。VAE的原理是透過一個編碼器將輸入影像編碼成特徵向量,它用來學習高斯分佈的均值和方差,而解碼器則可以將特徵向量轉化為生成影像,它側重於學習生成能力。流模型是從一個簡單的分佈開始,透過一系列可逆的轉換函式將分佈轉化成目標分佈。擴散模型先透過正向過程將噪聲逐漸加入到資料中,然後透過反向過程預測每一步加入的噪聲,透過將噪聲去掉的方式逐漸還原得到無噪聲的影像,擴散模型本質上是一個馬爾可夫架構,只是其中訓練過程用到了深度學習的BP,但它更屬於數學層面的創新。這也就是為什麼很多計算機的同學看擴散模型相關的論文會如此費力。

DDPM_1.png

圖1:生成模型的四種型別 [4]

擴散模型中最重要的思想根基是馬爾可夫鏈,它的一個關鍵性質是平穩性。即如果一個機率隨時間變化,那麼再馬爾可夫鏈的作用下,它會趨向於某種平穩分佈,時間越長,分佈越平穩。如圖2所示,當你向一滴水中滴入一滴顏料時,無論你滴在什麼位置,只要時間足夠長,最終顏料都會均勻的分佈在水溶液中。這也就是擴散模型的前向過程。

DDPM_2.png

圖2:顏料分子在水溶液中的擴散過程

如果我們能夠在擴散的過程顏料分子的位置、移動速度、方向等移動屬性。那麼也可以根據正向過程的儲存的移動屬性從一杯被溶解了顏料的水中反推顏料的滴入位置。這邊是擴散模型的反向過程。記錄移動屬性的快照便是我們要訓練的模型。

2. 擴散模型

在這一部分我們將集中介紹擴散模型的數學原理以及推導的幾個重要性質,因為推導過程涉及大量的數學知識但是對理解擴散模型本身思想並無太大幫助,所以這裡我會省去推導的過程而直接給出結論。但是我也會給出推導過程的出處,對其中的推導過程比較感興趣的請自行檢視。

2.1 計算原理

擴散模型簡單的講就是透過神經網路學習從純噪聲資料逐漸對資料進行去噪的過程,它包含兩個步驟,如圖3:

  1. 固定的前向過程\(q\):在這一步我們逐漸將高斯噪聲新增到影像中,直到得到一個純噪聲的影像;
  2. 可學習的反向去噪過程\(p_\theta\):在這一步我們從純噪聲影像中逐漸對其進行去噪,直到得到真實的影像。

DDPM_3.png

圖3:DDPM的前向加噪和後向去噪過程

更具體些,對於一個\(T\)步的擴散模型,每一步的索引為\(t\)。在前向過程中,我們從一個真實影像\(\boldsymbol x_0\)開始,在每一步我們隨機生成一些高斯噪聲,然後將生成的噪聲逐步加入到輸入影像中,當\(T\)足夠大時,我們得到的加噪後的影像便接近一個高斯噪聲影像,例如DDPM中 \(T=1000\)。在後向過程中,我們從噪聲影像\(\boldsymbol x_T\)開始(訓練時是真實影像加噪的結果,取樣時是隨機噪聲),透過一個神經網路學習\(\boldsymbol x_{t-1}\)\(\boldsymbol x_t\)新增的噪聲,然後透過逐漸去噪的方式得到最後要生成的影像。

2.1.1 前向過程

前向過程即擴散過程指的是向資料中逐漸新增高斯噪聲直到資料完全變成噪聲的過程。假設\(q(\boldsymbol x_0)\)是真實影像的分佈,我們可以透過從訓練集的真實影像中隨機取樣一張影像,表示為\(\boldsymbol x \sim q(\boldsymbol x_0)\)。那麼前向過程\(q(\boldsymbol x_t | \boldsymbol x_{t-1})\)指的是在前向的每一步透過向影像\(\boldsymbol x_{t-1}\)中新增高斯噪聲得到\(\boldsymbol x_t\)。我們知道,一個高斯分佈由均值\(\mu\)和方差\(\sigma^2\)定義。那麼在每一步向影像中新增高斯噪聲的過程表示為式(1),它是一個均值\(\mu_t = \sqrt{1-\beta_t} \boldsymbol x_{t-1}\),方差\(\sigma^2_t = \beta_t\)的高斯分佈。

\[q(\boldsymbol x_t | \boldsymbol x_{t-1}) = \mathcal N(\boldsymbol x_t | \sqrt{1 - \beta_t} \boldsymbol x_{t-1}, \beta_t \mathbf I) \tag1 \]

具體到每一步的計算時,我們先取樣一個二維標準高斯分佈\(\epsilon \sim \mathcal N(\boldsymbol 0, \mathbf I)\),然後透過引數\(\beta_t\)\(\boldsymbol x_{t-1}\)得到\(\boldsymbol x_t\),表示為\(\boldsymbol x_t = \sqrt{1-\beta_t} \boldsymbol x_{t-1} + \sqrt{\beta_t} \epsilon\)。注意這裡\(\beta_t\)並不是一個常數值,而是一個隨時間變化的變數。\(\beta_t\)的變化情況在擴散模型中被叫做差異時間表(Variance Schedule),常見的時間表策略有線性時間表(DDPM),平方時間表,cosine時間表[6]等。在DDPM擴散模型的前項過程中,我們需要保證\(T\)足夠大並且\(\beta_t\)的時間表配置合理,才能保證我們最終得到的\(\boldsymbol x_t\)也將是一個高斯噪聲影像。擴散模型的全部前向過程可以表示為從\(t=1\)\(t=T\)的時刻的馬爾可夫鏈,如式(2)。

\[q(\boldsymbol x_{0:T}) = q(\boldsymbol x_0) \prod_{t=1}^T q\left( \boldsymbol x_t | \boldsymbol x_{t-1} \right) \tag2 \]

擴散過程一個隱藏的重要特徵是我們可以直接基於原始資料\(\boldsymbol x_0\)來對任意\(t\)步的\(\boldsymbol x_t\)進行取樣。這裡我們定義\(\bar{a}_t = \prod_{s=1}^t \alpha_t\)以及\(\alpha_t = 1 - \beta_t\),透過重引數(Reparamazation)技巧,我們可以得到服從分佈\(q(\boldsymbol x_t| \boldsymbol x_0)\)的任意一個樣本\(\boldsymbol x_t\),我們可以透過反重引數化得到式(3),推導過程見參考文獻[11]的式(61)到式(70)。

\[\begin{aligned} \boldsymbol x_t & \sim q(\boldsymbol x_t | \boldsymbol x_0) \\ & = \sqrt{\bar{\alpha_t}} \boldsymbol x_0 + \sqrt{1 - \bar{\alpha_t}}\epsilon \\ & = \mathcal N(\boldsymbol x_t; \sqrt{\bar{\alpha}_t} \boldsymbol x_0, (1 - \bar{\alpha}_t) \mathbf I) \end{aligned} \tag3 \]

上面推理反應了一個重要的性質,即\(\boldsymbol x_t\)可以看做原始資料\(\boldsymbol x_0\)和隨機噪聲\(\boldsymbol \epsilon\)的線性組合,其中\(\sqrt{\bar{\alpha}_t}\)\(\sqrt{1 - \bar{\alpha}_t}\)是組合係數,它們的平方和為\(1\)。公式(3)在論文中被叫做“Nice Property”。進一步的,我們可以使用\(\bar \alpha\)來定義差異時間表,例如[6]中的cosine時間表便是這麼做的。\(\bar{\alpha}_t\)要比\(\beta\)更直接,例如我們可以透過將\(\bar{\alpha}_T\)設定為一個接近\(0\)的值,來使得最終得到的噪聲更傾向於是一個高斯噪聲。

2.1.2 後向過程

前向過程是將資料噪聲化的過程,那麼擴散模型的後向過程\(p(\boldsymbol x_{t-1} | \boldsymbol x_t)\)則是一個去噪過程。即我們先在\(T\)時刻隨機取樣一個二維高斯噪聲,然後逐步進行去噪,最終得到一個和真實影像分佈一致的生成影像\(\boldsymbol x_0\)

所以擴散模型的核心是如何進行這個去噪過程,因為我們並不知道\(p(\boldsymbol x_{t-1} | \boldsymbol x_t)\)的具體形式是什麼。擴散模型指出,我們可以使用一個神經網路學習這個去噪過程。因為第\(t\)時刻的分佈\(\boldsymbol x_t\)是已知的,因此我們這個神經網路的目標是根據\(\boldsymbol x_t\)去學習\(\boldsymbol x_{t-1}\)的機率分佈。綜上,擴散模型的後向過程表示為\(p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t)\),其中\(\theta\)是神經網路的引數,我們可以用SGD等策略對該網路進行最佳化。

因為前向過程我們新增的噪聲是高斯噪聲,為了簡化模型的訓練難度,我們假設反向的去噪過程去掉的噪聲也是高斯噪聲。因為一個高斯分佈是透過均值\(\mu_\theta\)和方差\(\sum_\theta\)決定的,那麼\(p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t)\)可以表示為式(4)的形式。

\[p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t) = \mathcal N(\boldsymbol x_{t-1}; \mu_\theta(\boldsymbol x_t,t), \sum_\theta(\boldsymbol x_t, t)) \tag4 \]

其中均值和方差均是根據模型計算得到的。綜合所有時間步,我們也可以透過馬爾可夫鏈得到擴散模型的後向過程,如式(5)。

\[p_\theta(\boldsymbol x_{0:T}) = p(x_T) \prod_{t=1}^T p_\theta (\boldsymbol x_{t-1} | \boldsymbol x_t) \tag5 \]

其中\(p(\boldsymbol x_t) = \mathcal N(\boldsymbol x_T; \boldsymbol 0, \mathbf I)\)是隨機取樣的高斯噪聲,\(p_\theta (\boldsymbol x_{t-1} | \boldsymbol x_t)\)是一個均值和方差需要計算的高斯分佈。

2.1.3 目標函式

那麼問題來了,我們究竟使用什麼樣的最佳化目標才能比較好的預測高斯噪聲的分佈呢?一個比較複雜的方式是使用變分自編碼器的最大化證據下界(Evidence Lower Bound, ELBO)的思想來推導,如式(6),推導詳細過程見論文[11]的式(47)到式(58),這裡主要用到了貝葉斯定理和琴生不等式。

\[\begin{aligned} \mathcal L & = - \log p(\boldsymbol x) \\ & = - \log \int \frac{p_\theta(\boldsymbol x_{0:T})q(\boldsymbol x_{1:T} | \boldsymbol x_0)}{q(\boldsymbol x_{1:T} | \boldsymbol x_0)} d \boldsymbol x_{1:T} \\ & \leq - \mathbb E_{q(\boldsymbol x_{1:T} | \boldsymbol x_0)} \left[ \frac{p_\theta(\boldsymbol x_{0:T})}{q(\boldsymbol x_{1:T} | \boldsymbol x_0)}\right] \\ & = - \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]}_{\text {重構項}} + \underbrace{D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right) \| p\left(\boldsymbol{x}_T\right)\right)}_{\text {先驗匹配項}} + \sum_{t=2}^T \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) \| p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)\right)\right]}_{\text {去噪匹配項}} \end{aligned} \tag6 \]

式(6)的推導細節並不重要,我們需要重點關注的是它的最終等式的三個組成部分,下面我們分別介紹它們:

(1) 重構項 \(\mathcal L_0 = \mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]\),它的作用是對原始資料進行重構,最佳化的是負log似然。DDPM提供的計算方式是:它首先將離散的影像畫素從\([0, 255]\)歸一化到了\([-1,1]\)的範圍,然後用式(3)中估計的$\mathcal N(\boldsymbol x_t; \sqrt{\bar{\alpha}_t} \boldsymbol x_0, (1 - \bar{\alpha}_t) \boldsymbol I) $構建一個離散的解碼器來計算,如式(7)。它計算的是高斯分佈落在以Ground Truth為中心,且範圍大小為\(2/255\)時的機率積分,即累積分佈函式(Cumulative Distribution Function,CDF)。

\[\begin{aligned} p_\theta\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right) & =\prod_{i=1}^D \int_{\delta_{-}\left(x_0^i\right)}^{\delta_{+}\left(x_0^i\right)} \mathcal{N}\left(x ; \mu_\theta^i\left(\boldsymbol{x}_1, 1\right), \sigma_1^2\right) d x \\ \delta_{+}(x) & =\left\{\begin{array}{ll} \infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1 \end{array} \quad \delta_{-}(x)= \begin{cases}-\infty & \text { if } x=-1 \\ x-\frac{1}{255} & \text { if } x>-1\end{cases} \right. \end{aligned} \tag7 \]

其中\(D\)是整張影像,\(i\)是影像上的畫素點座標。

(2) 先驗重構項 \(\mathcal L_T=D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right) \| p\left(\boldsymbol{x}_T\right)\right)\):它使用KL散度計算了最後的噪聲輸入和標準的高斯先驗的接近程度,因為這一部分沒有可以訓練的引數,我們可以將它視作常數0。

(3) 去噪匹配項 \(\mathcal L_{t-1} = \mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) \| p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)\right)\right]\):它計算的是真實後驗分佈\(q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)\)和預測的分佈\(p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)\)之間的KL散度。因為我們希望真實的去噪過程和模型預測的去噪過程完全一致,如圖4。

DDPM_4.png

圖4:擴散模型的去噪匹配項在每一步都要擬合噪音的真實後驗分佈和估計分佈

真實後驗分佈可以使用貝葉斯定理進行推導,最終結果如式(8),推導過程見論文[11]的式(71)到式(84)。

\[\begin{aligned} q(\boldsymbol x_{t-1} | \boldsymbol x_t, \boldsymbol x_0) & = \frac{q(\boldsymbol x_{t} | \boldsymbol x_{t-1}, \boldsymbol x_0) q(\boldsymbol x_{t-1} | \boldsymbol x_0)}{q(\boldsymbol x_{t} | \boldsymbol x_0)} \\ & \propto \mathcal N \left( \boldsymbol x_{t-1}; \frac{\sqrt{\alpha_t} (1 - \bar{\alpha}_{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}_{t-1}}(1 - \alpha_t) \boldsymbol x_0}{ 1- \bar{\alpha}_t}, \frac{(1 - \alpha_t)(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf I \right) \\ & = \mathcal N(\boldsymbol x_{t-1}; \mu_q(\boldsymbol x_t, \boldsymbol x_0), \Sigma_q(t)) \end{aligned} \tag8 \]

其中\(\mu_q(\boldsymbol x_t, \boldsymbol x_0) = \frac{\sqrt{\alpha_t} (1 - \bar{\alpha}_{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}_{t-1}}(1 - \alpha_t) \boldsymbol x_0}{ 1- \bar{\alpha}_t}\)是均值,\(\Sigma_q(t) = \frac{(1 - \alpha_t)(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf I\)是方差。可以看出均值是和\(\boldsymbol x_t\)\(\boldsymbol x_0\)相關的,但是方差與資料無關。假設預測的分佈也服從正態分佈,表示為式(9)。

\[p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right) = \mathcal N(\boldsymbol x_{t-1}; \mu_\theta(\boldsymbol x_t, t), \Sigma_q(t)) \tag9 \]

為了進一步化簡\(\mathcal L_{t-1}\)我們需要用到式(10)的兩個高斯分佈的KL散度計算公式。

\[\begin{aligned} & D_\text{KL}(\mathcal N(\boldsymbol x; \boldsymbol \mu_x, \boldsymbol \Sigma_x), \mathcal N(\boldsymbol y; \boldsymbol \mu_y, \boldsymbol \Sigma_y) \\ = & \frac{1}{2}\left[ \log \frac{|\boldsymbol \Sigma_x|}{|\boldsymbol \Sigma_y|} - d + \text{tr}(\boldsymbol \Sigma_y^{-1} \boldsymbol \Sigma_x) + (\boldsymbol \mu_y - \boldsymbol \mu_x)^\intercal \boldsymbol \sigma_y^{-1}(\boldsymbol \mu_y - \boldsymbol \mu_x)\right]) \end{aligned} \tag{10} \]

在兩個分佈均是高斯分佈的前提下,我們可以使用公式(10)繼續對\(\mathcal L_{t-1}\)進行進行計算,這一部分完整的推導流程參考論文[11]的式(87)到式(92)。從這裡我們可以看出,當兩個高斯分佈方差相同時,求它們之間的KL散度既是求兩個分佈的均值的l2距離。

\[\begin{aligned} \mathop{\arg\min}_\theta D_\text{KL}(q(\boldsymbol x_{t-1} | \boldsymbol x_t, \boldsymbol x_0) || p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t)) = \mathop{\arg\min}_\theta \frac{1}{2\sigma_q^2(t)} \left[\|\boldsymbol \mu_\theta (\boldsymbol x_t, \boldsymbol x_0) - \boldsymbol \mu_q(\boldsymbol x_t, t) \|^2_2\right] \\ \end{aligned} \tag{11} \]

透過式(3)我們可以得到\(\boldsymbol x_0 = \frac{\boldsymbol x_t - \sqrt{1 - \bar\alpha_t}\boldsymbol \epsilon_0}{\sqrt {\bar \alpha_t}}\),將它代入到\(\mu_q(\boldsymbol x_t, \boldsymbol x_0) = \frac{\sqrt{\alpha_t} (1 - \bar{\alpha}_{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}_{t-1}}(1 - \alpha_t) \boldsymbol x_0}{ 1- \bar{\alpha}_t}\)中,我們可以得到式(12)的最終結果,推導過程見參考文獻(11)的式(116)到式(124)。

\[\begin{aligned} \boldsymbol \mu_q(\boldsymbol x_t, \boldsymbol x_0) & = \frac{\sqrt{\alpha_t} (1 - \bar{\alpha}_{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}_{t-1}}(1 - \alpha_t) \boldsymbol x_0}{ 1- \bar{\alpha}_t} \\ & = \frac{\sqrt{\alpha_t} (1 - \bar{\alpha}_{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}_{t-1}}(1 - \alpha_t) \frac{\boldsymbol x_t - \sqrt{1 - \bar\alpha_t}\boldsymbol \epsilon_0}{\sqrt {\bar \alpha_t}}}{ 1- \bar{\alpha}_t} \\ & = \frac{1}{\sqrt{\alpha_t}} \boldsymbol x_t - \frac{1-\alpha_t}{\sqrt{1 - \bar{\alpha_t}}{\sqrt{\alpha_t}}} \boldsymbol \epsilon_0 \end{aligned} \tag{12} \]

同理,我們也可以用這種方式計算$ \mu_\theta (\boldsymbol x_t, \boldsymbol x_0)$,如式(13):

\[\mu_\theta (\boldsymbol x_t, t) = \frac{1}{\sqrt{\alpha_t}} \boldsymbol x_t - \frac{1-\alpha_t}{\sqrt{1 - \bar{\alpha_t}}{\sqrt{\alpha_t}}} \hat {\boldsymbol \epsilon }_\theta(\boldsymbol x_t, t) \tag{13} \]

將式(12)和式(13)帶入到式(11)中,我們可以得到:

\[\begin{aligned} & \mathop{\arg\min}_\theta D_\text{KL}(q(\boldsymbol x_{t-1} | \boldsymbol x_t, \boldsymbol x_0) || p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t)) \\ = \; & \mathop{\arg\min}_\theta \frac{1}{2\sigma^2_q(t)} \frac{(1-\alpha_t)^2}{(1-\bar{\alpha}_t)\alpha_t} \left[ \| \boldsymbol \epsilon_0 - \hat{\boldsymbol \epsilon}_\theta (\boldsymbol x_t, t)\|^2_2 \right] \\ = \; & \mathop{\arg\min}_\theta \frac{1}{2\sigma^2_q(t)} \frac{(1-\alpha_t)^2}{(1-\bar{\alpha}_t)\alpha_t} \left[ \| \boldsymbol \epsilon_0 - \hat{\boldsymbol \epsilon}_\theta (\sqrt{\bar{\alpha}_t} \boldsymbol x_0 + \sqrt{ 1- \bar{\alpha}_t} \boldsymbol \epsilon , t)\|^2_2 \right] \end{aligned} \tag{14} \]

最終我們可以將擴散模型的損失函式簡化為式(15)的形式。其中\(\boldsymbol \epsilon_t\)是新增的高斯噪聲,\(\hat{\boldsymbol \epsilon}_\theta(\boldsymbol x_t, t)\)是一個神經網路,用於預測從\(\boldsymbol x_0\)\(\boldsymbol x_t\)時刻新增的噪聲。

\[\mathcal L_\text{simple} = \mathbb E_{t, \boldsymbol x_0, \epsilon} \| \boldsymbol \epsilon_t - \hat{\boldsymbol \epsilon}_\theta (\sqrt{\bar{\alpha}_t} \boldsymbol x_0 + \sqrt{ 1- \bar{\alpha}_t} \boldsymbol \epsilon , t)\|^2 \tag{15} \]

2.1.4 模型訓練

在第2.1.1節我們講到我們可以直接基於原始資料\(\boldsymbol x_0\)來對任意\(t\)步的\(\boldsymbol x_t\)進行取樣。那麼在實際訓練過程中,我們不必將所有的時間片都拿來訓練。而採取直接取樣到時刻\(t\),然後得到該時刻的\(\boldsymbol x_t\)並使用神經網路預測新增的噪聲即可。因為擴散模型的\(T\)是一個非常大的值,使用這種方式將大幅提升訓練速度。它的訓練過程為:

  1. 從分佈為\(q(\boldsymbol x_0)\)的資料集隨機取樣一個樣本\(\boldsymbol x_0 \sim q(\boldsymbol x_0)\);
  2. \(1\)\(T\)中隨機取樣一個值\(t\),用於表示新增噪聲的水平;
  3. 隨機取樣一個二維高斯噪音\(\epsilon\),然後使用上面介紹的“Nice Property”對\(\boldsymbol x_0\)施加\(t\)級別的噪聲;
  4. 訓練神經網路根據加噪之後的\(\boldsymbol x_t\)預測作用到\(\boldsymbol x_0\)之上的噪聲。

雖然上面我們介紹了很多內容,並給出了大量公式,但得益於推匯出的幾個重要性質,擴散模型的訓練並不複雜,它的訓練虛擬碼見演算法1。

DDPM_ag1.png

2.1.5 樣本生成

正如我們所介紹的,擴散模型的生成過程是一個反向去噪的過程,它的虛擬碼見演算法2。具體的講,我們從\(T\)時刻開始,首先隨機取樣一個高斯噪聲。然後我們使用神經網路預測的噪聲逐漸對其去噪,直到\(0\)時刻停止。透過式(9)我們得到了\(\boldsymbol x_{t-1} \sim p_\theta(\boldsymbol x_{t-1}|\boldsymbol x_t)\),那麼我們可以進一步得到從\(\boldsymbol x_t\)\(\boldsymbol x_{t-1}\)的計算公式,如式(16)。

\[\boldsymbol x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left( \boldsymbol x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}}\epsilon(\boldsymbol x_t, t)\right) + \sigma_t \boldsymbol z \tag{16} \]

其中\(\boldsymbol z\sim \mathcal N(\mathbf 0, \mathbf I)\)是一個二維標準高斯分佈。

DDPM_ag2.png

2.2 演算法實現

2.2.1模型結構

DDPM在預測施加的噪聲時,它的輸入是施加噪聲之後的影像,預測內容是和輸入影像相同尺寸的噪聲,所以它可以看做一個Img2Img的任務。DDPM選擇了U-Net[9]作為噪聲預測的模型結構。U-Net是一個U形的網路結構,它由編碼器,解碼器以及編碼器和解碼器之間的跨層連線(殘差連線)組成。其中編碼器將影像降取樣成一個特徵,解碼器將這個特徵上取樣為目標噪聲,跨層連線用於拼接編碼器和解碼器之間的特徵。

DDPM_5.png

圖5:U-Net的網路結構

下面我們介紹DDPM的模型結構的重要元件。首先在U-Net的卷積部分,DDPM使用了寬殘差網路(Wide Residual Network,WRN)[12]作為核心結構,WRN是一個比標準殘差網路層數更少,但是通道數更多的網路結構。也有作者復現發現ConvNeXT作為基礎結構會取得非常顯著的效果提升[13,14]。這裡我們可以根據訓練資源靈活的調整卷積結構以及具體的層數等超參。因為我們在擴散過程的整個流程中都共享同一套引數,為了區分不同的時間片,作者借鑑了Transformer [15]的位置編碼的思想,採用了正弦位置嵌入對時間\(t\)進行了編碼,這使得模型在預測噪聲時知道它預測的是批次中分別是哪個時間片新增的噪聲。在卷積層之間,DDPM新增了一個注意力層。這裡我們可以使用Transformer中提出的自注意力機制或是多頭自注意力機制。[13]則提出了一個線性注意力機制的模組,它的特點是消耗的時間以及佔用的記憶體和序列長度是線性相關的,對比傳統注意力機制的平方相關要高效很多。在進行歸一化時,DDPM選擇了組歸一化(Group Normalization,GN)[16]。最後,對於U-Net中的降取樣和上取樣操作,DDPM分別選擇了步長為2的卷積以及反摺積。

確定了這些元件,我們便可以搭建用於DDPM的U-Net的模型了。從第2.1節的介紹我們知道,模型的輸入為形狀為(batch_size, num_channels, height, width)的噪聲影像和形狀為(batch_size,1)的噪聲水平,返回的是形狀為(batch_size, num_channels, height, width)的預測噪聲,我們搭建的用於噪聲預測的模型結構如下:

  1. 首先在噪聲影像\(\boldsymbol x_0\)上應用卷積層,併為噪聲水平\(t\)計算時間嵌入;
  2. 接下來是降取樣階段。採用的模型結構依次是兩個卷積(WRNS或是ConvNeXT)+GN+Attention+降取樣層;
  3. 在網路的最中間,依次是卷積層+Attention+卷積層;
  4. 接下來是上取樣階段。它首先會使用Short-cut拼接來自降取樣中同樣尺寸的卷積,再之後是兩個卷積+GN+Attention+上取樣層。
  5. 最後是使用WRNS或是ConvNeXT作為輸出層的卷積。

U-Net類的forword函式如下面程式碼片段所示,完整的實現程式碼參照[3]。

def forward(self, x, time):
    x = self.init_conv(x)
    t = self.time_mlp(time) if exists(self.time_mlp) else None
    h = []
    # downsample
    for block1, block2, attn, downsample in self.downs:
        x = block1(x, t)
        x = block2(x, t)
        x = attn(x)
        h.append(x)
        x = downsample(x)
    # bottleneck
    x = self.mid_block1(x, t)
    x = self.mid_attn(x)
    x = self.mid_block2(x, t)
    # upsample
    for block1, block2, attn, upsample in self.ups:
        x = torch.cat((x, h.pop()), dim=1)
        x = block1(x, t)
        x = block2(x, t)
        x = attn(x)
        x = upsample(x)
    return self.final_conv(x)

2.2.2 前向加噪

擴散模型的前向過程是逐漸向影像中新增噪聲的過程,這個從時刻\(0\)到時刻\(T\)\(t\)的變化情況叫做差異時間表。DDPM使用的是線性時間表,即我們首先對\(T\)個時間步做均勻拆分,得到\(\beta_t\)。接下來我們根據\(\beta_t\)計算我們需要的其它變數,例如\(\alpha\)\(\bar{\alpha}\)等,它們最好被儲存起來以避免重複計算。接下來我們需要準備輸入影像\(\boldsymbol x_0\)以及隨機噪聲\(\epsilon\),其中影像的處理主要包括resize到模型輸入大小以及進行歸一化,DDPM的策略是將它們線性歸一化到\([-1,1]\)之間,隨機噪聲即隨機生成一個和輸入影像相同尺寸的二維高斯噪聲。最後我們根據式(3)將它們合成一張影像即可。圖6是參考文獻[3]中給出的輸入影像依次經過0次,50次,100次,150次以及199次加噪後的效果圖,可以看出隨著逐漸新增噪聲,影像越來越難以區分,直到徹底變成一個二維高斯噪聲。

DDPM_6.png

圖6:一張圖依次經過0次,50次,100次,150次以及199次加噪後的效果圖

根據式(14)我們知道,擴散模型的損失函式計算的是兩張影像的相似性,因此我們可以選擇使用迴歸演算法的所有損失函式,以MSE為例,前向過程的核心程式碼如下面程式碼片段。

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
   	# 1. 根據時刻t計算隨機噪聲分佈,並對影像x_start進行加噪
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    # 2. 根據噪聲影像以及時刻t,預測新增的噪聲
    predicted_noise = denoise_model(x_noisy, t)
    # 3. 對比新增的噪聲和預測的噪聲的相似性
    loss = F.mse_loss(noise, predicted_noise)
    return loss

2.2.3 樣本生成

根據2.1.5節介紹的樣本生成流程,它的核心程式碼片段所示,關於這段程式碼的講解我透過註釋新增到了程式碼片段中。

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    # 使用式(13)計算模型的均值
    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
    if t_index == 0:
        return model_mean
    else:
      	# 獲取儲存的方差
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # 演算法2的第4行
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# 演算法2的流程,但是我們儲存了所有中間樣本
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device
    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []
    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs

最後我們看下在人臉影像資料集下訓練的模型,一批隨機噪聲經過逐漸去噪變成人臉影像的示例。

DDPM_7.gif

圖7:擴散模型由隨機噪聲透過去噪逐漸生成人臉影像

3. 總結

這裡我們以DDPM為例介紹了另一個派系的生成演算法:擴散模型。擴散模型是一個基於馬爾可夫鏈的數學模型,它透過預測每個時間片新增的噪聲來進行模型的訓練。作為近日來引發熱烈討論的ControlNet, Stable Diffusion等模型的底層演算法,我們十分有必要對其有所瞭解。DDPM的實現並不複雜,這得益於大量數學界大佬透過大量的數學推導將整個擴散過程和反向去噪過程進行了精彩的化簡,這才有了DDPM的大道至簡的實現。DDPM作為一個擴散模型的基石演算法,它有著很多早期演算法的共同問題:

  1. 取樣速度慢:DDPM的去噪是從時刻\(T\)到時刻\(1\)的一個完整的馬爾可夫鏈的計算,尤其是DDPM還需要一個比較大的\(T\)才能保證比較好的效果,這就導致了DDPM的取樣過程註定是非常慢的;
  2. 生成效果差:DDPM的效果並不能說是非常好,尤其是對於高解析度影像的生成。這一方面是因為它的計算速度限制了它擴充套件到更大的模型;另一方面它的設計還有一些問題,例如逐畫素的計算損失並使用相同權值而忽略影像中的主體並不是非常好的策略。
  3. 內容不可控:我們可以看出,DDPM生成的內容完全還是取決於它的訓練集。它並沒有引入一些先驗條件,因此並不能透過控制影像中的細節來生成我們制定的內容。

我們現在已經知道,DDPM的這些問題已大幅得到改善,現在基於擴散模型生成的影像已經達到甚至超過人類多數的畫師的效果,我也會在之後逐漸給出這些最佳化方案的講解。

Reference

[1] Sohl-Dickstein, Jascha, et al. "Deep unsupervised learning using nonequilibrium thermodynamics." International Conference on Machine Learning. PMLR, 2015.

[2] Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in Neural Information Processing Systems 33 (2020): 6840-6851.

[3] https://huggingface.co/blog/annotated-diffusion

[4] https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#simplification

[5] https://openai.com/blog/generative-models/

[6] Nichol, Alexander Quinn, and Prafulla Dhariwal. "Improved denoising diffusion probabilistic models." International Conference on Machine Learning. PMLR, 2021.

[7] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).

[8] Hinton, Geoffrey E., and Ruslan R. Salakhutdinov. "Reducing the dimensionality of data with neural networks." science 313.5786 (2006): 504-507.

[9] Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation[C]//International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015: 234-241.

[10] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for semantic segmentation." Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.

[11] Luo, Calvin. "Understanding diffusion models: A unified perspective." arXiv preprint arXiv:2208.11970 (2022).

[12] Zagoruyko, Sergey, and Nikos Komodakis. "Wide residual networks." arXiv preprint arXiv:1605.07146 (2016).

[13] https://github.com/lucidrains/denoising-diffusion-pytorch

[14] Liu, Zhuang, et al. "A convnet for the 2020s." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.

[15] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).

[16] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018.