簡單介紹Pytorch實現WGAN用於動漫頭像生成
WGAN與GAN的不同
去除sigmoid
使用具有動量的最佳化方法,比如使用RMSProp
要對Discriminator的權重做修整限制以確保lipschitz連續約
WGAN實戰卷積生成動漫頭像
import torch import torch.nn as nn import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.utils import save_image import os from anime_face_generator.dataset import ImageDataset batch_size = 32 num_epoch = 100 z_dimension = 100 dir_path = './wgan_img' # 建立資料夾 if not os.path.exists(dir_path): os.mkdir(dir_path) def to_img(x): """因為我們在生成器裡面用了tanh""" out = 0.5 * (x + 1) return out dataset = ImageDataset() dataloader = DataLoader(dataset, batch_size=32, shuffle=False) class Generator(nn.Module): def __init__(self): super().__init__() self.gen = nn.Sequential( # 輸入是一個nz維度的噪聲,我們可以認為它是一個1*1*nz的feature map nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), # 上一步的輸出形狀:(512) x 4 x 4 nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), # 上一步的輸出形狀: (256) x 8 x 8 nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), # 上一步的輸出形狀: (256) x 16 x 16 nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), # 上一步的輸出形狀:(256) x 32 x 32 nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False), nn.Tanh() # 輸出範圍 -1~1 故而採用Tanh # nn.Sigmoid() # 輸出形狀:3 x 96 x 96 ) def forward(self, x): x = self.gen(x) return x def weight_init(m): # weight_initialization: important for wgan class_name = m.__class__.__name__ if class_name.find('Conv') != -1: m.weight.data.normal_(0, 0.02) elif class_name.find('Norm') != -1: m.weight.data.normal_(1.0, 0.02) class Discriminator(nn.Module): def __init__(self): super().__init__() self.dis = nn.Sequential( nn.Conv2d(3, 64, 5, 3, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 輸出 (64) x 32 x 32 nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), # 輸出 (128) x 16 x 16 nn.Conv2d(128, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), # 輸出 (256) x 8 x 8 nn.Conv2d(256, 512, 4, 2, 1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), # 輸出 (512) x 4 x 4 nn.Conv2d(512, 1, 4, 1, 0, bias=False), nn.Flatten(), # nn.Sigmoid() # 輸出一個數(機率) ) def forward(self, x): x = self.dis(x) return x def weight_init(m): # weight_initialization: important for wgan class_name = m.__class__.__name__ if class_name.find('Conv') != -1: m.weight.data.normal_(0, 0.02) elif class_name.find('Norm') != -1: m.weight.data.normal_(1.0, 0.02) def save(model, filename="model.pt", out_dir="out/"): if model is not None: if not os.path.exists(out_dir): os.mkdir(out_dir) torch.save({'model': model.state_dict()}, out_dir + filename) else: print("[ERROR]:Please build a model!!!") import QuickModelBuilder as builder if __name__ == '__main__': one = torch.FloatTensor([1]).cuda() mone = -1 * one is_print = True # 建立物件 D = Discriminator() G = Generator() D.weight_init() G.weight_init() if torch.cuda.is_available(): D = D.cuda() G = G.cuda() lr = 2e-4 d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, ) g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, ) d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99) g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99) fake_img = None # ##########################進入訓練##判別器的判斷過程##################### for epoch in range(num_epoch): # 進行多個epoch的訓練 pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader)) for i, img in enumerate(dataloader): num_img = img.size(0) real_img = img.cuda() # 將tensor變成Variable放入計算圖中 # 這裡的最佳化器是D的最佳化器 for param in D.parameters(): param.requires_grad = True # ########判別器訓練train##################### # 分為兩部分:1、真的影像判別為真;2、假的影像判別為假 # 計算真實圖片的損失 d_optimizer.zero_grad() # 在反向傳播之前,先將梯度歸0 real_out = D(real_img) # 將真實圖片放入判別器中 d_loss_real = real_out.mean(0).view(1) d_loss_real.backward(one) # 計算生成圖片的損失 z = torch.randn(num_img, z_dimension).cuda() # 隨機生成一些噪聲 z = z.reshape(num_img, z_dimension, 1, 1) fake_img = G(z).detach() # 隨機噪聲放入生成網路中,生成一張假的圖片。 # 避免梯度傳到G,因為G不用更新, detach分離 fake_out = D(fake_img) # 判別器判斷假的圖片, d_loss_fake = fake_out.mean(0).view(1) d_loss_fake.backward(mone) d_loss = d_loss_fake - d_loss_real d_optimizer.step() # 更新引數 # 每次更新判別器的引數之後把它們的絕對值截斷到不超過一個固定常數c=0.01 for parm in D.parameters(): parm.data.clamp_(-0.01, 0.01) # ==================訓練生成器============================ # ###############################生成網路的訓練############################### for param in D.parameters(): param.requires_grad = False # 這裡的最佳化器是G的最佳化器,所以不需要凍結D的梯度,因為不是D的最佳化器,不會更新D g_optimizer.zero_grad() # 梯度歸0 z = torch.randn(num_img, z_dimension).cuda() z = z.reshape(num_img, z_dimension, 1, 1) fake_img = G(z) # 隨機噪聲輸入到生成器中,得到一副假的圖片 output = D(fake_img) # 經過判別器得到的結果 # g_loss = criterion(output, real_label) # 得到的假的圖片與真實的圖片的label的loss g_loss = torch.mean(output).view(1) # bp and optimize g_loss.backward(one) # 進行反向傳播 g_optimizer.step() # .step()一般用在反向傳播後面,用於更新生成網路的引數 # 列印中間的損失 pbar.set_right_info(d_loss=d_loss.data.item(), g_loss=g_loss.data.item(), real_scores=real_out.data.mean().item(), fake_scores=fake_out.data.mean().item(), ) pbar.update() try: fake_images = to_img(fake_img.cpu()) save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1)) except: pass if is_print: is_print = False real_images = to_img(real_img.cpu()) save_image(real_images, dir_path + '/real_images.png') pbar.finish() d_scheduler.step() g_scheduler.step() save(D, "wgan_D.pt") save(G, "wgan_G.pt")
到此這篇關於Pytorch實現WGAN用於動漫頭像生成的文章就介紹到這了。
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69901823/viewspace-2767502/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 教你基於MindSpore用DCGAN生成漫畫頭像
- 送書 | AI插畫師:如何用基於PyTorch的生成對抗網路生成動漫頭像?AIPyTorch
- 簡單介紹pytorch中log_softmax的實現PyTorch
- 簡單介紹C#呼叫USB攝像頭的方法C#
- node簡單實現一個更改頭像功能
- 簡單介紹numpy實現RNN原理實現RNN
- 從頭開始瞭解PyTorch的簡單實現PyTorch
- GAN網路之入門教程(五)之基於條件cGAN動漫頭像生成
- 簡單介紹NMS的實現方法
- 簡單介紹C#獲取攝像頭拍照顯示影像的方法C#
- javascript實現繼承方式簡單介紹JavaScript繼承
- javascript實現鏈式呼叫簡單介紹JavaScript
- javascript實現二維陣列實現簡單介紹JavaScript陣列
- jQuery實現的生成隨機密碼程式碼例項簡單介紹jQuery隨機密碼
- CameraPath實現簡單漫遊
- js實現的動態載入css檔案簡單介紹JSCSS
- 實現微信搖一搖功能簡單介紹
- 使用CORS實現ajax跨域簡單介紹CORS跨域
- 一個純前端實現的頭像生成網站前端網站
- 基於 GD 庫生成圓形頭像
- 簡單介紹android實現可以滑動的平滑曲線圖Android
- Lucene介紹及簡單應用
- 用if條件語句來實現瀏覽器相容簡單介紹瀏覽器
- 簡單介紹5個python的實用技巧Python
- xheditor文字編輯器的簡單實用介紹
- RPC模式的介紹以及簡單的實現RPC模式
- 簡單介紹SpringMVC RESTFul實現列表功能SpringMVCREST
- 簡單介紹Go 字串比較的實現示例Go字串
- 實現跨域iframe介面方法呼叫 簡單介紹跨域
- javascript模擬實現私有屬性簡單介紹JavaScript
- jquery實現的元素居中外掛簡單介紹jQuery
- javascript如何實現模組程式設計簡單介紹JavaScript程式設計
- 執行緒池的介紹及簡單實現執行緒
- 簡單介紹基於Redis的List實現特價商品列表功能Redis
- 關於 React Hooks 的簡單介紹ReactHook
- javascript變數作用於簡單介紹JavaScript變數
- java關於事件的簡單介紹Java事件
- 帶貨直播系統,實現簡單的換頭像並儲存