tf2.0 cycle-gan,官方程式碼復現整理。
官方資料集網址:https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/
官方教程:https://tensorflow.google.cn/tutorials
cycle-gan程式碼(已修改,可用):
import tensorflow as tf
import tensorflow_datasets as tfds
from pix2pix import pix2pix
import os
import time, argparse
import matplotlib.pyplot as plt
from IPython.display import clear_output
tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE
# dataset, metadata = tfds.load(name='horse2zebra',data_dir='E:\\Users\\CycleGAN-tf2.0-tourtial\\dataset',
# with_info=True, as_supervised=True)
# train_horses, train_zebras = dataset['trainA'], dataset['trainB']
# test_horses, test_zebras = dataset['testA'], dataset['testB']
def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--DATASET", default='apple2orange', type=str)
args = parser.parse_args()
return args
args = arg_parser()
dataset_path = os.path.join(os.getcwd(), 'dataset/' + args.DATASET + '/')
train_horses = tf.data.Dataset.list_files(dataset_path + 'trainA/*')
train_zebras = tf.data.Dataset.list_files(dataset_path + 'trainB/*')
test_horses = tf.data.Dataset.list_files(dataset_path + 'testA/*')
test_zebras = tf.data.Dataset.list_files(dataset_path + 'testB/*')
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
cropped_image = tf.image.random_crop(
image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image
# 將影像歸一化到區間 [-1, 1] 內。
def normalize(image):
image = tf.cast(image, tf.float32)
image = (image / 127.5) - 1
return image
def random_jitter(image):
# 調整大小為 286 x 286 x 3
image = tf.image.resize(image, [286, 286],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
# 隨機裁剪到 256 x 256 x 3
image = random_crop(image)
# 隨機映象
image = tf.image.random_flip_left_right(image)
return image
def load(image_file):
image = tf.io.read_file(image_file)
image = tf.image.decode_jpeg(image)
image = tf.cast(image, tf.float32)
return image
def preprocess_image_train(image_file):
image = load(image_file)
image = random_jitter(image)
image = normalize(image)
return image
def preprocess_image_test(image_file):
image = load(image_file)
image = normalize(image)
return image
train_horses = train_horses.map(
preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
train_zebras = train_zebras.map(
preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
test_horses = test_horses.map(
preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
test_zebras = test_zebras.map(
preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
OUTPUT_CHANNELS = 3
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8
imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']
for i in range(len(imgs)):
plt.subplot(2, 2, i+1)
plt.title(title[i])
if i % 2 == 0:
plt.imshow(imgs[i][0] * 0.5 + 0.5)
else:
plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
plt.figure(figsize=(8, 8))
plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')
plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')
plt.show()
LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
real_loss = loss_obj(tf.ones_like(real), real)
generated_loss = loss_obj(tf.zeros_like(generated), generated)
total_disc_loss = real_loss + generated_loss
return total_disc_loss * 0.5
def generator_loss(generated):
return loss_obj(tf.ones_like(generated), generated)
def calc_cycle_loss(real_image, cycled_image):
loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
return LAMBDA * loss1
def identity_loss(real_image, same_image):
loss = tf.reduce_mean(tf.abs(real_image - same_image))
return LAMBDA * 0.5 * loss
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
generator_f=generator_f,
discriminator_x=discriminator_x,
discriminator_y=discriminator_y,
generator_g_optimizer=generator_g_optimizer,
generator_f_optimizer=generator_f_optimizer,
discriminator_x_optimizer=discriminator_x_optimizer,
discriminator_y_optimizer=discriminator_y_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# 如果存在檢查點,恢復最新版本檢查點
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
EPOCHS = 100
def generate_images(model, test_input, save_path, epoch):
prediction = model(test_input)
plt.figure(figsize=(12, 12))
display_list = [test_input[0], prediction[0]]
title = ['Input Image', 'Predicted Image']
for i in range(2):
plt.subplot(1, 2, i+1)
plt.title(title[i])
# 獲取範圍在 [0, 1] 之間的畫素值以繪製它。
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
# plt.show()
plt.savefig(save_path+str(epoch)+'_test.jpg')
def train_step(real_x, real_y):
# persistent 設定為 Ture,因為 GradientTape 被多次應用於計算梯度。
with tf.GradientTape(persistent=True) as tape:
# 生成器 G 轉換 X -> Y。
# 生成器 F 轉換 Y -> X。
fake_y = generator_g(real_x, training=True)
cycled_x = generator_f(fake_y, training=True)
fake_x = generator_f(real_y, training=True)
cycled_y = generator_g(fake_x, training=True)
# same_x 和 same_y 用於一致性損失。
same_x = generator_f(real_x, training=True)
same_y = generator_g(real_y, training=True)
disc_real_x = discriminator_x(real_x, training=True)
disc_real_y = discriminator_y(real_y, training=True)
disc_fake_x = discriminator_x(fake_x, training=True)
disc_fake_y = discriminator_y(fake_y, training=True)
# 計算損失。
gen_g_loss = generator_loss(disc_fake_y)
gen_f_loss = generator_loss(disc_fake_x)
total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
# 總生成器損失 = 對抗性損失 + 迴圈損失。
total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
# 計算生成器和判別器損失。
generator_g_gradients = tape.gradient(total_gen_g_loss,
generator_g.trainable_variables)
generator_f_gradients = tape.gradient(total_gen_f_loss,
generator_f.trainable_variables)
discriminator_x_gradients = tape.gradient(disc_x_loss,
discriminator_x.trainable_variables)
discriminator_y_gradients = tape.gradient(disc_y_loss,
discriminator_y.trainable_variables)
# 將梯度應用於優化器。
generator_g_optimizer.apply_gradients(zip(generator_g_gradients,
generator_g.trainable_variables))
generator_f_optimizer.apply_gradients(zip(generator_f_gradients,
generator_f.trainable_variables))
discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
discriminator_x.trainable_variables))
discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
start = time.time()
n = 0
for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
train_step(image_x, image_y)
if n % 10 == 0:
print ('step=%d' % n)
n+=1
clear_output(wait=True)
# 使用一致的影像(sample_horse),以便模型的進度清晰可見。
save_path = 'E:\\Users\\CycleGAN-tf2.0-tourtial\\samples\\'
generate_images(generator_g, sample_horse, save_path, epoch)
if (epoch + 1) % 5 == 0:
ckpt_save_path = ckpt_manager.save()
print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
ckpt_save_path))
print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
time.time()-start))
# 在測試資料集上執行訓練的模型。
test_num = 0
save_path_test = 'E:\\Users\\CycleGAN-tf2.0-tourtial\\test_save\\'
for inp in test_horses.take(5):
generate_images(generator_g, inp, save_path_test, test_num)
test_num = test_num+1
相關文章
- 《動手學深度學習》TF2.0 實現深度學習TF2
- css程式碼整理CSS
- 程式碼規範整理
- Oracle 官方教材閱讀整理Oracle
- iptables官方手冊整理薦
- 什麼是程式碼整理?
- 頁面常用程式碼整理
- 微信隱藏程式碼整理
- python常用程式碼整理Python
- 國家密碼標準-商密SM2官方文件整理密碼
- 以太坊官方 Token 程式碼詳解
- 微信官方單檔案接入程式碼
- 常用的JScript程式碼整理JS
- 軟著整理程式碼快速生成
- OpenDoc - 專案程式碼整理指南
- js實現的返回並重新整理上一頁程式碼JS
- 自復制程式碼
- 30 個Python程式碼實現的常用功能,精心整理版Python
- javascript實現的重新整理當前頁面程式碼例項JavaScript
- Transformers2.0讓你三行程式碼呼叫語言模型,相容TF2.0和PyTorchORM行程模型TF2PyTorch
- 高危 Bug !Apache Log4j2 遠端程式碼執行漏洞:官方正加急修復中 !Apache
- 谷歌官方 Android MVP 模式程式碼解讀谷歌AndroidMVP模式
- 整理課程中將程式碼納入Git程式碼版本控制Git
- js實現的移動端下拉重新整理功能程式碼例項JS
- [VUE系列二]vue官方文件總結和整理Vue
- Oracle 獲取整數方式程式碼整理Oracle
- 可以重新整理頁面的javascript程式碼JavaScript
- 目標識別程式碼解讀整理
- 控制程式碼的本質(整理-收藏)
- 【python海龜畫圖】程式碼整理Python
- 突然發現hibernate釋出官方examples程式了
- ajax實現的無重新整理使用者登入例項程式碼
- DiskGenius(磁碟修復工具)官方版
- Google官方MVP示例程式碼閱讀筆記GoMVP筆記
- 程式碼審計之seacms v6.45 前臺Getshell 復現分析ACM
- Apache log4j2 遠端程式碼執行漏洞復現?Apache
- 肯特·貝克:改變人生的程式碼整理魔法
- Java基礎知識整理之程式碼塊Java