pytorch訓練GAN時的detach()

凌逆戰發表於2020-11-09

  我最近在學使用Pytorch寫GAN程式碼,發現有些程式碼在訓練部分細節有略微不同,其中有的人用到了detach()函式截斷梯度流,有的人沒用detch(),取而代之的是在損失函式在反向傳播過程中將backward(retain_graph=True),本文通過兩個 gan 的程式碼,介紹它們的作用,並分析,不同的更新策略對程式效率的影響。

  這兩個 GAN 的實現中,有兩種不同的訓練策略:

  • 先訓練判別器(discriminator),再訓練生成器(generator),這是原始論文Generative Adversarial Networks 中的演算法
  • 先訓練generator,再訓練discriminator

  為了減少網路垃圾,GAN的原理網上一大堆,我這裡就不重複贅述了,想要詳細瞭解GAN原理的朋友,可以參考我專題文章:神經網路結構:生成式對抗網路(GAN)

需要了解的知識:

  detach():截斷node反向傳播的梯度流,將某個node變成不需要梯度的Varibale,因此當反向傳播經過這個node時,梯度就不會從這個node往前面傳播

更新策略

  我們直接下面進入本文正題,即,在 pytorch 中,detach 和 retain_graph 是幹什麼用的?本文將藉助三段 GAN 的實現程式碼,來舉例介紹它們的作用。

先訓練判別器,再訓練生成器

策略一

我們分析迴圈中一個 step 的程式碼:

valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device)  # 真實標籤,都是1
fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device)  # 假標籤,都是0

# ########################
#        訓練判別器       #
# ########################
real_imgs = imgs.to(device)     # 真實圖片
z = torch.randn((imgs.shape[0], 100)).to(device)  # 噪聲

gen_imgs = generator(z)  # 從噪聲中生成假資料
pred_gen = discriminator(gen_imgs)  # 判別器對假資料的輸出
pred_real = discriminator(real_imgs)  # 判別器對真資料的輸出

optimizer_D.zero_grad()  # 把判別器中所有引數的梯度歸零
real_loss = adversarial_loss(pred_real, valid)  # 判別器對真實樣本的損失
fake_loss = adversarial_loss(pred_gen, fake)  # 判別器對假樣本的損失
d_loss = (real_loss + fake_loss) / 2  # 兩項損失相加取平均

# 下面這行程式碼十分重要,將在正文著重講解
d_loss.backward(retain_graph=True)  # retain_graph=True 十分重要,否則計算圖記憶體將會被釋放
optimizer_D.step()  # 判別器引數更新

# ########################
#        訓練生成器       #
# ########################
g_loss = adversarial_loss(pred_gen, valid)  # 生成器的損失函式
optimizer_G.zero_grad()  # 生成器引數梯度歸零
g_loss.backward()  # 生成器的損失函式梯度反向傳播
optimizer_G.step()  # 生成器引數更新

程式碼講解

  鑑別器的損失函式d_loss是由real_loss和fake_loss組成的,而fake_loss又是noise經過generator來的。這樣一來我們對d_loss進行反向傳播,不僅會計算discriminator 的梯度還會計算generator 的梯度(雖然這一步optimizer_D.step()只更新 discriminator 的引數),因此下面在更新generator引數時,要先將generator引數的梯度清零,避免受到discriminator loss 回傳過來的梯度影響。

  generator 的 損失在回傳時,同樣要經過 discriminator 網路才能傳遞迴自身(系統從輸入噪聲到 Discriminator 輸出,從頭到尾只有一次前向傳播,而有兩次反向傳播,故在第一次反向傳播時,鑑別器要設定 backward(retain graph=True),保持計算圖不被釋放。因為 pytorch 預設一個計算圖只計算一次反向傳播,反向傳播後,這個計算圖的記憶體就會被釋放,所以用這個引數控制計算圖不被釋放。因此,在回傳梯度時,同樣也計算了一遍 discriminator 的引數梯度,只不過這次 discriminator 的引數不更新,只更新 generator 的引數,即 optimizer_G.step()。同時,我們看到,下一個 step 首先將 discriminator 的梯度重置為 0,就是為了防止 generator loss 反向傳播時順帶計算的梯度對其造成影響(還有上一步 discriminator loss 回傳時累積的梯度)。

  綜上,我們看到,為了完成一步引數更新,我們進行了兩次反向傳播,第一次反向傳播為了更新 discriminator 的引數,但多餘計算了 generator 的梯度。第二次反向傳播為了更新 generator 的引數,但是計算了 discriminator 的梯度,因此在寫一個step,需要立即清零discriminator梯度。

  如果你實在看不懂,就照著這個形式寫程式碼就行了,反正形式都幫你們寫好了

策略二

  這種策略我遇到的比較多,也是先訓練鑑別器,再訓練生成器

  鑑別器訓練階段,noise 從 generator 輸入,輸出 fake data,然後 detach 一下,隨著 true data 一起輸入 discriminator,計算 discriminator 損失,並更新 discriminator 引數。生成器訓練階段,把沒經過 detach 的 fake data 輸入到discriminator 中,計算 generator loss,再反向傳播梯度,更新 generator 的引數。這種策略,計算了兩次 discriminator 梯度,一次 generator 梯度。感覺這種比較符合先更新 discriminator 的習慣。缺點是,之前的 generator 生成的計算圖得保留著,直到 discriminator 更新完,再釋放。

valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device)  # 真實標籤,都是1
fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device)  # 假標籤,都是0

# ########################
#        訓練判別器       #
# ########################
real_imgs = imgs.to(device)     # 真實圖片
z = torch.randn((imgs.shape[0], 100)).to(device)  # 噪聲

gen_imgs = generator(z)  # 從噪聲中生成假資料
pred_gen = discriminator(gen_imgs.detach())  # 假資料detach(),判別器對假資料的輸出
pred_real = discriminator(real_imgs)  # 判別器對真資料的輸出

optimizer_D.zero_grad()  # 把判別器中所有引數的梯度歸零
real_loss = adversarial_loss(pred_real, valid)  # 判別器對真實樣本的損失
fake_loss = adversarial_loss(pred_gen, fake)  # 判別器對假樣本的損失
d_loss = (real_loss + fake_loss) / 2  # 兩項損失相加取平均

# 下面這行程式碼十分重要,將在正文著重講解
d_loss.backward()  # retain_graph=True 十分重要,否則計算圖記憶體將會被釋放
optimizer_D.step()  # 判別器引數更新

# ########################
#        訓練生成器       #
# ########################
g_loss = adversarial_loss(pred_gen, valid)  # 生成器的損失函式
optimizer_G.zero_grad()  # 生成器引數梯度歸零
g_loss.backward()  # 生成器的損失函式梯度反向傳播
optimizer_G.step()  # 生成器引數更新

先訓練生成器,再訓練判別器

 我們分析迴圈中一個 step 的程式碼:

valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)  # 真實樣本的標籤,都是 1
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)  # 生成樣本的標籤,都是 0
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))  # 噪聲
real_imgs = Variable(imgs.type(Tensor))     # 真實圖片

# ########################
#        訓練生成器       #
# ########################
optimizer_G.zero_grad()  # 生成器引數梯度歸零
gen_imgs = generator(z)  # 根據噪聲生成虛假樣本
g_loss = adversarial_loss(discriminator(gen_imgs), valid)  # 用真實的標籤+假樣本,計算生成器損失
g_loss.backward()  # 生成器梯度反向傳播,反向傳播經過了判別器,故此時判別器引數也有梯度
optimizer_G.step()  # 生成器引數更新,判別器引數雖然有梯度,但是這一步不能更新判別器

# ########################
#        訓練判別器       #
# ########################
optimizer_D.zero_grad()  # 把生成器損失函式梯度反向傳播時,順帶計算的判別器引數梯度清空
real_loss = adversarial_loss(discriminator(real_imgs), valid)  # 真樣本+真標籤:判別器損失
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)  # 假樣本+假標籤:判別器損失
d_loss = (real_loss + fake_loss) / 2  # 判別器總的損失函式
d_loss.backward()  # 判別器損失回傳
optimizer_D.step()  # 判別器引數更新

  為了更新生成器引數,用生成器的損失函式計算梯度,然後反向傳播,傳播圖中經過了判別器,根據鏈式法則,不得不順帶計算一下判別器的引數梯度,雖然在這一步不會更新判別器引數。反向傳播過後,noise 到 fake image 再到 discriminator 的輸出這個前向傳播的計算圖就被釋放掉了,後面也不會再用到。

  接著更新判別器引數,此時注意到,我們輸入判別器的是兩部分,一部分是真實資料,另一部分是生成器的輸出,也就是假資料。注意觀察細節,在判別器前向傳播過程,輸入的假資料被 detach 了,detach 的意思是,這個資料和生成它的計算圖“脫鉤”了,即梯度傳到它那個地方就停了,不再繼續往前傳播(實際上也不會再往前傳播了,因為 generator 的計算圖在第一次反向傳播過後就被釋放了)。因此,判別器梯度反向傳播,就到它自己身上為止。

  因此,比起第一種策略,這種策略要少計算一次 generator 的所有引數的梯度,同時,也不必刻意儲存一次計算圖,佔用不必要的記憶體。

  但需要注意的是,在第一種策略中,noise 從 generator 輸入,到 discriminator 輸出,只經歷了一次前向傳播,discriminator 端的輸出,被用了兩次,一次是計算 discriminator 的損失函式,另一次是計算 generator 的損失函式。

  而在第這種策略中,noise 從 generator 輸入,到discriminator 輸出,計算 generator 損失,回傳,這一步更新了 generator 的引數,並釋放了計算圖。下一步更新 discriminator 的引數時,generator 的輸出經過 detach 後,又通過了一遍 discriminator,相當於,generator 的輸出前後兩次通過了 discriminator ,得到相同的輸出。顯然,這也是冗餘的。

總結

綜上,這兩段程式碼各有利弊:

  第一段程式碼,好處是 noise 只進行了一次前向傳播,缺點是,更新 discriminator 引數時,多計算了一次 generator 的梯度,同時,第一次更新 discriminator 需要保留計算圖,保證算 generator loss 時計算圖不被銷燬。

  第三段程式碼,好處是通過先更新 generator ,使更新後的前向傳播計算圖可以放心被銷燬,因此不用保留計算圖佔用記憶體。同時,在更新 discriminator 的時候,也不會像上面的那段程式碼,計算冗餘的 generator 的梯度。缺點是,在 discriminator 上,對 generator 的輸出算了兩次前向傳播,第二次又產生了新的計算圖(但比第一次的小)。

一個多計算了一次 generator 梯度,一個多計算一次 discriminator 前向傳播。因此,兩者差別不大。如果 discriminator 比generator 複雜,那麼應該採取第一種策略,如果 discriminator 比 generator 簡單,那麼應該採取第三種策略,通常情況下,discriminator 要比 generator 簡單,故如果效果差不多儘量採取第三種策略。

  但是第三種先更新generator,再更新 discriminator 總是給人感覺怪怪得,因為 generator 的更新需要 discriminator 提供準確的 loss 和 gradient,否則豈不是在瞎更新?

  但是策略三,馬上用完馬上釋放。綜合來說,還是策略三最好,策略二其次,策略一最差(差在多計算一次 generator gradient 上,而通常多計算一次 generator gradient 的運算量比多計算一次 discriminator 前向傳播的運算量大),因此,detach 還是很有必要的。

參考

Pytorch: detach 和 retain_graph

使用PyTorch進行GAN訓練時對於梯度截斷的思考.detach()

相關文章