InDepth Guide to Denoising Diffusion Probabilistic Models DDPM:DDPM擴散機率模型去噪深度指南——理論到實現

张天明發表於2024-11-09

An In-Depth Guide to Denoising Diffusion Probabilistic Models DDPM – Theory to Implementation
中文翻譯:DDPM擴散機率模型去噪深度指南——理論到實現

https://learnopencv.com/denoising-diffusion-probabilistic-models/#forward-diffusion-equation

https://github.com/spmallick/learnopencv/tree/master/Guide-to-training-DDPMs-from-Scratch

擴散機率模型是一個令人興奮的新研究領域,在影像生成方面顯示出巨大的前景。回想起來,基於擴散的生成模型於2015年首次引入,並於2020年推廣,當時Ho等人發表了論文“去噪擴散機率模型”(DDPM)。DDPM負責使擴散模型實用。在本文中,我們將重點介紹DDPM背後的關鍵概念和技術,並在“花”資料集上從頭開始訓練DDPM,以實現無條件影像生成。

無條件影像生成

在DDPM中,作者改變了公式和模型訓練程式,這有助於提高和實現與GAN相媲美的“影像保真度”,並確立了這些新生成演算法的有效性。

完全理解“去噪擴散機率模型”的最佳方法是複習理論(+一些數學)和底層程式碼。考慮到這一點,讓我們探索學習路徑,其中:

  • 我們將首先解釋什麼是生成模型以及為什麼需要它們。
  • 我們將從理論的角度討論基於擴散的生成模型中使用的方法
  • 我們將探索理解去噪擴散機率模型所需的所有數學。
  • 最後,我們將討論DDPM中用於影像生成的訓練和推理,並在PyTorch中從頭開始進行編碼。

1. 生成模型的必要性

基於影像的生成模型的工作是生成相似的新影像,換句話說,是我們原始影像集的“代表”。

我們需要建立和訓練生成模型,因為可以用(256x256x3)影像表示的所有可能影像的集合是巨大的。影像必須具有正確的畫素值組合來表示有意義的東西(我們可以理解的東西)。

An image of a Sunflower.

An RGB image of a Sunflower

例如,為了使上面的影像代表“向日葵”,影像中的畫素需要處於正確的配置中(它們需要具有正確的值)。而這些影像存在的空間只是(256x256x3)影像空間所表示的整個影像集的一小部分。

現在,如果我們知道如何從這個子空間中獲取/取樣一個點,我們就不需要構建“生成模型”。然而,在這個時間點,我們不需要。😓

捕獲/建模這個(資料)子空間的機率分佈函式,或者更確切地說,機率密度函式(PDF)仍然未知,很可能太複雜而沒有意義。

這就是為什麼我們需要“生成模型”來計算我們的資料滿足的潛在似然函式。

PS:PDF是一個“機率函式”,表示連續隨機變數的密度(似然性),在這種情況下,這意味著一個函式表示影像位於函式引數定義的特定值範圍之間的似然性。

PPS:每個PDF都有一組引數,用於確定分佈的形狀和機率。分佈的形狀隨著引數值的變化而變化。例如,在正態分佈的情況下,我們有均值\(µ\)(mu)和方差\(σ^2\)(sigma)來控制分佈的中心點和擴散。

Effects of changing the parameters of the Gaussian distributions. DDPMEffect of parameters of the Gaussian Distribution

Source: https://magic-with-latents.github.io/latent/posts/ddpms/part2/

2. 什麼是擴散機率模型?

在我們之前的文章“影像生成擴散模型簡介”中,我們沒有討論這些模型背後的數學。我們只提供了擴散模型如何工作的概念性概述,並重點介紹了不同的知名模型及其應用。在本文中,我們將主要關注第一部分。

在本節中,我們將從邏輯和理論的角度解釋基於擴散的生成模型。接下來,我們將回顧從頭開始理解和實現去噪擴散機率模型所需的所有數學。

擴散模型是一類受非平衡統計物理學思想啟發的生成模型,該思想指出:

我們可以使用馬爾可夫鏈逐步將一種分佈轉換為另一種分佈

—— 使用非平衡熱力學的深度無監督學習,2015年

擴散生成模型由兩個相反的過程組成,即正向和反向擴散過程。

2.1 正向擴散過程

“破壞容易,創造難”

—— 賽珍珠

  1. 在“正向擴散”過程中,我們緩慢迭代地向訓練集中的影像新增噪聲(破壞),使它們“移出或遠離”現有的子空間。
  2. 我們在這裡所做的是將我們的訓練集所屬的未知和複雜的分佈轉換為一個易於我們取樣和理解的(資料)點。
  3. 在正向過程結束時,影像變得完全無法識別。複雜的資料分佈被完全轉化為(選定的)簡單分佈。每個影像都被對映到資料子空間之外的空間。

Slowly transforming the data distribution in the forward diffusion process as done in diffusion probabilistic models (DDPM).Source: https://ayandas.me/blog-tut/2021/12/04/diffusion-prob-models.html

2.2 反向擴散過程

透過將影像形成過程分解為去噪自編碼器的順序應用,擴散模型(DM)在影像資料及其他方面實現了最先進的合成結果。

——穩定擴散,2022年

A high-level conceptual overview of the entire image space. Denoising Diffusion Probabilistic Models

A high-level conceptual overview of the entire image space.

  1. 在“反向擴散過程”中,其思想是逆轉正向擴散過程。
  2. 我們緩慢而迭代地嘗試逆轉正向過程中對影像執行的損壞。
  3. 反向過程從正向過程結束的地方開始。
  4. 從一個簡單的空間開始的好處是,我們知道如何從這個簡單的分佈中獲取/取樣一個點(可以把它想象成資料子空間之外的任何點)。
  5. 我們的目標是找出如何返回資料子空間。
  6. 然而,問題是,我們可以從這個“簡單”空間中的一個點開始走無限的路徑,但只有其中的一小部分會把我們帶到“資料”子空間。
  7. 在擴散機率模型中,這是透過參考正向擴散過程中採取的小迭代步驟來實現的。
  8. 滿足正向過程中損壞影像的PDF在每一步都略有不同。
  9. 因此,在反向過程中,我們在每一步都使用深度學習模型來預測正向過程的PDF引數。
  10. 一旦我們訓練了模型,我們就可以從簡單空間中的任何一點開始,並使用模型迭代地採取步驟,將我們帶回資料子空間。
  11. 在反向擴散中,我們從有噪聲的影像開始,逐步迭代地執行“去噪”。
  12. 這種訓練和生成新樣本的方法比GAN更穩定,也比變分自編碼器(VAE)和歸一化流等以前的方法更好。

A gif illustrating the inference stage of diffusion probabilistic models.

自2020年推出以來,DDPM一直是尖端影像生成系統的基礎,包括DALL-E 2、Imagen、Stable Diffusion和Midjourney。

隨著當今人工智慧藝術生成工具的大量出現,很難為特定的用例找到合適的工具。在我們最近的文章中,我們探討了所有不同的人工智慧藝術生成工具,以便您可以做出明智的選擇來生成最好的藝術。

3. 去噪擴散機率模型背後的數學細節

由於這篇文章背後的動機是“從頭開始建立和訓練去噪擴散機率模型”,我們可能不得不介紹它們背後的數學魔法,而不是全部。

在本節中,我們將介紹所有必需的數學,同時確保它也易於理解。

讓我們開始…

Illustration of the forward and reverse diffusion process in denoising diffusion probabilistic models (DDPMs).

箭頭上提到了兩個術語:

  1. \(q(x_{t}|x_{t-1})\)

    1. 這個術語也被稱為前向擴散核(FDK)
    2. 它定義了給定影像xt-1的正向擴散過程xt中時間步長t處影像的PDF。
    3. 它表示正向擴散過程中每一步應用的“過渡函式”。
  2. \(p_{\theta}(x_{t-1}|x_{t})\)

    1. 與正向過程類似,它被稱為反向擴散核(RDK)
    2. 它代表\(x_{t-1}\)的PDF,其中\(x_t\)\(𝜭\)引數化。\(\theta\)表示使用神經網路學習反向過程分佈的引數。
    3. 這是反向擴散過程中每一步應用的“過渡函式”。

3.1 正向擴散過程的數學細節

正向擴散過程中的分佈\(q\)定義為馬爾可夫鏈,由下式給出:

\[q(x_{1},\ldots,x_{T}|x_{0}):=\prod_{t=1}^{T}q(x_{t}|x_{t-1})\ldots(1)\\q(x_{t}|x_{t-1}):=\mathcal{N}(x_{t};\sqrt{1-\beta_{t}}x_{t-1},\beta_{t}I)\quad\ldots(2) \]

  1. 我們首先從資料集中獲取一張影像:\(x_0\)。從數學上講,它被表述為從原始(但未知)資料分佈中取樣一個資料點:\(x_{0}\sim q(x_{0})\)
  2. 正向過程的PDF是從時間步\(1→T\)開始的個體分佈的產物
  3. 正向擴散過程是固定且已知的。
  4. 從時間步長\(1\)\(T\)的所有中間噪聲影像也稱為“延遲”。延遲的維度與原始影像相同。
  5. 用於定義FDK的PDF是“正態/高斯分佈”(方程式2)。
  6. 在每個時間步長\(t\),定義影像\(x_t\)分佈的引數設定為:
    • 平均值:\(\sqrt{1-\beta_{t}} x_{t-1}\)
    • 協方差:\(\beta_{t}I\)
  7. 術語\(β\)被稱為“擴散率”,並使用“方差排程器”預先計算。術語\(I\)是一個恆等矩陣。因此,每個時間步長的分佈稱為各向同性高斯分佈。
  8. 原始影像在每個時間步長都會因新增少量高斯噪聲(\(\epsilon\))而損壞。新增的噪聲量由排程器調節。
  9. 透過選擇足夠大的時間步長並定義一個行為良好的\(\beta_t\)排程,重複應用FDK逐漸將資料分佈轉換為近似各向同性高斯分佈。

A modified image of diffusion process illustration focusing on forward diffusion process.

我們如何從\(x_{t-1}\)中獲得影像\(x_t\),以及如何在每個時間步長新增噪聲?

透過在變分自編碼器中使用重引數化技巧,可以很容易地理解這一點。

參考第二個方程,我們可以很容易地從正態分佈中取樣影像\(x_t\),如下所示:

\[\begin{aligned}x_{t}&=\sqrt{1-\beta_{t}} x_{t-1}+\sqrt{\beta_{t}} \epsilon\quad\ldots(3)\\&;\mathrm{where}\ \epsilon\sim\mathcal{N}(0,I)\end{aligned} \]

  1. 這裡,\(\epsilon\)是從標準高斯分佈中隨機取樣的“噪聲”項,首先進行縮放,然後新增(縮放)\(x_{t-1}\)
  2. 這樣,從\(x_0\)開始,原始影像從\(t=1…T\)迭代地被破壞

在實踐中,DDPM的作者使用“線性方差排程器”,在\([0.001,\ldots,0.02]\)範圍內定義\(\beta\),並設定總時間步長\(T=1000\)

“擴散模型透過每個正向過程步驟(按因子)縮小資料,這樣在新增噪聲時方差就不會增加。”

—— 去噪擴散機率模型,2020年

A graph to show how the value of beta terms changes depending on the timesteps.Variance Scheduler vs timesteps

這裡有一個問題,導致正向擴散過程效率低下🐢.

每當我們需要時間步長\(t\)的潛在樣本\(x_t\)時,我們必須在馬爾可夫鏈中執行\(t-1\)步。

In the current formulation of forward diffusion kernel we have no choice but to traverse the Markov chain to get to timestep t.We have to follow through all \(t-1\)intermediate states in Markov Chain to get\(x_t\)

為了解決這個問題,DDPM的作者重新制定了核心,使其在過程中直接從時間步長\(0\)(即從原始影像)變為時間步長\(t\)

In the modified formulation of the forward diffusion kernel used in denoising diffusion probabilistic models (DDPMs), we can skip all the intermediate timesteps.Skipping intermediate steps

為此,定義了兩個附加術語:

\[\alpha_{t}:= 1-\beta_{t}\quad\ldots(4)\\\bar{\alpha}_{t}:=\prod_{s=1}^{t}\alpha_{s}\quad\ldots(5) \]

其中式(5) 是從\(1\)\(T\)\(𝛂\)的累積乘積。

然後,透過將\(𝝱'\)替換為\(𝛂'\),並利用高斯分佈的加法性質。正向擴散過程可以改寫為\(𝛂\)

\[\begin{aligned}q(x_{t}|x_{0})&:=\mathcal{N}(x_{t};\sqrt{\bar{\alpha}_{t}}x_{0},(1-\bar{\alpha}_{t})I)\quad\ldots(6)\\x_{t}&:=\sqrt{\bar{\alpha}_{t}}x_{0}+\sqrt{1-\bar{\alpha}_{t}}\epsilon\end{aligned} \]

🚀 使用上述公式,我們可以在馬爾可夫鏈中的任意時間步長\(t\)進行取樣。

這就是正向擴散過程。

3.2 反向擴散過程的數學細節

In the reverse diffusion process, we try and follow the same route as the forward diffusion process but in reverse.

Czech Hiking Markers System. Following the path to take in the return journey.

“在反向擴散過程中,任務是學習正向擴散過程的有限時間(在\(T\)個時間步長內)反轉。”

這基本上意味著我們必須“撤消”正向過程,即迭代地去除正向過程中新增的噪聲。這是使用神經網路模型完成的。

在正向過程中,轉換函式\(q\)是使用高斯函式定義的,那麼反向過程\(p\)應該使用什麼函式呢?神經網路應該學習什麼?

  1. 1949年,W.Feller證明,對於高斯(和二項式)分佈,擴散過程的反轉與正向過程具有相同的函式形式。
  2. 這意味著,與定義為正態分佈的FDK類似,我們可以使用相同的函式形式(高斯分佈)來定義反向擴散核。
  3. 反向過程也是馬爾可夫鏈,其中神經網路在每個時間步預測反向擴散核的引數。
  4. 在訓練過程中,學習到的(引數的)估計值應接近FDK在每個時間步的後驗引數。我們將在下一節中更多地討論FDK的後驗。
  5. 我們想要這樣做,因為如果我們反向遵循正向軌跡,我們可能會回到原始資料分佈。
  6. 在此過程中,我們還將學習如何從純高斯噪聲開始生成與底層資料分佈緊密匹配的新樣本(我們在推理過程中無法訪問正向過程)。

A modified illustration of diffusion process focusing on reverse diffusion process.

  1. 反向擴散的馬爾可夫鏈從正向過程結束的地方開始,即在時間步長\(T\)處,資料分佈已被轉換為(幾乎)各向同性高斯分佈。

    \[q(x_{T})\approx\mathcal{N}(x_{t};0,I)\\p(x_{T}):=\mathcal{N}(x_{l};0,I)\ldots(7) \]

  2. 反向擴散過程的PDF是我們從純噪聲\(x_T\)開始得到資料樣本(與原始分佈相同)的所有可能路徑的“積分”。

    \[p_\theta(x_0):=\int p_\theta(x_{0:T})dx_{1:T} \]

    \[p_{\theta}(\mathbf{x}_{0:T}):=p(\mathbf{x}_{T})\prod_{t=1}^{T}p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t}),\quad p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t}):=\mathcal{N}(\mathbf{x}_{t-1};\mu_{\theta}(\mathbf{x}_{t},t),\Sigma_{\theta}(\mathbf{x}_{t},t)) \]

All equations used to define the forward and reverse diffusion process in denoising diffusion probabilistic models (DDPMs). All equations related to the forward and reverse diffusion processes.

3.3 用於去噪擴散機率模型的訓練目標和損失函式

基於擴散的生成模型的訓練目標相當於“最大化(在反向過程結束時)生成的樣本(\(x\))屬於原始資料分佈的對數似然”

我們將擴散模型中的轉換函式定義為“高斯函式”。為了最大化高斯分佈的對數似然性,需要嘗試找到分佈的引數(\(𝞵\)\(𝝈^2\)),使(生成的)資料與原始資料屬於相同資料分佈的“似然性”最大化。

為了訓練我們的神經網路,我們將損失函式(\(L\))定義為目標函式的負值。因此,\(p_{\theta}(x_{0})\)的高值意味著低損失,反之亦然。

\[p_{\theta}(x_{0}):=\int p_{\theta}(x_{0:T})dx_{1:T}\\L=-log(p_{\theta}(x_{0})) \]

事實證明,這很難解決,因為我們需要在非常高的維度(畫素)空間上對\(T\)時間步長上的連續值進行積分。

相反,作者從VAE中汲取靈感,使用變分下限(VLB)重新制定訓練目標,也稱為“證據下限”(ELBO),這是一個看起來很可怕的方程👻

\[\mathbb{E}\left[-\log p_\theta(\mathbf{x}_0)\right]\leq\mathbb{E}_q\left[-\log\frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}\right]=\mathbb{E}_q\left[-\log p(\mathbf{x}_T)-\sum_{t\geq1}\log\frac{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{q(\mathbf{x}_t|\mathbf{x}_{t-1})}\right]=:L \]

denoising-diffusion-probabilistic-models-andrew_ng_meme – LearnOpenCV Prof. Andrew Ng to the rescue 🐱‍🏍

經過一些簡化,DDPM作者得出了這個最終的\(L_{vlb}\)——變分下限損失項:

\[\mathbb{E}_{q}\bigg[\underbrace{D_{\mathrm{KL}}(q(\mathbf{x}_{T}|\mathbf{x}_{0})\parallel p_{\theta}(\mathbf{x}_{T}))}_{L_{T}}+\sum_{t>1}\underbrace{D_{\mathrm{KL}}(q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})\parallel p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t}))}_{L_{t-1}}\underbrace{-\log p_{\theta}(\mathbf{x}_{\theta}|\mathbf{x}_{1})}_{L_{0}}\bigg] \]

我們可以將上述\(L_{vlb}\)損失項分解為單獨的時間步長,如下所示:

\[\begin{aligned} L_{\mathrm{vlb}}& :=L_0+L_1+...+L_{T-1}+L_T \\ L_{0}& :=-\log p_\theta(x_0|x_1) \\ L_{t-1}& :=D_{KL}(q(x_{i-1}|x_{i},x_{0})\parallel p_{\theta}(x_{i-1}|x_{i})) \\ L_{T}& :=D_{KL}(q(x_{T}|x_{0})\parallel p(x_{T})) \end{aligned} \]

你可能會注意到這個損失函式是巨大的!但DDPM的作者透過忽略簡化損失函式中的一些項進一步簡化了它。

被忽略的項包括:

  1. \(L_0\)——作者在沒有這個的情況下獲得了更好的結果。

  2. \(L_T\)——這是正向過程中最終潛分佈和反向過程中第一個潛分佈之間的“KL散度”。然而,這裡沒有涉及神經網路引數,所以我們除了定義一個好的方差排程器並使用大的時間步長外,什麼也做不了,這樣它們都表示各向同性高斯分佈。

因此,\(L_{t-1}\)是唯一剩下的損失項,它是正向過程(以\(x_t\)和初始樣本\(x_0\)為條件)的“後驗”與引數化反向擴散過程之間的KL散度。這兩個項也是高斯分佈

\[L_{vlb}:=L_{t-1}:=D_{KL}(q(x_{t-1}|x_{t},x_{0})||p_{\theta}(x_{t-1}|x_{t})) \]

術語\(\mathrm{q(x_{t-1}|x_{t},x_{0})}\)被稱為“前向過程後向分佈”

我們的深度學習模型在訓練過程中的工作是近似/估計這個(高斯)後驗的引數,使KL散度儘可能小。

Image to illustrate the point why we need to minimize the KL divergence between the forward posterior and the reverse process in denoising diffusion probabilistic models (DDPMs).

後驗分佈的引數如下:

\[\begin{aligned} q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})& =\mathcal{N}(\mathbf{x}_{t-1};\tilde{\boldsymbol{\mu}}_{t}(\mathbf{x}_{t},\mathbf{x}_{0}),\tilde{\boldsymbol{\beta}}_{t}\mathbf{I}), \\ \mathrm{where}\quad\tilde{\mu}_{t}(\mathbf{x}_{t},\mathbf{x}_{0})& :=\frac{\sqrt{\alpha_{t-1}}\beta_{t}}{1-\bar{\alpha}_{t}}\mathbf{x}_{0}+\frac{\sqrt{\alpha_{t}}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}\mathbf{x}_{t}\quad\mathrm{and}\quad\tilde{\beta}_{t}:=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}}\beta_{t} \end{aligned} \]

為了進一步簡化模型的任務,作者決定將方差固定為常數\(\beta_t\)

現在,模型只需要學習預測上述方程。反向擴散核被修改為:

\[p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})=\mathcal{N}(\mathbf{x}_{t-1};\mu_{\theta}(\mathbf{x}_{t},t),\Sigma_{\theta}(\mathbf{x}_{t},t)) \]

\[p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})={\mathcal{N}}(\mathbf{x}_{t-1};\mu_{\theta}(\mathbf{x}_{t},t),\sigma^{2}I) \]

由於我們保持方差恆定,最小化KL散度就像最小化兩個高斯分佈\(q\)\(p\)的均值(𝞵)之間的差(或距離)一樣簡單(例如,左影像中分佈均值之間的差),可以按如下方式完成:

\[L_{t-1}=\mathbb{E}_{q}\left[\frac{1}{2\sigma_{t}^{2}}\|\tilde{\mu}_{t}(\mathbf{x}_{t},\mathbf{x}_{0})-\mu_{\theta}(\mathbf{x}_{t},t)\|^{2}\right]+C \]

現在,我們可以採取三種方法:

  1. 直接預測\(x_0\)並在後驗函式中使用它進行查詢\(\tilde{\mu}\)
  2. 預測整個\(\tilde{\mu}\)
  3. 預測每個時間步的噪音。這是透過使用重新引數化技巧將\(\tilde{\mu}\)\(x_0\)寫成\(x_t\)來實現的。\(\mathbf{x}_{t}:=\sqrt{\bar{\alpha}_{t}}\mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}}\epsilon\)

透過使用第三種選擇,經過一些簡化,\(\tilde{\mu}\)可以表示為:

\[\tilde{\mu}(x_{t},x_{0})=\frac{1}{\sqrt{\bar{\alpha}_{t}}}\left(x_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}}\epsilon_{t}\right) \]

同樣,\(\mu_{\theta}(x_{t},t)\)的公式設定為:

\[\mu_\theta(x_t,x_0)=\frac{1}{\sqrt{\bar{\alpha}_t}}\left(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t,t)\right) \]

在訓練和推理時,我們知道\(𝝱\)\(𝛂\)\(x_t\)。因此,我們的模型只需要預測每個時間步長的噪聲。去噪擴散機率模型中使用的簡化(忽略一些加權項後)損失函式如下:

\[L_{\mathrm{simple}}(\theta):=\mathbb{E}_{t,\mathbf{x}_{0},\epsilon}\left[\left\|\epsilon-\epsilon_{\theta}(\sqrt{\bar{\alpha}_{t}}\mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}}\boldsymbol{\epsilon},t)\right\|^{2}\right] \]

Comparing just the noise.

這基本上是:

\[L_{\mathrm{simple}}=E_{t,x_0,\epsilon}\left[||\epsilon-\epsilon_\theta(x_t,t)||^2\right] \]

這是我們用來訓練DDPM的最終損失函式,它只是正向過程中新增的噪聲與模型預測的噪聲之間的“均方誤差”。這是本文對擴散機率模型去噪的最有影響力的貢獻。

這太棒了,因為從那些看起來很可怕的ELBO術語開始,我們最終得到了整個機器學習領域中最簡單的損失函式。

4. 在PyTorch中從頭開始編寫DDPM

從本節開始,我們將在PyTorch中從頭開始編寫訓練去噪擴散機率模型所需的所有基本元件。我們使用Kaggle核心代替Colab,因為它提供了比Colab免費版本更好的GPU和更長的訓練時間(這對擴散模型至關重要)。

注意:經常使用的輔助函式的程式碼不會新增到帖子中。

💡 您可以透過訂閱部落格文章來訪問此文章和我們所有其他文章的整個程式碼庫,我們將向您傳送下載連結。

https://github.com/spmallick/learnopencv/tree/master/Guide-to-training-DDPMs-from-Scratch

首先,我們將定義配置類,這些類將包含用於載入資料集、建立日誌目錄和訓練模型的超引數。

from dataclasses import dataclass
 
@dataclass
class BaseConfig:
    DEVICE = get_default_device()
    DATASET = "Flowers"  #  "MNIST", "Cifar-10", "Flowers"
 
    # For logging inferece images and saving checkpoints.
    root_log_dir = os.path.join("Logs_Checkpoints", "Inference")
    root_checkpoint_dir = os.path.join("Logs_Checkpoints", "checkpoints")
 
    # Current log and checkpoint directory.
    log_dir = "version_0"
    checkpoint_dir = "version_0"
 
 
@dataclass
class TrainingConfig:
    TIMESTEPS = 1000  # Define number of diffusion timesteps
    IMG_SHAPE = (1, 32, 32) if BaseConfig.DATASET == "MNIST" else (3, 32, 32)
    NUM_EPOCHS = 800
    BATCH_SIZE = 32
    LR = 2e-4
    NUM_WORKERS = 2

5. 建立PyTorch資料集類物件

本文使用“Flowers”資料集,該資料集可以從Kaggle下載或快速載入到Kaggle核心環境中。但您可能已經注意到,在BaseConfig類中,我們還提供了載入MNIST、Cifare10和Cifare100資料集的選項。你可以選擇你喜歡的。

flowers資料集可以從這裡下載:Flowers Recognition | Kaggle

使用Kaggle核心時,只需單擊“新增資料”元件並選擇資料集即可。

在這裡,我們建立兩個函式:

  1. get_dataset(…):返回將傳遞給Dataloader的資料集類物件。對資料集中的每個影像應用三個預處理變換和一個增強。
    1. 預處理:
      1. 轉換[0, 255]→[0.0, 1.0]範圍內的畫素值
      2. 根據形狀調整影像大小(32x32)。
      3. 從[0.0, 1.0]→[-1.0, 1.0]範圍更改畫素值。這是由DDPM作者完成的,這樣輸入影像的值範圍與標準高斯影像大致相同。
    2. 增強:
      1. 隨機水平翻轉,如原始實現中使用的。如果你使用的是MNIST資料集,一定要註釋掉這一行。
  2. inverse_transfers(…):此函式用於反轉載入步驟中應用的變換,並將影像恢復到[0.0, 255.0]範圍。
import torchvision
import torchvision.transforms as TF
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader
 
 
def get_dataset(dataset_name='MNIST'):
    transforms = torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Resize((32, 32), 
                                          interpolation=torchvision.transforms.InterpolationMode.BICUBIC, 
                                          antialias=True),
            torchvision.transforms.RandomHorizontalFlip(),
#             torchvision.transforms.Normalize(MEAN, STD),
            torchvision.transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] 
        ]
    )
     
    if dataset_name.upper() == "MNIST":
        dataset = datasets.MNIST(root="data", train=True, download=True, transform=transforms)
    elif dataset_name == "Cifar-10":    
        dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transforms)
    elif dataset_name == "Cifar-100":
        dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transforms)
    elif dataset_name == "Flowers":
        dataset = datasets.ImageFolder(root="/kaggle/input/flowers-recognition/flowers", transform=transforms)
 
    return dataset
 
def inverse_transform(tensors):
    """Convert tensors from [-1., 1.] to [0., 255.]"""
    return ((tensors.clamp(-1, 1) + 1.0) / 2.0) * 255.0

6. 建立PyTorch資料載入器類物件

接下來,我們定義get_dataloader(…)函式,該函式返回所選資料集的dataloader物件。

def get_dataloader(dataset_name='MNIST', 
                   batch_size=32, 
                   pin_memory=False, 
                   shuffle=True, 
                   num_workers=0, 
                   device="cpu"
                  ):
    dataset      = get_dataset(dataset_name=dataset_name)
    dataloader = DataLoader(dataset, batch_size=batch_size, 
                            pin_memory=pin_memory, 
                            num_workers=num_workers, 
                            shuffle=shuffle
                           )
    # Used for moving batch of data to the user-specified machine: cpu or gpu
    device_dataloader = DeviceDataLoader(dataloader, device)
    return device_dataloader

7. 視覺化資料集

首先,我們將透過呼叫get_dataloader(…)函式來建立“dataloader”物件。

loader = get_dataloader(
    dataset_name=BaseConfig.DATASET,
    batch_size=128,
    device=”cpu”,
)

然後,我們可以簡單地使用torchvision的make_grid(…)函式來繪製花朵影像的網格。

from torchvision.utils import make_grid
 
plt.figure(figsize=(10, 4), facecolor='white')
 
for b_image, _ in loader:
    b_image = inverse_transform(b_image)
    grid_img = make_grid(b_image / 255.0, nrow=16, padding=True, pad_value=1)
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.axis("off")
    break

The flowers dataset used for training DDPMs from scratch.Flowers Dataset

8. DDPM中使用的模型架構

在DDPM中,作者使用了一個UNet形狀的深度神經網路,該網路將以下內容作為輸入:

  1. 在反向過程的任何階段輸入影像。

  2. 輸入影像的時間步長。

從通常的UNet架構開始,作者用ResNet模型中使用的“殘差塊”替換了每個級別的原始雙卷積。

該架構由5個元件組成:

  1. 編碼器塊
  2. 瓶頸塊
  3. 解碼器塊
  4. 自注意力模組
  5. 正弦位置編碼

結構細節:

  1. 編碼器和解碼器路徑中有四個級別,它們之間有瓶頸塊。
  2. 每個編碼器級包括兩個殘差塊,除了最後一級之外,其餘都進行了卷積下采樣。
  3. 每個相應的解碼器級包括三個殘差塊,並使用2x最近鄰卷積對前一級的輸入進行上取樣。
  4. 編碼器路徑中的每個階段都在跳過連線的幫助下連線到解碼器路徑。
  5. 該模型使用單一特徵圖解析度的“自我關注”模組。
  6. 模型中的每個殘差塊都從前一層(以及解碼器路徑中的其他層)獲得輸入,並嵌入當前時間步長。時間步長嵌入通知模型輸入在馬爾可夫鏈中的當前位置。

The architecture of the UNet model used in denoising diffusion probabilistic models (DDPMs).The U-Net architecture used in DDPMs

在本文中,我們正在研究(32×32)的影像大小。我們的模型和本文中使用的原始模型之間只存在兩個微小的變化。

  1. 我們使用64個基本通道,而不是128個。
  2. 編碼器和解碼器路徑都有四個級別。每個級別的特徵圖大小保持如下:32→16→8→8。我們在特徵圖大小為(16x16)和(8x8)時應用自我注意,而不是在原始情況下,它們在特徵圖尺寸為(16x16)時只應用一次。

請注意,我們沒有新增模型程式碼,因為UNet+的程式碼很容易修改,但因為所有不同的元件。它變得太大了,無法新增到帖子中。

9. 擴散類

在本節中,我們將建立一個名為SimpleDiffusion的類。此類包含:

  1. 執行正向和反向擴散過程所需的排程器常量。

  2. 定義DDPM中使用的線性方差排程器的方法。

  3. 一種使用更新的前向擴散核執行單個步驟的方法。

class SimpleDiffusion:
    def __init__(
        self,
        num_diffusion_timesteps=1000,
        img_shape=(3, 64, 64),
        device="cpu",
    ):
        self.num_diffusion_timesteps = num_diffusion_timesteps
        self.img_shape = img_shape
        self.device = device
        self.initialize()
 
    def initialize(self):
        # BETAs & ALPHAs required at different places in the Algorithm.
        self.beta  = self.get_betas()
        self.alpha = 1 - self.beta
         
        self_sqrt_beta                       = torch.sqrt(self.beta)
        self.alpha_cumulative                = torch.cumprod(self.alpha, dim=0)
        self.sqrt_alpha_cumulative           = torch.sqrt(self.alpha_cumulative)
        self.one_by_sqrt_alpha               = 1. / torch.sqrt(self.alpha)
        self.sqrt_one_minus_alpha_cumulative = torch.sqrt(1 - self.alpha_cumulative)
          
    def get_betas(self):
        """linear schedule, proposed in original ddpm paper"""
        scale = 1000 / self.num_diffusion_timesteps
        beta_start = scale * 1e-4
        beta_end = scale * 0.02
        return torch.linspace(
            beta_start,
            beta_end,
            self.num_diffusion_timesteps,
            dtype=torch.float32,
            device=self.device,
        )

10. 正向擴散過程的Python程式碼

在本節中,我們將編寫python程式碼,根據這裡提到的方程式在一個步驟中執行“正向擴散過程”。

forward_diffusion(...)函式接收一批影像和相應的時間步長,並使用更新的前向擴散核方程新增噪聲/破壞輸入影像。

def forward_diffusion(sd: SimpleDiffusion, x0: torch.Tensor, timesteps: torch.Tensor):
    eps = torch.randn_like(x0)  # Noise
    mean    = get(sd.sqrt_alpha_cumulative, t=timesteps) * x0  # Image scaled
    std_dev = get(sd.sqrt_one_minus_alpha_cumulative, t=timesteps) # Noise scaled
    sample  = mean + std_dev * eps # scaled inputs * scaled noise
 
    return sample, eps  # return ... , gt noise --> model predicts this

10.1 樣本影像正向擴散過程的視覺化

在本節中,我們將視覺化一些樣本影像的前向擴散過程,看看它們在\(T\)個時間步內透過馬爾可夫鏈時是如何被破壞的。

sd = SimpleDiffusion(num_diffusion_timesteps=TrainingConfig.TIMESTEPS, device="cpu")
 
loader = iter(  # converting dataloader into an iterator for now.
    get_dataloader(
        dataset_name=BaseConfig.DATASET,
        batch_size=6,
        device="cpu",
    )
)

對某些特定時間步執行正向處理,並儲存原始影像的噪聲版本。

x0s, _ = next(loader)
 
noisy_images = []
specific_timesteps = [0, 10, 50, 100, 150, 200, 250, 300, 400, 600, 800, 999]
 
for timestep in specific_timesteps:
    timestep = torch.as_tensor(timestep, dtype=torch.long)
 
    xts, _ = sd.forward_diffusion(x0s, timestep)
    xts    = inverse_transform(xts) / 255.0
    xts    = make_grid(xts, nrow=1, padding=1)
     
    noisy_images.append(xts)

繪製不同時間步的樣本損壞情況。

_, ax = plt.subplots(1, len(noisy_images), figsize=(10, 5), facecolor='white')
 
for i, (timestep, noisy_sample) in enumerate(zip(specific_timesteps, noisy_images)):
    ax[i].imshow(noisy_sample.squeeze(0).permute(1, 2, 0))
    ax[i].set_title(f"t={timestep}", fontsize=8)
    ax[i].axis("off")
    ax[i].grid(False)
 
plt.suptitle("Forward Diffusion Process", y=0.9)
plt.axis("off")
plt.show()

Images are corrupted in the forward process of diffusion probabilistic models.The original image gets increasingly corrupted as timesteps increase. At the end of the forward process, we are left with noise.

11. 用於去噪擴散機率模型的訓練和取樣演算法

The training and sampling algorithm as described in the DDPMs paper.

基於演算法1的訓練程式碼

這裡定義的第一個函式是train_one_epoch(…)。此函式用於執行“一個訓練週期”,即它透過在整個資料集上迭代一次來訓練模型,並將在我們的最終訓練迴圈中呼叫。

我們還使用混合精度訓練來更快地訓練模型並節省GPU記憶體。程式碼非常簡單,幾乎是演算法的一對一轉換。

# Algorithm 1: Training
 
def train_one_epoch(model, loader, sd, optimizer, scaler, loss_fn, epoch=800, 
                   base_config=BaseConfig(), training_config=TrainingConfig()):
     
    loss_record = MeanMetric()
    model.train()
 
    with tqdm(total=len(loader), dynamic_ncols=True) as tq:
        tq.set_description(f"Train :: Epoch: {epoch}/{training_config.NUM_EPOCHS}")
          
        for x0s, _ in loader: # line 1, 2
            tq.update(1)
             
            ts = torch.randint(low=1, high=training_config.TIMESTEPS, size=(x0s.shape[0],), device=base_config.DEVICE) # line 3
            xts, gt_noise = sd.forward_diffusion(x0s, ts) # line 4
 
            with amp.autocast():
                pred_noise = model(xts, ts)
                loss = loss_fn(gt_noise, pred_noise) # line 5
 
            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
 
            # scaler.unscale_(optimizer)
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
 
            scaler.step(optimizer)
            scaler.update()
 
            loss_value = loss.detach().item()
            loss_record.update(loss_value)
 
            tq.set_postfix_str(s=f"Loss: {loss_value:.4f}")
 
        mean_loss = loss_record.compute().item()
     
        tq.set_postfix_str(s=f"Epoch Loss: {mean_loss:.4f}")
     
    return mean_loss

基於演算法2的取樣或推理程式碼

我們定義的下一個函式是reverse_diffusion(...),它負責執行推理,即使用反向擴散過程生成影像。該函式接受一個訓練好的模型和擴散類,可以生成一個展示整個擴散過程的影片,也可以僅生成最終生成的影像。

# Algorithm 2: Sampling
     
@torch.no_grad()
def reverse_diffusion(model, sd, timesteps=1000, img_shape=(3, 64, 64), 
                      num_images=5, nrow=8, device="cpu", **kwargs):
 
    x = torch.randn((num_images, *img_shape), device=device)
    model.eval()
 
    if kwargs.get("generate_video", False):
        outs = []
 
    for time_step in tqdm(iterable=reversed(range(1, timesteps)), 
                          total=timesteps-1, dynamic_ncols=False, 
                          desc="Sampling :: ", position=0):
 
        ts = torch.ones(num_images, dtype=torch.long, device=device) * time_step
        z = torch.randn_like(x) if time_step > 1 else torch.zeros_like(x)
 
        predicted_noise = model(x, ts)
 
        beta_t                            = get(sd.beta, ts)
        one_by_sqrt_alpha_t               = get(sd.one_by_sqrt_alpha, ts)
        sqrt_one_minus_alpha_cumulative_t = get(sd.sqrt_one_minus_alpha_cumulative, ts) 
 
        x = (
            one_by_sqrt_alpha_t
            * (x - (beta_t / sqrt_one_minus_alpha_cumulative_t) * predicted_noise)
            + torch.sqrt(beta_t) * z
        )
 
        if kwargs.get("generate_video", False):
            x_inv = inverse_transform(x).type(torch.uint8)
            grid = make_grid(x_inv, nrow=nrow, pad_value=255.0).to("cpu")
            ndarr = torch.permute(grid, (1, 2, 0)).numpy()[:, :, ::-1]
            outs.append(ndarr)
 
    if kwargs.get("generate_video", False): # Generate and save video of the entire reverse process. 
        frames2vid(outs, kwargs['save_path'])
        display(Image.fromarray(outs[-1][:, :, ::-1])) # Display the image at the final timestep of the reverse process.
        return None
 
    else: # Display and save the image at the final timestep of the reverse process. 
        x = inverse_transform(x).type(torch.uint8)
        grid = make_grid(x, nrow=nrow, pad_value=255.0).to("cpu")
        pil_image = TF.functional.to_pil_image(grid)
        pil_image.save(kwargs['save_path'], format=save_path[-3:].upper())
        display(pil_image)
        return None

12. 從頭開始訓練DDPM

在前面的部分中,我們已經定義了訓練所需的所有必要類和函式。我們現在要做的就是組裝它們並開始訓練過程。

在我們開始訓練之前:

  • 我們將首先定義所有與模型相關的超引數。

  • 然後初始化UNet模型、AdamW最佳化器、MSE損失函式和其他必要的類。

@dataclass
class ModelConfig:
    BASE_CH = 64  # 64, 128, 256, 256
    BASE_CH_MULT = (1, 2, 4, 4) # 32, 16, 8, 8 
    APPLY_ATTENTION = (False, True, True, False)
    DROPOUT_RATE = 0.1
    TIME_EMB_MULT = 4 # 128
 
model = UNet(
    input_channels          = TrainingConfig.IMG_SHAPE[0],
    output_channels         = TrainingConfig.IMG_SHAPE[0],
    base_channels           = ModelConfig.BASE_CH,
    base_channels_multiples = ModelConfig.BASE_CH_MULT,
    apply_attention         = ModelConfig.APPLY_ATTENTION,
    dropout_rate            = ModelConfig.DROPOUT_RATE,
    time_multiple           = ModelConfig.TIME_EMB_MULT,
)
model.to(BaseConfig.DEVICE)
 
optimizer = torch.optim.AdamW(model.parameters(), lr=TrainingConfig.LR) # Original → Adam
 
dataloader = get_dataloader(
    dataset_name  = BaseConfig.DATASET,
    batch_size    = TrainingConfig.BATCH_SIZE,
    device        = BaseConfig.DEVICE,
    pin_memory    = True,
    num_workers   = TrainingConfig.NUM_WORKERS,
)
 
loss_fn = nn.MSELoss()
 
sd = SimpleDiffusion(
    num_diffusion_timesteps = TrainingConfig.TIMESTEPS,
    img_shape               = TrainingConfig.IMG_SHAPE,
    device                  = BaseConfig.DEVICE,
)
 
scaler = amp.GradScaler() # For mixed-precision training.

然後,我們將初始化日誌記錄和檢查點目錄,以儲存中間取樣結果和模型引數。

total_epochs = TrainingConfig.NUM_EPOCHS + 1
log_dir, checkpoint_dir = setup_log_directory(config=BaseConfig())
generate_video = False
ext = ".mp4" if generate_gif else ".png"

最後,我們可以編寫訓練迴圈。由於我們已經將所有程式碼劃分為簡單、易於除錯的函式和類,現在我們所要做的就是在epochs訓練迴圈中呼叫它們。具體來說,我們需要在迴圈中呼叫上一節中定義的“訓練”和“取樣”函式。

for epoch in range(1, total_epochs):
    torch.cuda.empty_cache()
    gc.collect()
     
    # Algorithm 1: Training
    train_one_epoch(model, sd, dataloader, optimizer, scaler, loss_fn, epoch=epoch)
 
    if epoch % 20 == 0:
        save_path = os.path.join(log_dir, f"{epoch}{ext}")
         
        # Algorithm 2: Sampling
        reverse_diffusion(model, sd, timesteps=TrainingConfig.TIMESTEPS, 
                          num_images=32, generate_video=generate_video, save_path=save_path, 
                          img_shape=TrainingConfig.IMG_SHAPE, device=BaseConfig.DEVICE, nrow=4,
        )
 
        # clear_output()
        checkpoint_dict = {
            "opt": optimizer.state_dict(),
            "scaler": scaler.state_dict(),
            "model": model.state_dict()
        }
        torch.save(checkpoint_dict, os.path.join(checkpoint_dir, "ckpt.pt"))
        del checkpoint_dict

如果一切順利,培訓程式應開始並列印培訓日誌,類似於:

DDPM training output logs.

13. 使用DDPM生成影像

如果你對每20個迭代生成的樣本感到滿意,你可以讓訓練完成800個迭代,也可以在其間中斷。

為了執行推理,我們只需重新載入儲存的模型,您可以使用相同或不同的日誌目錄來儲存結果。您也可以重新初始化SimpleDiffusion類,但這不是必需的。

# Reloading model from saved checkpoint
model = UNet(
    input_channels          = TrainingConfig.IMG_SHAPE[0],
    output_channels         = TrainingConfig.IMG_SHAPE[0],
    base_channels           = ModelConfig.BASE_CH,
    base_channels_multiples = ModelConfig.BASE_CH_MULT,
    apply_attention         = ModelConfig.APPLY_ATTENTION,
    dropout_rate            = ModelConfig.DROPOUT_RATE,
    time_multiple           = ModelConfig.TIME_EMB_MULT,
)
model.load_state_dict(torch.load(os.path.join(checkpoint_dir, "ckpt.tar"), map_location='cpu')['model'])
 
model.to(BaseConfig.DEVICE)
 
sd = SimpleDiffusion(
    num_diffusion_timesteps = TrainingConfig.TIMESTEPS,
    img_shape               = TrainingConfig.IMG_SHAPE,
    device                  = BaseConfig.DEVICE,
)
 
log_dir = "inference_results"

推理程式碼只是使用訓練好的模型呼叫reverse_didiffusion(...) 函式。

generate_video = False # Set it to True for generating video of the entire reverse diffusion proces or False to for saving only the final generated image.
 
ext = ".mp4" if generate_video else ".png"
filename = f"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}"
 
save_path = os.path.join(log_dir, filename)
 
reverse_diffusion(
    model,
    sd,
    num_images=256,
    generate_video=generate_video,
    save_path=save_path,
    timesteps=1000,
    img_shape=TrainingConfig.IMG_SHAPE,
    device=BaseConfig.DEVICE,
    nrow=32,
)
print(save_path)

我們得到的一些結果:


Example 1 of unconditional flower image generation using the trained model.

Example 2 of unconditional flower image generation using the trained model.

14. 總結

總之,擴散模型代表了一個快速增長的領域,為未來帶來了豐富的令人興奮的可能性。隨著這一領域的研究不斷髮展,我們可以期待出現更先進的技術和應用。我鼓勵讀者分享他們對這一主題的想法和問題,並就擴散模型的未來進行對話。

總結這篇文章📜, 我們涵蓋了一系列相關主題。

  1. 我們首先為為什麼我們需要生成模型這一基本問題提供了直觀的答案。

  2. 然後,我們繼續討論,從邏輯和理論的角度解釋基於擴散的生成模型。

  3. 在建立了理論基礎後,我們逐一介紹了DDPM推匯出的所有必要的數學方程,同時保持了流暢度,使其易於掌握。

  4. 最後,我們透過解釋從頭開始訓練DDPM和執行推理所需的所有不同程式碼來總結。我們還展示了實驗結果。

參考目錄

  1. What are Diffusion Models?
  2. DDPMs from scratch
  3. Diffusion Models | Paper Explanation | Math Explained
  4. Paper – Deep Unsupervised Learning using Nonequilibrium Thermodynamics
  5. Paper – Denoising Diffusion Probabilistic Models
  6. Paper – Improved Denoising Diffusion Probabilistic Models
  7. Paper – A Survey on Generative Diffusion Model
  8. An introduction to Diffusion Probabilistic Models – Ayan Das
  9. Denoising diffusion probabilistic models – Param Hanji

我們很樂意收到您的來信。請隨時在評論區提問;我們非常樂意與您交談。

🌟快樂學習!

相關文章