VAE生成人臉程式碼

指间的执着發表於2024-06-30

基於VAE介紹的理論,簡單實現VAE生成人臉,程式碼如下:

utils.py

import os
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import glob
import cv2
import numpy as np
import torch


class MyDataset(Dataset):
    def __init__(self, img_path, device):
        super(MyDataset, self).__init__()
        self.device = device
        self.fnames = glob.glob(os.path.join(img_path+"*.jpg"))
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
        ])

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = cv2.imread(fname, cv2.IMREAD_COLOR)
        img = self.transforms(img)
        img = img.to(self.device)
        return img

    def __len__(self):
        return len(self.fnames)

VAE.py

import torch
import torch.nn as nn


class VAE(nn.Module):
    def __init__(self, image_size: int, in_channels: int, latent_dim: int, hid_dims: int = None):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        if not hid_dims:
            hid_dims = [32, 64, 128, 256]

        feature_size = image_size // (2**4)

        modules = []
        for h_d in hid_dims:
            modules.append(nn.Sequential(nn.Conv2d(in_channels, h_d, 3, 2, 1),
                            nn.BatchNorm2d(h_d),
                            nn.LeakyReLU()))
            in_channels = h_d

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hid_dims[-1]*feature_size**2, latent_dim)
        self.fc_var = nn.Linear(hid_dims[-1]*feature_size**2, latent_dim)

        # decoder
        self.decoder_input = nn.Linear(latent_dim, hid_dims[-1]*feature_size**2)
        hid_dims.reverse()

        modules = []
        for i in range(len(hid_dims)-1):
            modules.append(nn.Sequential(nn.ConvTranspose2d(hid_dims[i], hid_dims[i+1], 3, 2, 1, 1),
                           nn.BatchNorm2d(hid_dims[i+1]),
                           nn.LeakyReLU()))

        self.decoder = nn.Sequential(*modules)

        self.decoder_out = nn.Sequential(nn.ConvTranspose2d(hid_dims[-1], hid_dims[-1], 3, 2, 1, 1),
                                         nn.BatchNorm2d(hid_dims[-1]),
                                         nn.LeakyReLU(),
                                         nn.Conv2d(hid_dims[-1], 3, 3, 1, 1, 1),
                                         nn.Sigmoid())

    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, start_dim=1)
        mu = self.fc_mu(x)
        var = self.fc_var(x)
        return mu, var

    def decode(self, x):
        x = self.decoder_input(x)
        x = x.view(-1, 256, 6, 6)
        x = self.decoder(x)
        x = self.decoder_out(x)
        return x

    def re_parameterize(self, mu, log_var):
        std = torch.exp_(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + std*eps

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.re_parameterize(mu, log_var)
        out = self.decode(z)
        return out, mu, log_var

    def sample(self, n_samples, device):
        z = torch.randn((n_samples, self.latent_dim)).to(device)
        samples = self.decode(z)
        return samples


if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    fake_input = torch.ones((1, 3, 96, 96))
    model = VAE(96, 3, 1024)
    out, *_ = model(fake_input)
    print(out.shape)
    print(model.sample(10, DEVICE).shape)

Loss.py

import torch
import torch.nn as nn


class Loss(nn.Module):
    def __init__(self, kld_weight=0.03):
        super(Loss, self).__init__()
        self.kld_weight = kld_weight
        self.criterion = nn.MSELoss(reduction='mean')

    def forward(self, input, output, mu, log_var):
        recon_loss = self.criterion(output, input)
        kld_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
        return recon_loss + self.kld_weight*kld_loss

train_vae.py

import os
import numpy as np
import torch
from VAE import VAE
import argparse
from torch.utils.data import DataLoader
from PIL import Image
from torch.optim import Adam
from utils import MyDataset
from torchvision.utils import save_image
from Loss import Loss
from tqdm import tqdm


def args_parser():
    parser = argparse.ArgumentParser(description="Parameters of training vae model")
    parser.add_argument("-b", "--batch_size", type=int, default=128)
    parser.add_argument("-i", "--in_channels", type=int, default=3)
    parser.add_argument("-d", "--latent_dim", type=int, default=256)
    parser.add_argument("-l", "--lr", type=float, default=1e-3)
    parser.add_argument("-w", "--weight_decay", type=float, default=1e-5)
    parser.add_argument("-e", "--epoch", type=int, default=500)
    parser.add_argument("-v", "--snap_epoch", type=int, default=1)
    parser.add_argument("-n", "--num_samples", type=int, default=64)
    parser.add_argument("-p", "--path", type=str, default="./results_linear")
    return parser.parse_args()


def train(model, input_data, loss_fn, optimizer):
    optimizer.zero_grad()
    out, mu, log_var = model(input_data)
    total_loss = loss_fn(input_data, out, mu, log_var)
    total_loss.backward()
    optimizer.step()

    print("loss:", total_loss.item())


if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    opt = args_parser()

    loss_fn = Loss(kld_weight=0.03)

    dataset = MyDataset(img_path="../faces/", device=DEVICE)
    train_loader = DataLoader(dataset=dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0)
    model = VAE(image_size=96, in_channels=opt.in_channels, latent_dim=opt.latent_dim)
    model.to(DEVICE)

    optimizer = Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)

    for epoch in range(opt.epoch):
        model.train()
        data_bar = tqdm(train_loader)
        for step, data in enumerate(data_bar):
            train(model, data.to(DEVICE), loss_fn, optimizer)

        if epoch % opt.snap_epoch == 0 or epoch == opt.epoch - 1:
            model.eval()
            images = model.sample(opt.num_samples, DEVICE)
            imgs = images.detach().cpu().numpy()
            saved_image_path = os.path.join(opt.path, "images")
            os.makedirs(saved_image_path, exist_ok=True)
            fname = './my_generated-images-{0:0=4d}.png'.format(epoch)
            save_image(images, fname, nrow=8)
            saved_model_path = os.path.join(opt.path, "models")
            os.makedirs(saved_model_path, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(saved_model_path, f"epoch_{epoch}.pth"))

沒有調參,訓練333個epoch,模型生成的結果如下:

相關文章