Denoising Diffusion Implicit Models(去噪隱式模型)

澳大利亚树袋熊發表於2024-10-18

DDPM有一個很麻煩的問題,就是需要迭代很多步,十分耗時。有人提出了一些方法,比如one-step dm等等。較著名、也比較早的是DDIM。

原文:https://arxiv.org/pdf/2010.02502

參考博文:https://zhuanlan.zhihu.com/p/666552214?utm_id=0

訓練過程與ddpm一致,推理過程發生變化,加速了擴散過程,結果也變得穩定一些。

DDIM假設

DM假設

ddim給出了一個新的擴散假設,結合ddpm的原假設,直接往新假設代入xt得到:

根據原假設,聯絡上式:

得到:

DDIM假設被變形為:

x0可以根據擴散模型假設消去,得到:

當然你可以隔著很多步,所以有:

DDIM程式碼如下:

#ddpm 
def sample_backward_step(self, x_t, t, net, simple_var=True,isUnsqueeze=True):

        n = x_t.shape[0]
        if isUnsqueeze:
            t_tensor = torch.tensor([t] * n,
                                dtype=torch.long).to(x_t.device).unsqueeze(1)
        else:
            t_tensor = torch.tensor([t] * n,
                                    dtype=torch.long).to(x_t.device)
        eps = net(x_t, t_tensor)


        if simple_var:
            var = self.betas[t]
        else:
            var = (1 - self.alpha_bars[t - 1]) / (
                1 - self.alpha_bars[t]) * self.betas[t]
        noise = torch.randn_like(x_t)
        noise *= torch.sqrt(var)

        mean = (x_t -
                (1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
                eps) / torch.sqrt(self.alphas[t])
        x_t = mean + noise

        return x_t

#ddim
def time_backward_step(self, x_t, t, net, sample_step=5,isUnsqueeze=True):
        n = x_t.shape[0]
        if isUnsqueeze:
            t_tensor = torch.tensor([t] * n,
                                    dtype=torch.long).to(x_t.device).unsqueeze(1)
        else:
            t_tensor = torch.tensor([t] * n,
                                    dtype=torch.long).to(x_t.device)
        eps = net(x_t, t_tensor)

        xstar=torch.sqrt(1./self.alpha_bars[t])*x_t-torch.sqrt(1./self.alpha_bars[t]-1)*eps
        xstar=torch.clamp(xstar,-1,1)

        prev_t=t-sample_step if t-sample_step>0 else 0
        pred_xt=torch.sqrt(1-self.alpha_bars[prev_t])*eps
        x_prev=torch.sqrt(self.alpha_bars[prev_t])*xstar+pred_xt

        return x_prev

DDIM結果圖

DDPM結果圖

相關文章