基於VAE介紹的理論,簡單實現VAE生成人臉,程式碼如下:
utils.py
import os from torch.utils.data import Dataset from torchvision.transforms import transforms import glob import cv2 import numpy as np import torch class MyDataset(Dataset): def __init__(self, img_path, device): super(MyDataset, self).__init__() self.device = device self.fnames = glob.glob(os.path.join(img_path+"*.jpg")) self.transforms = transforms.Compose([ transforms.ToTensor(), ]) def __getitem__(self, idx): fname = self.fnames[idx] img = cv2.imread(fname, cv2.IMREAD_COLOR) img = self.transforms(img) img = img.to(self.device) return img def __len__(self): return len(self.fnames)
VAE.py
import torch import torch.nn as nn class VAE(nn.Module): def __init__(self, image_size: int, in_channels: int, latent_dim: int, hid_dims: int = None): super(VAE, self).__init__() self.latent_dim = latent_dim if not hid_dims: hid_dims = [32, 64, 128, 256] feature_size = image_size // (2**4) modules = [] for h_d in hid_dims: modules.append(nn.Sequential(nn.Conv2d(in_channels, h_d, 3, 2, 1), nn.BatchNorm2d(h_d), nn.LeakyReLU())) in_channels = h_d self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hid_dims[-1]*feature_size**2, latent_dim) self.fc_var = nn.Linear(hid_dims[-1]*feature_size**2, latent_dim) # decoder self.decoder_input = nn.Linear(latent_dim, hid_dims[-1]*feature_size**2) hid_dims.reverse() modules = [] for i in range(len(hid_dims)-1): modules.append(nn.Sequential(nn.ConvTranspose2d(hid_dims[i], hid_dims[i+1], 3, 2, 1, 1), nn.BatchNorm2d(hid_dims[i+1]), nn.LeakyReLU())) self.decoder = nn.Sequential(*modules) self.decoder_out = nn.Sequential(nn.ConvTranspose2d(hid_dims[-1], hid_dims[-1], 3, 2, 1, 1), nn.BatchNorm2d(hid_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hid_dims[-1], 3, 3, 1, 1, 1), nn.Sigmoid()) def encode(self, x): x = self.encoder(x) x = torch.flatten(x, start_dim=1) mu = self.fc_mu(x) var = self.fc_var(x) return mu, var def decode(self, x): x = self.decoder_input(x) x = x.view(-1, 256, 6, 6) x = self.decoder(x) x = self.decoder_out(x) return x def re_parameterize(self, mu, log_var): std = torch.exp_(0.5*log_var) eps = torch.randn_like(std) return mu + std*eps def forward(self, x): mu, log_var = self.encode(x) z = self.re_parameterize(mu, log_var) out = self.decode(z) return out, mu, log_var def sample(self, n_samples, device): z = torch.randn((n_samples, self.latent_dim)).to(device) samples = self.decode(z) return samples if __name__ == '__main__': DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") fake_input = torch.ones((1, 3, 96, 96)) model = VAE(96, 3, 1024) out, *_ = model(fake_input) print(out.shape) print(model.sample(10, DEVICE).shape)
Loss.py
import torch import torch.nn as nn class Loss(nn.Module): def __init__(self, kld_weight=0.03): super(Loss, self).__init__() self.kld_weight = kld_weight self.criterion = nn.MSELoss(reduction='mean') def forward(self, input, output, mu, log_var): recon_loss = self.criterion(output, input) kld_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp()) return recon_loss + self.kld_weight*kld_loss
train_vae.py
import os import numpy as np import torch from VAE import VAE import argparse from torch.utils.data import DataLoader from PIL import Image from torch.optim import Adam from utils import MyDataset from torchvision.utils import save_image from Loss import Loss from tqdm import tqdm def args_parser(): parser = argparse.ArgumentParser(description="Parameters of training vae model") parser.add_argument("-b", "--batch_size", type=int, default=128) parser.add_argument("-i", "--in_channels", type=int, default=3) parser.add_argument("-d", "--latent_dim", type=int, default=256) parser.add_argument("-l", "--lr", type=float, default=1e-3) parser.add_argument("-w", "--weight_decay", type=float, default=1e-5) parser.add_argument("-e", "--epoch", type=int, default=500) parser.add_argument("-v", "--snap_epoch", type=int, default=1) parser.add_argument("-n", "--num_samples", type=int, default=64) parser.add_argument("-p", "--path", type=str, default="./results_linear") return parser.parse_args() def train(model, input_data, loss_fn, optimizer): optimizer.zero_grad() out, mu, log_var = model(input_data) total_loss = loss_fn(input_data, out, mu, log_var) total_loss.backward() optimizer.step() print("loss:", total_loss.item()) if __name__ == '__main__': DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") opt = args_parser() loss_fn = Loss(kld_weight=0.03) dataset = MyDataset(img_path="../faces/", device=DEVICE) train_loader = DataLoader(dataset=dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0) model = VAE(image_size=96, in_channels=opt.in_channels, latent_dim=opt.latent_dim) model.to(DEVICE) optimizer = Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay) for epoch in range(opt.epoch): model.train() data_bar = tqdm(train_loader) for step, data in enumerate(data_bar): train(model, data.to(DEVICE), loss_fn, optimizer) if epoch % opt.snap_epoch == 0 or epoch == opt.epoch - 1: model.eval() images = model.sample(opt.num_samples, DEVICE) imgs = images.detach().cpu().numpy() saved_image_path = os.path.join(opt.path, "images") os.makedirs(saved_image_path, exist_ok=True) fname = './my_generated-images-{0:0=4d}.png'.format(epoch) save_image(images, fname, nrow=8) saved_model_path = os.path.join(opt.path, "models") os.makedirs(saved_model_path, exist_ok=True) torch.save(model.state_dict(), os.path.join(saved_model_path, f"epoch_{epoch}.pth"))
沒有調參,訓練333個epoch,模型生成的結果如下: