日月光華的gan小例子

qq_43620967發表於2020-09-28

來自課程GAN生成對抗網路精講 tensorflow2.0程式碼實戰 全網最簡潔易懂的GAN課程
的程式碼

tensorflow版本為2.0
升級版本看https://blog.csdn.net/qq_43620967/article/details/108835207

jupyter 檔案
連結:https://pan.baidu.com/s/1PR4phiKAEoK4sAAmTWR18A
提取碼:z9od

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
tf.__version__

Out[104]: ‘2.3.1’

輸入

(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
train_images.shape

Out[106]: (60000, 28, 28)

60000張圖,都是28*28 畫素

train_images.dtype

Out[107]: dtype(‘uint8’)

資料預處理

train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images.shape

Out[109]: (60000, 28, 28, 1)

train_images=(train_images-127.5)/127.5#歸一化 到【-1,1】
BATCH_SIZE=256
BUFFER_SIZE=600000
datasets=tf.data.Dataset.from_tensor_slices(train_images)
datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
datasets

Out[114]: <BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

第一維度表示個數

生成器模型

def generator_model():
    model=tf.keras.Sequential()
    model.add(layers.Dense(256,input_shape=(100,),use_bias=False))
    #Dense全連線層,input_shape=(100,)長度100的隨機向量,use_bias=False,因為後面有BN層
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())#啟用
    
    #第二層
    model.add(layers.Dense(512,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())#啟用
    
    #輸出層
    model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh'))
    model.add(layers.BatchNormalization())
    
    model.add(layers.Reshape((28,28,1)))#變成圖片 要以元組形式傳入
    
    return model

辨別器模型

def discriminator_model():#輸入圖片
    model=keras.Sequential()
    model.add(layers.Flatten())#把輸入的三維圖片扁平化
    
    model.add(layers.Dense(512,use_bias=False))    
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())#啟用    
    
    model.add(layers.Dense(256,use_bias=False))    
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())#啟用      
    
    model.add(layers.Dense(1))#輸出數字,>0.5真實圖片
    return model

loss函式

cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)#from_logits=True因為最後的輸出沒有啟用

判別器損失函式

def discriminator_loss(real_out,fake_out):#辨別器的輸出 真實圖片判1,假的圖片判0
    real_loss=cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
    return real_loss+fake_loss 

生成器損失函式

def generator_loss(fake_out):#希望fakeimage的判別輸出fake_out判別為真
    return cross_entropy(tf.ones_like(fake_out),fake_out)

優化器

generator_opt=tf.keras.optimizers.Adam(1e-4)#學習速率
discriminator_opt=tf.keras.optimizers.Adam(1e-4)
EPOCHS=100
noise_dim=100 #長度為100的隨機向量生成手寫資料集
​
num_exp_to_generate=16 #每步生成16個樣本
​
seed=tf.random.normal([num_exp_to_generate,noise_dim]) #生成隨機向量觀察變化情況

訓練

generator=generator_model()
discriminator=discriminator_model()

定義批次訓練函式

def train_step(images):
    noise=tf.random.normal([BATCH_SIZE,noise_dim]) #生成隨機數
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: #記錄訓練過程
        real_out = discriminator(images,training=True)
        gen_image = generator(noise,training=True)
        fake_out=discriminator(gen_image,training=True)
        
        gen_loss = generator_loss(fake_out)
        disc_loss = discriminator_loss(real_out,fake_out)
        
    #訓練過程
    gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables) #生成器與生成器可訓練引數的梯度
    gradient_disc = disc_tape.gradient(disc_loss,discriminator.trainable_variables) 
    generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

視覺化

def generator_plot_image(gen_model,test_noise): #觀察訓練出的圖象
    pre_images = gen_model(test_noise,training=False)
    fig = plt.figure(figsize=(4,4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1) #從1開始排
        plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray') #歸一化,灰色度
        plt.axis('off') #不顯示座標軸
    plt.show()

訓練

def train(dataset,epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)
            #print('.',end='')
        print('第'+str(epoch+1)+'次訓練結果')
        generator_plot_image(generator,seed)
train(datasets,EPOCHS)

結果

第1次訓練結果
在這裡插入圖片描述

第100次訓練結果

在這裡插入圖片描述

相關文章