一文詳解擴散模型:DDPM

京東雲開發者發表於2023-03-29
作者:京東零售 劉巖

擴散模型講解

前沿

人工智慧生成內容(AI Generated Content,AIGC)近年來成為了非常前沿的一個研究方向,生成模型目前有四個流派,分別是生成對抗網路(Generative Adversarial Models,GAN),變分自編碼器(Variance Auto-Encoder,VAE),標準化流模型(Normalization Flow, NF)以及這裡要介紹的擴散模型(Diffusion Models,DM)。擴散模型是受到熱力學中的一個分支,它的思想來源是非平衡熱力學(Non-equilibrium thermodynamics)。擴散模型的演算法理論基礎是透過變分推斷(Variational Inference)訓練引數化的馬爾可夫鏈(Markov Chain),它在許多工上展現了超過GAN等其它生成模型的效果,例如最近非常火熱的OpenAI的DALL-E 2,Stability.ai的Stable Diffusion等。這些效果驚豔的模型擴散模型的理論基礎便是我們這裡要介紹的提出擴散模型的文章[1]和非常重要的DDPM[2],擴散模型的實現並不複雜,但其背後的數學原理卻非常豐富。在這裡我會介紹這些重要的數學原理,但省去了這些公式的推導計算,如果你對這些推導感興趣,可以學習參考文獻[4,5,11]的相關內容。我在這裡主要以一個相對簡單的角度來講解擴散模型,幫助你快速入門這個非常重要的生成演算法。

1. 背景知識: 生成模型

目前生成模型主要有圖1所示的四類。其中GAN的原理是透過判別器和生成器的互相博弈來讓生成器生成足以以假亂真的影像。VAE的原理是透過一個編碼器將輸入影像編碼成特徵向量,它用來學習高斯分佈的均值和方差,而解碼器則可以將特徵向量轉化為生成影像,它側重於學習生成能力。流模型是從一個簡單的分佈開始,透過一系列可逆的轉換函式將分佈轉化成目標分佈。擴散模型先透過正向過程將噪聲逐漸加入到資料中,然後透過反向過程預測每一步加入的噪聲,透過將噪聲去掉的方式逐漸還原得到無噪聲的影像,擴散模型本質上是一個馬爾可夫架構,只是其中訓練過程用到了深度學習的BP,但它更屬於數學層面的創新。這也就是為什麼很多計算機的同學看擴散模型相關的論文會如此費力。

DDPM_1.png

圖1:生成模型的四種型別 [4]

擴散模型中最重要的思想根基是馬爾可夫鏈,它的一個關鍵性質是平穩性。即如果一個機率隨時間變化,那麼再馬爾可夫鏈的作用下,它會趨向於某種平穩分佈,時間越長,分佈越平穩。如圖2所示,當你向一滴水中滴入一滴顏料時,無論你滴在什麼位置,只要時間足夠長,最終顏料都會均勻的分佈在水溶液中。這也就是擴散模型的前向過程。

DDPM_2.png

圖2:顏料分子在水溶液中的擴散過程

如果我們能夠在擴散的過程顏料分子的位置、移動速度、方向等移動屬性。那麼也可以根據正向過程的儲存的移動屬性從一杯被溶解了顏料的水中反推顏料的滴入位置。這邊是擴散模型的反向過程。記錄移動屬性的快照便是我們要訓練的模型。

2. 擴散模型

在這一部分我們將集中介紹擴散模型的數學原理以及推導的幾個重要性質,因為推導過程涉及大量的數學知識但是對理解擴散模型本身思想並無太大幫助,所以這裡我會省去推導的過程而直接給出結論。但是我也會給出推導過程的出處,對其中的推導過程比較感興趣的請自行檢視。

2.1 計算原理

擴散模型簡單的講就是透過神經網路學習從純噪聲資料逐漸對資料進行去噪的過程,它包含兩個步驟,如圖3:

DDPM_3.png

圖3:DDPM的前向加噪和後向去噪過程

2.1.1 前向過程

2.1.2 後向過程

2.1.3 目標函式

那麼問題來了,我們究竟使用什麼樣的最佳化目標才能比較好的預測高斯噪聲的分佈呢?一個比較複雜的方式是使用變分自編碼器的最大化證據下界(Evidence Lower Bound, ELBO)的思想來推導,如式(6),推導詳細過程見論文[11]的式(47)到式(58),這裡主要用到了貝葉斯定理和琴生不等式。

式(6)的推導細節並不重要,我們需要重點關注的是它的最終等式的三個組成部分,下面我們分別介紹它們:

DDPM_4.png

圖4:擴散模型的去噪匹配項在每一步都要擬合噪音的真實後驗分佈和估計分佈

真實後驗分佈可以使用貝葉斯定理進行推導,最終結果如式(8),推導過程見論文[11]的式(71)到式(84)。

\(p{\boldsymbol{\theta}}\left(\boldsymbol{x}{t-1} \mid \boldsymbol{x}t\right) = \mathcal N(\boldsymbol x{t-1}; \mu\theta(\boldsymbol x\_t, t), \Sigma\_q(t)) \tag9\)

2.1.4 模型訓練

雖然上面我們介紹了很多內容,並給出了大量公式,但得益於推匯出的幾個重要性質,擴散模型的訓練並不複雜,它的訓練虛擬碼見演算法1。

DDPM_ag1.png

2.1.5 樣本生成

DDPM_ag2.png

2.2 演算法實現

2.2.1模型結構

DDPM在預測施加的噪聲時,它的輸入是施加噪聲之後的影像,預測內容是和輸入影像相同尺寸的噪聲,所以它可以看做一個Img2Img的任務。DDPM選擇了U-Net[9]作為噪聲預測的模型結構。U-Net是一個U形的網路結構,它由編碼器,解碼器以及編碼器和解碼器之間的跨層連線(殘差連線)組成。其中編碼器將影像降取樣成一個特徵,解碼器將這個特徵上取樣為目標噪聲,跨層連線用於拼接編碼器和解碼器之間的特徵。

DDPM_5.png

圖5:U-Net的網路結構

下面我們介紹DDPM的模型結構的重要元件。首先在U-Net的卷積部分,DDPM使用了寬殘差網路(Wide Residual Network,WRN)[12]作為核心結構,WRN是一個比標準殘差網路層數更少,但是通道數更多的網路結構。也有作者復現發現ConvNeXT作為基礎結構會取得非常顯著的效果提升[13,14]。這裡我們可以根據訓練資源靈活的調整卷積結構以及具體的層數等超參。因為我們在擴散過程的整個流程中都共享同一套引數,為了區分不同的時間片,作者借鑑了Transformer [15]的位置編碼的思想,採用了正弦位置嵌入對時間$t$進行了編碼,這使得模型在預測噪聲時知道它預測的是批次中分別是哪個時間片新增的噪聲。在卷積層之間,DDPM新增了一個注意力層。這裡我們可以使用Transformer中提出的自注意力機制或是多頭自注意力機制。[13]則提出了一個線性注意力機制的模組,它的特點是消耗的時間以及佔用的記憶體和序列長度是線性相關的,對比傳統注意力機制的平方相關要高效很多。在進行歸一化時,DDPM選擇了組歸一化(Group Normalization,GN)[16]。最後,對於U-Net中的降取樣和上取樣操作,DDPM分別選擇了步長為2的卷積以及反摺積。

確定了這些元件,我們便可以搭建用於DDPM的U-Net的模型了。從第2.1節的介紹我們知道,模型的輸入為形狀為(batch\_size, num\_channels, height, width)的噪聲影像和形狀為(batch\_size,1)的噪聲水平,返回的是形狀為(batch\_size, num_channels, height, width)的預測噪聲,我們搭建的用於噪聲預測的模型結構如下:

  1. 首先在噪聲影像\( \boldsymbol x_0\)上應用卷積層,併為噪聲水平$t$計算時間嵌入;
  2. 接下來是降取樣階段。採用的模型結構依次是兩個卷積(WRNS或是ConvNeXT)+GN+Attention+降取樣層;
  3. 在網路的最中間,依次是卷積層+Attention+卷積層;
  4. 接下來是上取樣階段。它首先會使用Short-cut拼接來自降取樣中同樣尺寸的卷積,再之後是兩個卷積+GN+Attention+上取樣層。
  5. 最後是使用WRNS或是ConvNeXT作為輸出層的卷積。

U-Net類的forword函式如下面程式碼片段所示,完整的實現程式碼參照[3]。

def forward(self, x, time):
    x = self.init_conv(x)
    t = self.time_mlp(time) if exists(self.time_mlp) else None
    h = []
    # downsample
    for block1, block2, attn, downsample in self.downs:
        x = block1(x, t)
        x = block2(x, t)
        x = attn(x)
        h.append(x)
        x = downsample(x)
    # bottleneck
    x = self.mid_block1(x, t)
    x = self.mid_attn(x)
    x = self.mid_block2(x, t)
    # upsample
    for block1, block2, attn, upsample in self.ups:
        x = torch.cat((x, h.pop()), dim=1)
        x = block1(x, t)
        x = block2(x, t)
        x = attn(x)
        x = upsample(x)
    return self.final_conv(x)



2.2.2 前向加噪

DDPM_6.png

圖6:一張圖依次經過0次,50次,100次,150次以及199次加噪後的效果圖

根據式(14)我們知道,擴散模型的損失函式計算的是兩張影像的相似性,因此我們可以選擇使用迴歸演算法的所有損失函式,以MSE為例,前向過程的核心程式碼如下面程式碼片段。

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
       # 1. 根據時刻t計算隨機噪聲分佈,並對影像x_start進行加噪
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    # 2. 根據噪聲影像以及時刻t,預測新增的噪聲
    predicted_noise = denoise_model(x_noisy, t)
    # 3. 對比新增的噪聲和預測的噪聲的相似性
    loss = F.mse_loss(noise, predicted_noise)
    return loss



2.2.3 樣本生成

根據2.1.5節介紹的樣本生成流程,它的核心程式碼片段所示,關於這段程式碼的講解我透過註釋新增到了程式碼片段中。

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    # 使用式(13)計算模型的均值
    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
    if t_index == 0:
        return model_mean
    else:
          # 獲取儲存的方差
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # 演算法2的第4行
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# 演算法2的流程,但是我們儲存了所有中間樣本
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device
    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []
    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs



最後我們看下在人臉影像資料集下訓練的模型,一批隨機噪聲經過逐漸去噪變成人臉影像的示例。

DDPM_7.gif

圖7:擴散模型由隨機噪聲透過去噪逐漸生成人臉影像

3. 總結

這裡我們以DDPM為例介紹了另一個派系的生成演算法:擴散模型。擴散模型是一個基於馬爾可夫鏈的數學模型,它透過預測每個時間片新增的噪聲來進行模型的訓練。作為近日來引發熱烈討論的ControlNet, Stable Diffusion等模型的底層演算法,我們十分有必要對其有所瞭解。DDPM的實現並不複雜,這得益於大量數學界大佬透過大量的數學推導將整個擴散過程和反向去噪過程進行了精彩的化簡,這才有了DDPM的大道至簡的實現。DDPM作為一個擴散模型的基石演算法,它有著很多早期演算法的共同問題:

  1. 取樣速度慢:DDPM的去噪是從時刻$T$到時刻$1$的一個完整的馬爾可夫鏈的計算,尤其是DDPM還需要一個比較大的$T$才能保證比較好的效果,這就導致了DDPM的取樣過程註定是非常慢的;
  2. 生成效果差:DDPM的效果並不能說是非常好,尤其是對於高解析度影像的生成。這一方面是因為它的計算速度限制了它擴充套件到更大的模型;另一方面它的設計還有一些問題,例如逐畫素的計算損失並使用相同權值而忽略影像中的主體並不是非常好的策略。
  3. 內容不可控:我們可以看出,DDPM生成的內容完全還是取決於它的訓練集。它並沒有引入一些先驗條件,因此並不能透過控制影像中的細節來生成我們制定的內容。

我們現在已經知道,DDPM的這些問題已大幅得到改善,現在基於擴散模型生成的影像已經達到甚至超過人類多數的畫師的效果,我也會在之後逐漸給出這些最佳化方案的講解。

Reference

[1] Sohl-Dickstein, Jascha, et al. "Deep unsupervised learning using nonequilibrium thermodynamics." _International Conference on Machine Learning_. PMLR, 2015.

[2] Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in Neural Information Processing Systems 33 (2020): 6840-6851.

[3] https://huggingface.co/blog/annotated-diffusion

[4] https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#simplification

[5] https://openai.com/blog/generative-models/

[6] Nichol, Alexander Quinn, and Prafulla Dhariwal. "Improved denoising diffusion probabilistic models." _International Conference on Machine Learning_. PMLR, 2021.

[7] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).

[8] Hinton, Geoffrey E., and Ruslan R. Salakhutdinov. "Reducing the dimensionality of data with neural networks." science 313.5786 (2006): 504-507.

[9] Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation[C]//International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015: 234-241.

[10] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for semantic segmentation." _Proceedings of the IEEE conference on computer vision and pattern recognition_. 2015.

[11] Luo, Calvin. "Understanding diffusion models: A unified perspective." arXiv preprint arXiv:2208.11970 (2022).

[12] Zagoruyko, Sergey, and Nikos Komodakis. "Wide residual networks." arXiv preprint arXiv:1605.07146 (2016).

[13] https://github.com/lucidrains/denoising-diffusion-pytorch

[14] Liu, Zhuang, et al. "A convnet for the 2020s." _Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition_. 2022.

[15] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).

[16] Wu, Yuxin, and Kaiming He. "Group normalization." _Proceedings of the European conference on computer vision (ECCV)_. 2018.

相關文章