CVPR2017:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
摘要
Despite the breakthroughs in accuracy and speed of single image super-resolution using faster and deeper convolutional neural networks, one central problem remains largely unsolved: how do we recover the finer texture details when we super-resolve at large upscaling factors? The behavior of optimization-based super-resolution methods is
principally driven by the choice of the objective function.Recent work has largely focused on minimizing the mean squared reconstruction error. The resulting estimates have
high peak signal-to-noise ratios, but they are often lacking high-frequency details and are perceptually unsatisfying in the sense that they fail to match the fidelity expected at
the higher resolution. In this paper, we present SRGAN,a generative adversarial network (GAN) for image superresolution (SR). To our knowledge, it is the first framework capable of inferring photo-realistic natural images for 4×upscaling factors. To achieve this, we propose a perceptual loss function which consists of an adversarial loss and a content loss. The adversarial loss pushes our solution to the natural image manifold using a discriminator network that is trained to differentiate between the super-resolved
images and original photo-realistic images. In addition, we use a content loss motivated by perceptual similarity instead of similarity in pixel space. Our deep residual network
is able to recover photo-realistic textures from heavily downsampled images on public benchmarks. An extensive mean-opinion-score (MOS) test shows hugely significant
gains in perceptual quality using SRGAN. The MOS scores obtained with SRGAN are closer to those of the original high-resolution images than to those obtained with any
state-of-the-art method.
論文主要講述GAN應用於超解析度的研究。作者們命名了一種對抗神經網路為SRGAN,它可以很好地將低解析度影像恢復成原始高解析度影像。在MOS測試中,SRGAN取得了比傳統方法都更高的得分。
論文主體
https://arxiv.org/pdf/1609.04802v5.pdf
實現程式碼
讀這篇論文的主要工作就是把論文中的方法復現一下,github中有許多同型別的開源專案,但由於版本更迭的問題,直接down下來的程式碼往往會有許多bug。在這裡記錄一下除錯程式碼的全過程。
Dataset
資料集下載網址:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
celebA是一個人臉資料集,包含幾十萬張不同的人臉圖片,同時這些資料已經是裁剪好的,並且經過壓縮後佔用空間都非常小,非常適合進行需要大量人臉圖片的訓練。
之後構建dataset.py來讀取dataset。
dataset.py
import glob
import random
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
class ImageDataset(Dataset):
def __init__(self, root, hr_shape):
hr_height, hr_width = hr_shape
# Transforms for low resolution images and high resolution images
self.lr_transform = transforms.Compose(
[
transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
self.hr_transform = transforms.Compose(
[
transforms.Resize((hr_height, hr_height), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
self.files = sorted(glob.glob(root + "/*.*"))
def __getitem__(self, index):
img = Image.open(self.files[index % len(self.files)])
img_lr = self.lr_transform(img)
img_hr = self.hr_transform(img)
return {"lr": img_lr, "hr": img_hr}
def __len__(self):
return len(self.files)
注意這裡指定的讀取路徑是檔案中的每一個圖片,也就是如果用data/img_align_celeba資料夾來儲存資料的話,需要將圖片全部解壓放置在data/img_align_celeba中。
Model
直接用github中寫好的論文中的網路結構
models.py
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision.models import vgg19
import math
class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
vgg19_model = vgg19(pretrained=True)
self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])
def forward(self, img):
return self.feature_extractor(img)
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, 0.8),
nn.PReLU(),
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, 0.8),
)
def forward(self, x):
return x + self.conv_block(x)
class GeneratorResNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
super(GeneratorResNet, self).__init__()
# First layer
self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())
# Residual blocks
res_blocks = []
for _ in range(n_residual_blocks):
res_blocks.append(ResidualBlock(64))
self.res_blocks = nn.Sequential(*res_blocks)
# Second conv layer post residual blocks
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8))
# Upsampling layers
upsampling = []
for out_features in range(2):
upsampling += [
# nn.Upsample(scale_factor=2),
nn.Conv2d(64, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.PixelShuffle(upscale_factor=2),
nn.PReLU(),
]
self.upsampling = nn.Sequential(*upsampling)
# Final output layer
self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())
def forward(self, x):
out1 = self.conv1(x)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = torch.add(out1, out2)
out = self.upsampling(out)
out = self.conv3(out)
return out
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
self.input_shape = input_shape
in_channels, in_height, in_width = self.input_shape
patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
self.output_shape = (1, patch_h, patch_w)
def discriminator_block(in_filters, out_filters, first_block=False):
layers = []
layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
if not first_block:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
layers = []
in_filters = in_channels
for i, out_filters in enumerate([64, 128, 256, 512]):
layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
in_filters = out_filters
layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
self.model = nn.Sequential(*layers)
def forward(self, img):
return self.model(img)
SRGAN
最後編寫訓練模組
srgan.py
import argparse
import os
import numpy as np
import math
import itertools
import sys
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torch.autograd import Variable
from models import *
from datasets import *
import torch.nn as nn
import torch.nn.functional as F
import torch
os.makedirs("images", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="img_align_celeba_demo", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--hr_height", type=int, default=256, help="high res. image height")
parser.add_argument("--hr_width", type=int, default=256, help="high res. image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=500, help="interval between saving image samples")
parser.add_argument("--checkpoint_interval", type=int, default=50, help="interval between model checkpoints")
opt = parser.parse_args()
print(opt)
cuda = torch.cuda.is_available()
hr_shape = (opt.hr_height, opt.hr_width)
# Initialize generator and discriminator
generator = GeneratorResNet()
discriminator = Discriminator(input_shape=(opt.channels, *hr_shape))
feature_extractor = FeatureExtractor()
# Set feature extractor to inference mode
feature_extractor.eval()
# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()
if cuda:
generator = generator.cuda()
discriminator = discriminator.cuda()
feature_extractor = feature_extractor.cuda()
criterion_GAN = criterion_GAN.cuda()
criterion_content = criterion_content.cuda()
if opt.epoch != 0:
# Load pretrained models
generator.load_state_dict(torch.load("saved_models/generator_%d.pth"))
discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth"))
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
dataloader = DataLoader(
ImageDataset("../../data/%s" % opt.dataset_name, hr_shape=hr_shape),
batch_size=opt.batch_size,
shuffle=True,
num_workers=3,
)
# ----------
# Training
# ----------
for epoch in range(opt.epoch, opt.n_epochs):
for i, imgs in enumerate(dataloader):
# Configure model input
imgs_lr = Variable(imgs["lr"].type(Tensor))
imgs_hr = Variable(imgs["hr"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
# ------------------
# Train Generators
# ------------------
optimizer_G.zero_grad()
# Generate a high resolution image from low resolution input
gen_hr = generator(imgs_lr)
# Adversarial loss
loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
# Content loss
gen_features = feature_extractor(gen_hr)
real_features = feature_extractor(imgs_hr)
loss_content = criterion_content(gen_features, real_features.detach())
# Total loss
loss_G = loss_content + 1e-3 * loss_GAN
loss_G.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Loss of real and fake images
loss_real = criterion_GAN(discriminator(imgs_hr), valid)
loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
# Total loss
loss_D = (loss_real + loss_fake) / 2
loss_D.backward()
optimizer_D.step()
# --------------
# Log Progress
# --------------
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
# Save image grid with upsampled inputs and SRGAN outputs
imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
img_grid = torch.cat((imgs_lr, gen_hr), -1)
save_image(img_grid, "images/%d.png" % batches_done, normalize=False)
sys.stdout.write(
"[Epoch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch+1, opt.n_epochs, loss_D.item(), loss_G.item())
)
print(" ")
if opt.checkpoint_interval != -1 and (epoch+1) % opt.checkpoint_interval == 0:
# Save model checkpoints
torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % epoch)
torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" % epoch)
原作者構建的srgan.py執行環境在控制檯中,這樣便於使用但不適合除錯。為了方便地除錯還是用IDE開啟更好一些。
程式碼中有幾個需要注意的問題:
①首先是版本問題,這裡要求pytorch為最新版,否則utils的讀取可能會出現調包路徑錯誤的問題。
②在dataloader中
dataloader = DataLoader(
ImageDataset("../../data/%s" % opt.dataset_name, hr_shape=hr_shape),
batch_size=opt.batch_size,
shuffle=True,
num_workers=3,
)
當出現samle讀取資料量為0的錯誤時,將shuffle=True改為False。
當出現記憶體錯誤時,將num_workers=3改為0或直接預設為0。原因是Win系統中不能完美相容多執行緒,很多時候因為指定了多個執行緒記憶體會莫名的溢位,改為0最保險。
③關於剛開始的預設引數:
batch_size,lr預設數值效果都很好,暫時還未嘗試其他取值
epoch因為資料量巨大,應該酌情縮小一些(嘗試了只用1k張圖片跑epoch50,效果還可以)
checkpoint_interval和sample_interval都是監督訓練進度的引數,儲存模型和實時對比超解析度恢復的情況。
實現結果
100張影像,epoch200,sample_interval取500(實際上就是batch跑了多少次,整個除法真上流 )
儲存的部分結果如下:
初始:
500次:
1000次:
2000次:
3000次:
相關文章
- 論文翻譯:2020_Generative Adversarial Network based Acoustic Echo Cancellation
- Deep Unfolding Network for Image Super-Resolution 論文解讀
- 2019 IJCNN之GAN(image transfer(face)):Attention-Guided Generative Adversarial Networks for UnsupervisCNNGUIIDE
- Image Super-Resolution Using DeepConvolutional Networks論文閱讀筆記筆記
- 2019-02-26 論文閱讀:Learning a Single Convolutional Super-Resolution Network for Multiple Degradations...
- 論文解讀(GAN)《Generative Adversarial Networks》
- THFuse: An infrared and visible image fusion network using transformer and hybrid feature extractor 論文解讀ORM
- CIAGAN: Conditional Identity Anonymization Generative Adversarial Networks閱讀筆記IDE筆記
- Wasserstein Generative adversarial Networks (WGANs) 學習WGAN資料彙總
- 語音合成論文翻譯:2019_MelGAN: Generative Adversarial Networks for Conditional Waveform SynthesisORM
- Blind Super-Resolution Kernel Estimation using an Internal-GAN 論文解讀
- [論文閱讀筆記] Adversarial Mutual Information Learning for Network Embedding筆記ORM
- [SIGIR2020] Sequential Recommendation with Self-Attentive Multi-Adversarial Network
- SciTech-Mathmatics-ImageProcessing-Remove the Background from an image using Python?REMPython
- 閱讀筆記:XGPT: Cross-modal Generative Pre-Training for Image Captioning筆記GPTROSAIAPT
- Image Super-Resolution via Sparse Representation——基於稀疏表示的超解析度重建
- 論文筆記 SimpleNet A Simple Network for Image Anomaly Detection and Localization筆記
- Complete Face Recovery GAN: Unsupervised Joint Face Rotation and De-Occlusion from a Single-View ImageView
- Cryptanalyzing and Improving a Novel Color Image Encryption Algorithm Using RT-Enhanced Chaotic Tent MapsGo
- 論文閱讀 dyngraph2vec: Capturing Network Dynamics using Dynamic Graph Representation LearningAPT
- free Generative AI courses All In OneAI
- SciTech-BigDataAIML-LLM-Generative modelAI
- CVPR2017部分論文簡介
- Flutter 入門與實戰(七):使用 cached_image_network 優化圖片載入體驗Flutter優化
- 136. Single Number
- RAG-Multi-Modal-Generative-AI-AgentAI
- MySQL 索引優化 Using where, Using filesortMySql索引優化
- WPF Image Image clip EllipseGeometry
- As a reader --> Diffusion Models for Imperceptible and Transferable Adversarial Attack
- MySQL explain結果Extra中"Using Index"與"Using where; Using index"區別MySqlAIIndex
- FSMO(Flexible Single Master Operation)FlexAST
- 137-Single Number II
- 260-Single Number III
- RMAN Duplicate RAC to Single Instance
- Image
- As a reader --> AdvDiffuser: Natural Adversarial Example Synthesis with Diffusion Models
- Using hints for PostgresqlSQL
- String interpolation using $