DENOISING DIFFUSION IMPLICIT MODELS (DDIM)
從DDPM中我們知道,其擴散過程(前向過程、或加噪過程)被定義為一個馬爾可夫過程,其去噪過程(也有叫逆向過程)也是一個馬爾可夫過程。對馬爾可夫假設的依賴,導致重建每一步都需要依賴上一步的狀態,所以推理需要較多的步長。
DDPM中對於其逆向分佈的建模使用馬爾可夫假設,這樣做的目的是將式子中的未知項 \(q(x_t|x_{t-1},x_0)\),轉化成了已知項 \(q(x_t|x_{t-1})\),最後求出 \(q(x_{t-1}|x_t,x_0)\) 的分佈也是一個高斯分佈 \(\mathcal{N}(x_{t-1};\mu_q(x_t,x_0),\Sigma_q(t))\)。
從DDPM的結論出發,我們不妨直接假設 \(q(x_{t-1}|x_t,x_0)\) 的分佈為高斯分佈,在不使用馬爾可夫假設的情況下,嘗試求解 \(q(x_{t-1}|x_t,x_0)\) 。
由 DDPM 中 \(q(x_{t-1}|x_t,x_0)\) 的分佈 \(\mathcal{N}(x_{t-1};\mu_q(x_t,x_0),\Sigma_q(t))\) 可知,均值為 一個關於 \(x_t,x_0\) 的函式,方差為一個關於 \(t\) 的函式。
我們可以把 \(q(x_{t-1}|x_t,x_0)\) 設計成如下分佈:
這樣,只要求解出 \(a,b,\sigma_t\) 這三個待定係數,即可確定 \(q(x_{t-1}|x_t,x_0)\) 的分佈。
重引數化 \(q(x_{t-1}|x_t,x_0)\) :
假設訓練模型時輸入噪聲圖片的加噪引數與DDPM完全一致
由 \(q(x_t|x_{0}) := \mathcal{N}(x_t;\sqrt{\bar{\alpha}_t}x_{0},(1-\bar{\alpha}_t)I)\) :
代入 \(x_t\) 有:
又:
觀察係數可以得到方程組:
三個未知數 兩個方程,可以用 \(\sigma_t\) 表示 \(a,b\):
\(a, b\) 代入 \(q(x_{t-1}|x_t,x_0) := \mathcal{N}(x_{t-1}; a x_0 + b x_t,\sigma_t^2 I)\)
又
代入 \(x_0\) 有:
透過觀察 \(x_{t-1}\) 的分佈,我們建模取樣分佈為高斯分佈:
並且均值和方差也採用相似的形式:
其中 \(\epsilon_\theta(x_t,t)\) 為預測的噪聲。
此時,確定最佳化目標只需要 \(q(x_{t-1}|x_t,x_0)\) 和 \(p_\theta(x_{t-1}|x_t)\) 兩個分佈儘可能相似,使用KL散度來度量,則有:
恰好與DDPM的最佳化目標一致,所以我們可以直接複用DDPM訓練好的模型。
\(p_{\theta}\) 的取樣步驟則為:
令 \(\sigma_t=\eta \sqrt{\dfrac{(1-{\alpha}_{t})(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}}\)
當 \(\eta =1\) 時,前向過程為 Markovian ,取樣過程變為 DDPM 。
當 \(\eta =0\) 時,取樣過程為確定過程,此時的模型 稱為 隱機率模型(implicit probabilstic model)。
DDIM如何加速取樣:
在 DDPM 中,基於馬爾可夫鏈 \(t\) 與 \(t-1\) 是相鄰關係,例如 \(t=100\) 則 \(t-1=99\);
在 DDIM 中,\(t\) 與 \(t-1\) 只表示前後關係,例如 \(t=100\) 時,\(t-1\) 可以是 90 也可以是 80、70,只需保證 \(t-1 < t\) 即可。
此時構建的取樣子序列 \(\tau=[\tau_i,\tau_{i-1},\cdots,\tau_{1}] \ll [t,t-1,\cdots,1]\) 。
例如,原序列 \(\Tau=[100,99,98,\cdots,1]\),取樣子序列為 \(\tau=[100,90,80,\cdots,1]\) 。
DDIM 取樣公式為:
當 \(\eta= 0\) 時,DDIM 取樣公式為:
程式碼實現
訓練過程與 DDPM 一致,程式碼參考上一篇文章。取樣程式碼如下:
device = 'cuda'
torch.cuda.empty_cache()
model = Unet().to(device)
model.load_state_dict(torch.load('ddpm_T1000_l2_epochs_300.pth'))
model.eval()
image_size=96
epochs = 500
batch_size = 128
T=1000
betas = torch.linspace(0.0001, 0.02, T).to('cuda') # torch.Size([1000])
# 每隔20取樣一次
tau_index = list(reversed(range(0, T, 20))) #[980, 960, ..., 20, 0]
eta = 0.003
# train
alphas = 1 - betas # 0.9999 -> 0.98
alphas_cumprod = torch.cumprod(alphas, axis=0) # 0.9999 -> 0.0000
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1-alphas_cumprod)
def get_val_by_index(val, t, x_shape):
batch_t = t.shape[0]
out = val.gather(-1, t)
return out.reshape(batch_t, *((1,) * (len(x_shape) - 1))) # torch.Size([batch_t, 1, 1, 1])
def p_sample_ddim(model):
def step_denoise(model, x_tau_i, tau_i, tau_i_1):
sqrt_alphas_bar_tau_i = get_val_by_index(sqrt_alphas_cumprod, tau_i, x_tau_i.shape)
sqrt_alphas_bar_tau_i_1 = get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape)
denoise = model(x_tau_i, tau_i)
if eta == 0:
sqrt_1_minus_alphas_bar_tau_i = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape)
sqrt_1_minus_alphas_bar_tau_i_1 = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i_1, x_tau_i.shape)
x_tau_i_1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * x_tau_i \
+ (sqrt_1_minus_alphas_bar_tau_i_1 - sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * sqrt_1_minus_alphas_bar_tau_i) \
* denoise
return x_tau_i_1
sigma = eta * torch.sqrt((1-get_val_by_index(alphas, tau_i, x_tau_i.shape)) * \
(1-get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape)) / get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape))
noise_z = torch.randn_like(x_tau_i, device=x_tau_i.device)
# 整個式子由三部分組成
c1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * (x_tau_i - get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape) * denoise)
c2 = torch.sqrt(1 - get_val_by_index(alphas_cumprod, tau_i_1, x_tau_i.shape) - sigma) * denoise
c3 = sigma * noise_z
x_tau_i_1 = c1 + c2 + c3
return x_tau_i_1
img_pred = torch.randn((4, 3, image_size, image_size), device=device)
for k in range(0, len(tau_index)):
# print(tau_index)
# 因為 tau_index 是倒序的,tau_i = k, tau_i_1 = k+1,這裡不能弄反
tau_i_1 = torch.tensor([tau_index[k+1]], device=device, dtype=torch.long)
tau_i = torch.tensor([tau_index[k]], device=device, dtype=torch.long)
img_pred = step_denoise(model, img_pred, tau_i, tau_i_1)
torch.cuda.empty_cache()
if tau_index[k+1] == 0: return img_pred
return img_pred
with torch.no_grad():
img = p_sample_ddim(model)
img = torch.clamp(img, -1.0, 1.0)
show_img_batch(img.detach().cpu())
DDIM
https://arxiv.org/pdf/2010.02502
https://github.com/ermongroup/ddim