日月光華的gan小例子
來自課程GAN生成對抗網路精講 tensorflow2.0程式碼實戰 全網最簡潔易懂的GAN課程
的程式碼
tensorflow版本為2.0
升級版本看https://blog.csdn.net/qq_43620967/article/details/108835207
jupyter 檔案
連結:https://pan.baidu.com/s/1PR4phiKAEoK4sAAmTWR18A
提取碼:z9od
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
tf.__version__
Out[104]: ‘2.3.1’
輸入
(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
train_images.shape
Out[106]: (60000, 28, 28)
60000張圖,都是28*28 畫素
train_images.dtype
Out[107]: dtype(‘uint8’)
資料預處理
train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images.shape
Out[109]: (60000, 28, 28, 1)
train_images=(train_images-127.5)/127.5#歸一化 到【-1,1】
BATCH_SIZE=256
BUFFER_SIZE=600000
datasets=tf.data.Dataset.from_tensor_slices(train_images)
datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
datasets
Out[114]: <BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>
第一維度表示個數
生成器模型
def generator_model():
model=tf.keras.Sequential()
model.add(layers.Dense(256,input_shape=(100,),use_bias=False))
#Dense全連線層,input_shape=(100,)長度100的隨機向量,use_bias=False,因為後面有BN層
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())#啟用
#第二層
model.add(layers.Dense(512,use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())#啟用
#輸出層
model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh'))
model.add(layers.BatchNormalization())
model.add(layers.Reshape((28,28,1)))#變成圖片 要以元組形式傳入
return model
辨別器模型
def discriminator_model():#輸入圖片
model=keras.Sequential()
model.add(layers.Flatten())#把輸入的三維圖片扁平化
model.add(layers.Dense(512,use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())#啟用
model.add(layers.Dense(256,use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())#啟用
model.add(layers.Dense(1))#輸出數字,>0.5真實圖片
return model
loss函式
cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)#from_logits=True因為最後的輸出沒有啟用
判別器損失函式
def discriminator_loss(real_out,fake_out):#辨別器的輸出 真實圖片判1,假的圖片判0
real_loss=cross_entropy(tf.ones_like(real_out),real_out)
fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
return real_loss+fake_loss
生成器損失函式
def generator_loss(fake_out):#希望fakeimage的判別輸出fake_out判別為真
return cross_entropy(tf.ones_like(fake_out),fake_out)
優化器
generator_opt=tf.keras.optimizers.Adam(1e-4)#學習速率
discriminator_opt=tf.keras.optimizers.Adam(1e-4)
EPOCHS=100
noise_dim=100 #長度為100的隨機向量生成手寫資料集
num_exp_to_generate=16 #每步生成16個樣本
seed=tf.random.normal([num_exp_to_generate,noise_dim]) #生成隨機向量觀察變化情況
訓練
generator=generator_model()
discriminator=discriminator_model()
定義批次訓練函式
def train_step(images):
noise=tf.random.normal([BATCH_SIZE,noise_dim]) #生成隨機數
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: #記錄訓練過程
real_out = discriminator(images,training=True)
gen_image = generator(noise,training=True)
fake_out=discriminator(gen_image,training=True)
gen_loss = generator_loss(fake_out)
disc_loss = discriminator_loss(real_out,fake_out)
#訓練過程
gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables) #生成器與生成器可訓練引數的梯度
gradient_disc = disc_tape.gradient(disc_loss,discriminator.trainable_variables)
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))
視覺化
def generator_plot_image(gen_model,test_noise): #觀察訓練出的圖象
pre_images = gen_model(test_noise,training=False)
fig = plt.figure(figsize=(4,4))
for i in range(pre_images.shape[0]):
plt.subplot(4,4,i+1) #從1開始排
plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray') #歸一化,灰色度
plt.axis('off') #不顯示座標軸
plt.show()
訓練
def train(dataset,epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
#print('.',end='')
print('第'+str(epoch+1)+'次訓練結果')
generator_plot_image(generator,seed)
train(datasets,EPOCHS)
結果
第1次訓練結果
第100次訓練結果
相關文章
- 「GAN優化」GAN訓練的小技巧優化
- 【python小例子】小例子拾憶Python
- python字典的小例子Python
- mybatis小例子2MyBatis
- PHP中ZendCache用法的小例子PHP
- react-refetch的使用小例子React
- 小例子理解多型多型
- While True用法小例子While
- 一個小例子搞懂redux的套路Redux
- 非同步學習小例子非同步
- 有關mysql中ROW_COUNT()的小例子MySql
- 小例子 理解 Laravel 中的 控制反轉模式Laravel模式
- 有關程式碼執行效率提升的小例子
- if、else if、else判斷語句的幾個小例子
- 一個被寫爛的redux計數小例子Redux
- 正則中關於環視(lookaround)的小例子
- Fake許可權驗證小例子
- 幾個彙編入門小例子
- 能量視角下的GAN模型:GAN=“挖坑”+“跳坑”模型
- 能量視角下的GAN模型(二):GAN=“分析”+“取樣”模型
- 李巨集毅GAN學習(四)GAN的基本理論
- dubbo入門和springboot整合dubbo小例子Spring Boot
- Wasserstein GAN
- 「GAN優化」如何選好正則項讓你的GAN收斂優化
- popmenu的例子
- 一個小例子,給你講透典型的 Go 併發操作Go
- mysql返回一個結果集的儲存過程小例子MySql儲存過程
- 十幾個python小例子,從此愛上pythonPython
- GAN入門
- 從兩個小例子看js中的隱式型別轉換JS型別
- 教小師妹學多執行緒,一個有深度的例子!執行緒
- 一個有趣的小例子,帶你入門協程模組-asyncio
- GAN網路從入門教程(二)之GAN原理
- 李弘毅老師GAN筆記(三),Unsupervised Conditional GAN筆記
- pytorch訓練GAN時的detach()PyTorch
- 谷歌開源的 GAN 庫–TFGAN谷歌
- 關於GAN的個人理解
- (精華)2020年7月9日 微信小程式 小程式生命週期和全域性變數微信小程式變數