diffusion model(一):DDPM技術小結 (denoising diffusion probabilistic)

莫叶何竹發表於2024-05-24

釋出日期:2023/05/18
主頁地址:http://myhz0606.com/article/ddpm

1 從直覺上理解DDPM

在詳細推到公式之前,我們先從直覺上理解一下什麼是擴散

對於常規的生成模型,如GAN,VAE,它直接從噪聲資料生成影像,我們不妨記噪聲資料為\(z\),其生成的圖片為\(x\)

對於常規的生成模型

學習一個解碼函式(即我們需要學習的模型)p,實現 \(p(z)=x\)

\[z \stackrel{p} \longrightarrow x \tag{1} \]

常規方法只需要一次預測即能實現噪聲到目標的對映,雖然速度快,但是效果不穩定。

常規生成模型的訓練過程(以VAE為例)

\[x \stackrel{q} \longrightarrow z \stackrel{p} \longrightarrow \widehat{x} \tag{2} \]

對於diffusion model

它將噪聲到目標的過程進行了多步拆解。不妨假設一共有\(T+1\)個時間步,第\(T\)個時間步 \(x_T\)是噪聲資料,第0個時間步的輸出是目標圖片\(x_0\)。其過程可以表述為:

\[z = x_T \stackrel{p} \longrightarrow x_{T-1} \stackrel{p} \longrightarrow \cdots \stackrel{p} \longrightarrow x_{1} \stackrel{p} \longrightarrow x_0 \tag{3} \]

對於DDPM它採用的是一種自迴歸式的重建方法,每次的輸入是當前的時刻及當前時刻的噪聲圖片。也就是說它把噪聲到目標圖片的生成分成了T步,這樣每一次的預測相當於是對殘差的預測。優勢是重建效果穩定,但速度較慢。

訓練整體pipeline包含兩個過程

2 diffusion pipeline

2.1前置知識:

高斯分佈的一些性質

(1)如果\(X \sim \mathcal{N}(\mu, \sigma^2)\),且\(a\)\(b\)是實數,那麼\(aX+b \sim \mathcal{N}(a\mu+b, (a\sigma)^2)\)

(2)如果\(X \sim \mathcal{N}(\mu(x), \sigma^2(x)) ,Y \sim \mathcal{N}(\mu(y), \sigma^2(y)),\)\(X,Y\)是統計獨立的正態隨機變數,則它們的和也滿足高斯分佈(高斯分佈可加性).

\[X+Y \sim \mathcal{N}(\mu(x)+\mu{(y), \sigma^2(x) + \sigma^2(y)}) \\ X-Y \sim \mathcal{N}(\mu(x)-\mu{(y), \sigma^2(x) + \sigma^2(y)}) \tag{4} \]

均值為\(\mu\)方差為\(\sigma\)的高斯分佈的機率密度函式為

\[\begin{align*} f(x) &= \frac{1}{\sqrt{2\pi} \sigma } \exp \left ({- \frac{(x - \mu)^2}{2\sigma^2}} \right) \\ &= \frac{1}{\sqrt{2\pi} \sigma } \exp \left[ -\frac{1}{2} \left( \frac{1}{\sigma^2}x^2 - \frac{2\mu}{\sigma^2}x + \frac{\mu^2}{\sigma^2} \right ) \right] \tag{5} \end{align*} \]

2.2 加噪過程

1 前向過程:將圖片資料對映為噪聲

每一個時刻都要新增高斯噪聲,後一個時刻都是由前一個時刻加噪聲得到。(其實每一個時刻加的噪聲就是訓練所用的標籤)。即

\[x_0 \stackrel{q} \longrightarrow x_1 \stackrel{q} \longrightarrow x_{2} \stackrel{q} \longrightarrow \cdots \stackrel{q} \longrightarrow x_{T-1} \stackrel{q} \longrightarrow x_T=z \tag{6} \]

下面我們詳細來看

\(\beta_t = 1 - \alpha_t,\beta_t\)\(t\)的增加而增大(論文中[2]從0.0001 -> 0.02) (這是因為一開始加一點噪聲就很明顯,後面需要增大噪聲的量才明顯).DDPM將加噪聲過程建模為一個馬爾可夫過程\(q(x_{1:T}|x_0):= \prod \limits_{t=1}^Tq(x_t|x_{t-1}) ,\)其中\(q(x_t|x_{t-1}):=\mathcal{N}(x_t; \sqrt{\alpha_t}x_{t-1}, (1 - \alpha_t) \textbf{I})\)

\[\begin{align*} x_t &= \sqrt{\alpha_t}x_{t-1} + \sqrt{(1 - \alpha_t)}z_t \\ &= \sqrt{\alpha_t}x_{t-1} + \sqrt{\beta_t}z_t \tag{7} \end{align*} \]

\(x_t\)為在t時刻的圖片,當\(t=0\)時為原圖;\(z_t\)為在t時刻所加的噪聲,服從標準正態分佈\(z_t \sim \mathcal{N}(0, \textbf{I});\)\(\alpha_t\)是常數,是自己定義的變數;從上式可見,隨著\(T\)增大,\(x_t\)越來越接近純高斯分佈.

同理:

\[x_{t-1} = \sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1 - \alpha_{t-1}}z_{t-1} \tag{8} \]

將式(8)代入式(7)可得:

\[\begin{align*} x_t &= \sqrt{\alpha_t} (\sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1 - \alpha_{t-1}}z_{t-1}) + \sqrt{1 - \alpha_t}z_t \\ &= \sqrt{\alpha_t \alpha_{t-1}}x_{t-2} + (\sqrt{\alpha_t (1 - \alpha_{t-1})} z_{t-1} + \sqrt{1 - \alpha_t}z_t) \tag{9} \end{align*} \]

由於\(z_{t-1}\)服從均值為0,方差為1的高斯分佈(即標準正態分佈),根據定義\(\sqrt{\alpha_t (1 - \alpha_{t-1})} z_{t-1}\)服從的是均值為0,方差為\(\alpha_t (1 - \alpha_{t-1})\)的高斯分佈.即\(\sqrt{\alpha_t (1 - \alpha_{t-1})} z_{t-1} \sim \mathcal{N}(0, \alpha_t (1 - \alpha_{t-1})\textbf{I})\).同理可得\(\sqrt{1 - \alpha_t}z_t \sim \mathcal{N}(0, (1 - \alpha_t)\textbf{I})\).則(高斯分佈可加性,可以透過定義推得,不贅述)

\[ (\sqrt{\alpha_t (1 - \alpha_{t-1})} , z_{t-1} + \sqrt{1 - \alpha_t}z_t) \sim \mathcal{N}(0, \alpha_t (1 - \alpha_{t-1}) + 1 - \alpha_t) = \mathcal{N}(0, 1 - \alpha_t \alpha_{t-1}) \tag{10} \]

我們不妨記\(\overline{z}_{t-2} \sim \mathcal{N}(0, \textbf{I}),\)\(\sqrt{1 - \alpha_t \alpha_{t-1}} \overline{z}_{t-2} \sim \mathcal{N}(0, (1 - \alpha_t \alpha_{t-1})\textbf{I})\)則式(10)最終可改寫為

\[x_t = \sqrt{\alpha_t \alpha_{t-1}} x_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \overline{z}_{t-2} \tag{11} \]

透過遞推,容易得到

\[\begin{align*} x_t &= \sqrt{\alpha_t \alpha_{t-1} \cdots \alpha_1} x_0 + \sqrt{1 - \alpha_t \alpha_{t-1} \dots \alpha_1} \overline{z}_0 \\ &= \sqrt{\prod_{i=1}^{t} {\alpha_i}}x_0 + \sqrt{1 - \prod_{i=1}^{t} {\alpha_i}} \overline {z}_0 \\ &\stackrel{\mathrm{令} \overline{\alpha}_{t} = \prod_{i=1}^{t} {\alpha_i}} = \sqrt{\overline{\alpha}_{t}}x_0+\sqrt{1 - \overline{\alpha}_{t}}\overline{z}_{0} \tag{12} \end{align*} \]

其中\(\overline{z}_{0} \sim \mathcal{N}(0, \mathrm{I}),x_0\)為原圖.從式(13)可見,我們可以從\(x_0\)得到任意時刻的\(x_t\)的分佈,而無需按照時間順序遞推!這極大提升了計算效率.

\[\begin{align*} q(x_t|x_0) &= \mathcal{N}(x_t; \mu{(x_t, t)},\sigma^2{(x_t, t)}{}\textbf{I}) \\ &= \mathcal{N}(x_t; \sqrt{\overline{\alpha}_{t}}x_0,(1 - \overline{\alpha}_{t})\textbf{I}) \tag{13} \end{align*} \]

⚠️加噪過程是確定的,沒有模型的介入. 其目的是製作訓練時標籤

2.3 去噪過程

給定\(x_T\)如何求出\(x_0\)呢?直接求解是很難的,作者給出的方案是:我們可以一步一步求解.即學習一個解碼函式\(p\),這個\(p\)能夠知道\(x_{t}\)\(x_{t-1}\)的對映規則.如何定義這個\(p\)是問題的關鍵.有了\(p\),只需從\(x_{t}\)\(x_{t-1}\)逐步迭代,即可得出\(x_0\).

\[z = x_T \stackrel{p} \longrightarrow x_{T-1} \stackrel{p} \longrightarrow \cdots \stackrel{p} \longrightarrow x_{1} \stackrel{p} \longrightarrow x_0 \tag{14} \]

去噪過程是加噪過程的逆向.如果說加噪過程是求給定初始分佈\(x_0\)求任意時刻的分佈\(x_t\),即\(q(x_t|x_0)\)那麼去噪過程所求的分佈就是給定任意時刻的分佈\(x_t\)求其初始時刻的分佈\(x_0,\)\(p(x_0|x_t)\) ,透過馬爾可夫假設,可以對上述問題進行化簡

\[\begin{align*} p(x_0|x_t) &= p(x_0|x1)p(x1|x2)\cdots p(x_{t-1}| x_t) \\ &= \prod_{i=0}^{t-1}{p(x_i|x_{i+1})} \tag{15} \end{align*} \]

如何求\({p(x_{t-1}|x_{t})}\)呢?前面的加噪過程我們大力氣推到出了\({q(x_{t}|x_{t-1})},\)我們可以透過貝葉斯公式把它利用起來

\[p(x_{t-1}|x_t) = \frac{p(x_{t}|x_{t-1})p(x_{t-1})}{p(x_t)} \tag{16} \]

⚠️這裡的(去噪)\(p\)和上面的(加噪)\(q\)只是對分佈的一種符號記法。

有了式(17)還是一頭霧水,\(p(x_t)\)\(p(x_{t-1})\)都不知道啊!該怎麼辦呢?這就要藉助模型的威力了.下面來看如何構建我們的模型.

延續加噪過程的推導\(p(x_t|x_0)\)\(p(x_{t-1}|x_0)\)我們是可以知道的.因此若我們知道初始分佈\(x_0\),則

\[\begin{align*} p(x_{t-1}|x_t,x_0) &= \frac{p(x_{t}|x_{t-1}, x_0)p(x_{t-1}|x_0)}{p(x_t|x_0)} &(17) \\ &= \frac{\mathcal{N}(x_t; \sqrt{\alpha_t}x_{t-1}, (1 - \alpha_t) \textbf{I} ) \mathcal{N}(x_{t-1}; \sqrt{\overline{\alpha}_{t-1}}x_0,(1 - \overline{\alpha}_{t-1}) \textbf{I})} { \mathcal{N}(x_t; \sqrt{\overline{\alpha}_{t}}x_0,(1 - \overline{\alpha}_{t}) \textbf{I} )} &(18) \\ &\stackrel{將式(5)代入} \propto \frac{ \exp \left ({- \frac{(x_t - \sqrt{\alpha_t}x_{t-1} )^2}{2 (1 - \alpha_t)}} \right) \exp \left ({- \frac{(x_{t-1} - \sqrt{\overline{\alpha}_{t-1}}x_0 )^2}{2 (1 - \overline{\alpha}_{t-1})}} \right) } { \exp \left ({- \frac{(x_{t} - \sqrt{\overline{\alpha}_{t}}x_0 )^2}{2 (1 - \overline{\alpha}_{t})}} \right) } &(19) \\ &= \exp \left [-\frac{1}{2} \left ( \frac{(x_t - \sqrt{\alpha_t}x_{t-1} )^2}{1 - \alpha_t} + \frac{(x_{t-1} - \sqrt{\overline{\alpha}_{t-1}}x_0 )^2}{1 - \overline{\alpha}_{t-1}} - \frac{(x_{t} - \sqrt{\overline{\alpha}_{t}}x_0 )^2}{1 - \overline{\alpha}_{t}} \right) \right] &(20) \\ &= \exp \left [ -\frac{1}{2} \left( \left( \frac{\alpha_t}{1-\alpha_t} + \frac{1}{1 - \overline{\alpha}_{t-1}} \right)x^2_{t-1} - \left ( \frac{2\sqrt{\overline{\alpha_{t}}}}{1 - \alpha_t}x_t + \frac{2 \sqrt{\overline{\alpha}_{t-1}}} {1 - \overline{\alpha}_{t-1} }x_0 \right)x_{t-1} + C(x_t, x_0) \right) \right] &(21) \end{align*} \]

結合高斯分佈的定義(6)來看式(22),不難發現\(p(x_{t-1}|x_t,x_0)\)也是服從高斯分佈的.並且結合式(6)我們可以求出其方差和均值

⚠️式17做了一個近似\(p(x_t|x_{t-1}, x_0) =p(x_t| x_{t-1}),\)能做這個近似原因是一階馬爾科夫假設,當前時間點只依賴前一個時刻的時間點.

\[\begin{align*} \frac{1}{\sigma_2} &= \frac{\alpha_t}{1-\alpha_t} + \frac{1}{1 - \overline{\alpha}_{t-1}} &(22) \\ \frac{2\mu}{\sigma^2} &= \frac{2\sqrt{\overline{\alpha_{t}}}}{1 - \alpha_t}x_t + \frac{2 \sqrt{\overline{\alpha}_{t-1}}} {1 - \overline{\alpha}_{t-1} }x_0 &(23) \end{align*} \]

可以求得:

\[\begin{align*} \sigma^2 &= \frac{1 - \overline{\alpha}_{t-1}}{1 - \overline{\alpha}_{t}} (1 - \alpha_t) \\ \mu &= \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}_t}x_t + \frac{\sqrt{\overline{\alpha}_{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}_t}x_0 \tag{24} \end{align*} \]

透過上式,我們可得

\[p(x_{t-1}|x_t,x_0) = \mathcal{N}(x_{t-1}; \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}_t}x_t + \frac{\sqrt{\overline{\alpha}_{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}_t}x_0 , (\frac{1 - \overline{\alpha}_{t-1}}{1 - \overline{\alpha}_{t}} (1 - \alpha_t)) \textbf{I}) \tag{25} \]

該式是真實的條件分佈.我們目標是讓模型學到的條件分佈\(p_\theta(x_{t-1}|x_t)\)儘可能的接近真實的條件分佈\(p(x_{t-1}|x_t, x_0).\)從上式可以看到方差是個固定量,那麼我們要做的就是讓\(p(x_{t-1}|x_t, x_0)\)\(p_\theta(x_{t-1}|x_t)\)的均值儘可能的對齊,即

(這個結論也可以透過最小化上述兩個分佈的KL散度推得)

\[\mathrm{arg} \mathop{min}_\theta \parallel u(x_0, x_t), u_\theta(x_t, t) \parallel \tag{26} \]

下面的問題變為:如何構造\(u_\theta(x_t, t)\)來使我們的最佳化儘可能的簡單

我們注意到\(\mu(x_0, x_t)與\mu_\theta(x_t, t)\)都是關於\(x_t\)的函式,不妨讓他們的\(x_t\)保持一致,則可將\(\mu_\theta(x_t, t)\)寫成

\[\mu_\theta(x_t, t) = \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}_t}x_t + \frac{\sqrt{\overline{\alpha}_{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}_t} f_\theta(x_t, t) \tag{27} \]

\(f_\theta(x_t, t)\)是我們需要訓練的模型.這樣對齊均值的問題就轉化成了: 給定\(x_t, t\)來預測原始圖片輸入\(x_0.\)根據上文的加噪過程,我們可以很容易製造訓練所需的資料對! (Dalle2的訓練採用的是這個方式).事情到這裡就結束了嗎?

DDPM作者表示直接從\(x_t\)\(x_0\)的預測資料跨度太大了,且效果一般.我們可以將式(12)做一下變形

\[\begin{align*} x_t &= \sqrt{\overline{\alpha}_{t}}x_0+\sqrt{1 - \overline{\alpha}_{t}}\overline{z}_{0} \\ x_0 &= \frac{1}{\sqrt{\overline{\alpha}_{t}}}(x_t - \sqrt{1 - \overline{\alpha}_{t}}\overline{z}_{0}) \tag{28} \end{align*} \]

代入到式(24)中

\[\begin{align*} \mu &= \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}_t}x_t + \frac{\sqrt{\overline{\alpha}_{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}_t} \frac{1}{\sqrt{\overline{a}_{t}}}(x_t - \sqrt{1 - \overline{a}_{t}}\overline{z}_{0}) \\ &= \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}_t}x_t + \frac{(1 - \alpha_t) }{1 - \overline{\alpha}_t} \frac{1}{\sqrt{\alpha}_{t}}(x_t - \sqrt{1 - \overline{\alpha}_{t}}\overline{z}_{0}) \\ &\stackrel{合併x_t} = \frac{\alpha_t(1 - \overline{\alpha}_{t-1}) + (1 - \alpha_t) }{\sqrt{\alpha}_t (1 - \overline{\alpha}_t)}x_t - \frac{\sqrt{1 - \overline{\alpha}_t}(1 - \alpha_t) }{\sqrt{\alpha_t}(1 - \overline{\alpha}_t)}\overline{z}_0 \\ &= \frac{1 - \overline{\alpha}_t}{\sqrt{\alpha}_t (1 - \overline{\alpha}_t)}x_t - \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}_t}}\overline{z}_0 \\ &= \frac{1}{\sqrt{\alpha}_t}x_t - \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}_t}}\overline{z}_0 \tag{29} \end{align*} \]

經過這次化簡,我們將\(\mu{(x_0, x_t)} \Rightarrow \mu{(x_t, \overline{z}_0)},\)其中\(\overline{z}_0 \sim \mathcal{N}(0, \textbf{I}),\)可以將式(29)轉變為

\[\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} x_t - \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}_t}}f_\theta(x_t, t) \tag{30} \]

此時對齊均值的問題就轉化成:給定\(x_t, t\)預測\(x_t\)加入的噪聲\(\overline{z}_0\), 也就是說我們的模型預測的是噪聲\(f_\theta{(x_t, t)} = \epsilon_{\theta}(x_t, t) \simeq \overline{z}_0\)

2.3.1 訓練與取樣過程

訓練的目標就是這所有時刻兩個噪聲的差異的期望越小越好(用MSE或L1-loss).

\[\mathbb{E}_{t \sim T } \parallel \epsilon - \epsilon_{\theta}(x_t, t)\parallel_2 ^2 \tag{31} \]

下圖為論文提供的訓練和取樣過程

image

2.3.2 取樣過程

透過以上討論,我們推匯出\(p_\theta(x_{t-1}|x_t)\)高斯分佈的均值和方差.\(p_\theta(x_{t-1}|x_t)=\mathcal{N}(x_{t-1}; \mu_{\theta}(x_t, t), \sigma^2(t) \textbf{I})\),根據文獻[1]從一個高斯分佈中取樣一個隨機變數可用一個重引數化技巧進行近似

\[\begin{align*} x_{t-1} &= \mu_{\theta}(x_t, t) + \sigma(t) \epsilon,其中 \epsilon \in \mathcal{N}(\epsilon; 0, \textbf{I}) \\ & = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1 - \alpha_t }{\sqrt{1 - \overline{\alpha}_t}}\epsilon_\theta(x_t, t)) + \sigma(t) \epsilon &\tag{32} \end{align*} \]

式(32)和論文給出的取樣遞推公式一致.

至此,已完成DDPM整體的pipeline.

還沒想明白的點,為什麼不能根據(7)的變形來進行取樣計算呢?

\[x_{t-1} = \frac{1}{\sqrt{\alpha_t}}x_t - \sqrt{\frac{1 - \alpha_t}{\alpha_t}} f_\theta(x_t, t) \tag{33} \]

3 從程式碼理解訓練&預測過程

3.1 訓練過程

參考程式碼倉庫: https://github.com/lucidrains/denoising-diffusion-pytorch/tree/main/denoising_diffusion_pytorch

已知項: 我們假定有一批N張圖片\(\{x_i |i=1, 2, \cdots, N\}\)

第一步: 隨機取樣K組成batch,如\(\mathrm{x\_start}= \{ x_k|k=1,2, \cdots, K \}, \mathrm{Shape}(\mathrm{x\_start}) = (K, C, H, W)\)

第二步: 隨機取樣一些時間步

t = torch.randint(0, self.num_timesteps, (b,), device=device).long()  # 隨機取樣時間步

第三步: 隨機取樣噪聲

noise = default(noise, lambda: torch.randn_like(x_start))  # 基於高斯分佈取樣噪聲

第四步: 計算\(\mathrm{x\_start}\)在所取樣的時間步的輸出\(x_T\)(即加噪聲).(根據公式12)

def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

betas = linear_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def q_sample(x_start, t, noise=None):
  """
  \begin{eqnarray}
    x_t &=& \sqrt{\alpha_t}x_{t-1} + \sqrt{(1 - \alpha_t)}z_t \nonumber \\
    &=&  \sqrt{\alpha_t}x_{t-1} + \sqrt{\beta_t}z_t
  \end{eqnarray}
  """
    return (
        extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
        extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
    )

x = q_sample(x_start = x_start, t = t, noise = noise)  # 這就是x0在時間步T的輸出

第五步: 預測噪聲.輸入\(x_T,t\)到噪聲預測模型,來預測此時的噪聲\(\hat{z}_t = \epsilon_\theta(x_T, t)\).論文用到的模型結構是Unet,與傳統Unet的輸入有所不同的是增加了一個時間步的輸入.

model_out = self.model(x, t, x_self_cond=None)  # 預測噪聲

這裡面有一個需要注意的點:模型是如何對時間步進行編碼並使用的

  • 首先會對時間步進行一個編碼,將其變為一個向量,以正弦編碼為例
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        """
        Args:
          x (Tensor), shape like (B,)
        """
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

# 時間步的編碼pipeline如下,本質就是將一個常數對映為一個向量
self.time_mlp = nn.Sequential(
    SinusoidalPosEmb(dim),
    nn.Linear(fourier_dim, time_dim),
    nn.GELU(),
    nn.Linear(time_dim, time_dim)
)
  • 將時間步的embedding嵌入到Unet的block中,使模型能夠學習到時間步的資訊
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift  # 將時間向量一分為2,一份用於提升幅值,一份用於修改相位

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):

        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)

        h = self.block2(h)

        return h + self.res_conv(x)

第六步:計算損失,反向傳播.計算預測的噪聲與實際的噪聲的損失,損失函式可以是L1或mse

@property
    def loss_fn(self):
        if self.loss_type == 'l1':
            return F.l1_loss
        elif self.loss_type == 'l2':
            return F.mse_loss
        else:
            raise ValueError(f'invalid loss type {self.loss_type}')

透過不斷迭代上述6步即可完成模型的訓練

3.2取樣過程

第一步:隨機從高斯分佈取樣一張噪聲圖片,並給定取樣時間步

img = torch.randn(shape, device=device)

第二步: 根據預測的當前時間步的噪聲,透過公式計算當前時間步的均值和方差


  posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) # 式(24)x_0的係數
  posterior_mean_coef = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)  # 式(24) x_t的係數

  def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

  def q_posterior(self, x_start, x_t, t):
    posterior_mean = (
        extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
        extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
    )  # 求出此時的均值
    posterior_variance = extract(self.posterior_variance, t, x_t.shape)  # 求出此時的方差
    posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) # 對方差取對數,可能為了數值穩定性
    return posterior_mean, posterior_variance, posterior_log_variance_clipped

  def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
      preds = self.model_predictions(x, t, x_self_cond)  # 預測噪聲
      x_start = preds.pred_x_start  # 模型預測的是在x_t時間步噪聲,x_start是根據公式(12)求

      if clip_denoised:
          x_start.clamp_(-1., 1.)

      model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
      return model_mean, posterior_variance, posterior_log_variance, x_start

第三步: 根據公式(32)計算得到前一個時刻圖片\(x_{t-1}\)

  @torch.no_grad()
  def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
      b, *_, device = *x.shape, x.device
      batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
      model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised)  # 計算當前分佈的均值和方差
      noise = torch.randn_like(x) if t > 0 else 0. # 從高斯分佈取樣噪聲
      pred_img = model_mean + (0.5 * model_log_variance).exp() * noise  # 根據
      return pred_img, x_start

透過迭代以上三步,直至\(T=0\)完成取樣.

思考和討論

DDPM區別與傳統的VAE與GAN採用了一種新的正規化實現了更高質量的影像生成.但實踐發現,需要較大的取樣步數才能得到較好的生成結果.由於其取樣過程是一個馬爾可夫的推理過程,導致會有較大的耗時.後續工作如DDIM針對該特性做了最佳化,數十倍降低取樣所用時間。

參考文獻

[1] Understanding Diffusion Models: A Unified Perspective

[2] Denoising Diffusion Probabilistic Models

相關文章