Python深度學習入門之mnist-inception(Tensorflow2.0實現)
mnist手寫數字資料集深度學習最常用的資料集,本文以mnist資料集為例,利用Tensorflow2.0框架搭建inception網路,實現mnist資料集識別任務,並畫出各個曲線。
Demo完整程式碼如下:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
#載入mnist資料集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
#預處理
x_train, x_test = x_train.astype(np.float32)/255., x_test.astype(np.float32)/255.
x_train, x_test = np.expand_dims(x_train, axis=3), np.expand_dims(x_test, axis=3)
# 建立訓練集50000、驗證集10000以及測試集10000
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
#標籤轉為one-hot格式
y_train = tf.one_hot(y_train, depth=10).numpy()
y_val = tf.one_hot(y_val, depth=10).numpy()
y_test = tf.one_hot(y_test, depth=10).numpy()
# tf.data.Dataset 批處理
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(100).repeat()
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(100).repeat()
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(100).repeat()
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow import keras
class ConvBNRelu(keras.Model):
def __init__(self, ch, kernelsz=3, strides=1, padding='same'):
super(ConvBNRelu, self).__init__()
self.model = keras.models.Sequential([
keras.layers.Conv2D(ch, kernelsz, strides=strides, padding=padding),
keras.layers.BatchNormalization(),
keras.layers.ReLU()
])
def call(self, x, training=None):
x = self.model(x, training=training)
return x
class InceptionBlk(keras.Model):
def __init__(self, ch, strides=1):
super(InceptionBlk, self).__init__()
self.ch = ch
self.strides = strides
self.conv1 = ConvBNRelu(ch, strides=strides)
self.conv2 = ConvBNRelu(ch, kernelsz=3, strides=strides)
self.conv3_1 = ConvBNRelu(ch, kernelsz=3, strides=strides)
self.conv3_2 = ConvBNRelu(ch, kernelsz=3, strides=1)
self.pool = keras.layers.MaxPooling2D(3, strides=1, padding='same')
self.pool_conv = ConvBNRelu(ch, strides=strides)
def call(self, x, training=None):
x1 = self.conv1(x, training=training)
x2 = self.conv2(x, training=training)
x3_1 = self.conv3_1(x, training=training)
x3_2 = self.conv3_2(x3_1, training=training)
x4 = self.pool(x)
x4 = self.pool_conv(x4, training=training)
# concat along axis=channel
x = tf.concat([x1, x2, x3_2, x4], axis=3)
return x
class Inception(keras.Model):
def __init__(self, num_layers, num_classes, init_ch=16, **kwargs):
super(Inception, self).__init__(**kwargs)
self.in_channels = init_ch
self.out_channels = init_ch
self.num_layers = num_layers
self.init_ch = init_ch
self.conv1 = ConvBNRelu(init_ch)
self.blocks = keras.models.Sequential(name='dynamic-blocks')
for block_id in range(num_layers):
for layer_id in range(2):
if layer_id == 0:
block = InceptionBlk(self.out_channels, strides=2)
else:
block = InceptionBlk(self.out_channels, strides=1)
self.blocks.add(block)
# enlarger out_channels per block
self.out_channels *= 2
self.avg_pool = keras.layers.GlobalAveragePooling2D()
self.fc = keras.layers.Dense(num_classes)
def call(self, x, training=None):
out = self.conv1(x, training=training)
out = self.blocks(out, training=training)
out = self.avg_pool(out)
out = self.fc(out)
return out
#網路引數設定
model_inception = Inception(2, 10)
model_inception.compile(optimizer=keras.optimizers.Adam(0.001),
loss=keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['acc'])
model_inception.build(input_shape=(None, 28, 28, 1))
#列印網路引數
model_inception.summary()
#開始訓練
history_inception = model_inception.fit(train_dataset, epochs=50, steps_per_epoch=30, validation_data=val_dataset, validation_steps=3)
#模型評估及儲存權重
model_inception.evaluate(test_dataset, steps=100)
model_inception.save_weights('save_model/inception_mnist/inception_mnist_weights.ckpt')
網路引數
評估結果
繪製曲線程式碼
import matplotlib.pyplot as plt
#輸入兩個曲線的資訊
plt.figure( figsize=(12,8), dpi=80 )
plt.plot(history_inception.epoch, history_inception.history.get('loss'), color='r', label = 'loss')
plt.plot(history_inception.epoch, history_inception.history.get('acc'), color='g', linestyle='-.', label = 'acc')
plt.plot(history_inception.epoch, history_inception.history.get('val_acc'), color='b', linestyle='--', label = 'val_acc')
#顯示圖例
plt.legend() #預設loc=Best
#新增網格資訊
plt.grid(True, linestyle='--', alpha=0.5) #預設是True,風格設定為虛線,alpha為透明度
#新增標題
plt.xlabel('epochs')
plt.ylabel('loss/acc')
plt.title('inception_Curve of loss/acc Change with epochs in Mnist')
plt.savefig('./save_png/inception.png')
plt.show()
網路曲線
相關文章
- 深度學習:TensorFlow入門實戰深度學習
- 《深度學習入門:基於Python的理論與實現》 Deep Learning from Scratch深度學習Python
- 《動手學深度學習》TensorFlow2.0版本深度學習
- 初入門Python學習之概念區分Python
- Python學習手冊(入門&爬蟲&資料分析&機器學習&深度學習)Python爬蟲機器學習深度學習
- 機器學習和深度學習概念入門機器學習深度學習
- Python入門之web2py框架學習!PythonWeb框架
- 學習Python之後,可以做哪些兼職?Python入門!Python
- 如何學習Python?Python學習入門路線Python
- 深度學習後門攻擊分析與實現(二)深度學習
- 深度學習後門攻擊分析與實現(一)深度學習
- 行業專家分享:深度學習筆記之Tensorflow入門!行業深度學習筆記
- 【深度學習】--GAN從入門到初始深度學習
- 【深度學習】神經網路入門深度學習神經網路
- Anaconda Pytorch 深度學習入門記錄PyTorch深度學習
- 學習Python需要考證嗎?Python學習入門!Python
- Python入門難嗎?如何順利入門Python學習?Python
- Python入門學習之異常處理機制Python
- 《深度學習入門:》學習基本第一章深度學習
- 【乾貨】機器學習和深度學習概念入門機器學習深度學習
- 深度學習入門:基於Python的理論與實現-第三章sys.path問題深度學習Python
- 快速入門——深度學習理論解析與實戰應用深度學習
- 深度學習入門筆記——Transform的使用深度學習筆記ORM
- 【Pytorch教程】迅速入門Pytorch深度學習框架PyTorch深度學習框架
- 深度學習入門筆記——DataLoader的使用深度學習筆記
- Python入門學習 之 永久儲存、異常處理Python
- 用Python和深度學習實現iPhone X的Face IDPython深度學習iPhone
- Python爬蟲入門學習實戰專案(一)Python爬蟲
- 《深度學習之TensorFlow:入門、原理與進階實戰》PDF+原始碼+李金洪深度學習原始碼
- Python深度學習Python深度學習
- 如何學習Python,新手如何入門Python
- 0基礎學習Python該如何入門?Python學習方法!Python
- 零基礎入門深度學習(一):用numpy實現神經網路訓練深度學習神經網路
- 如何學習python程式語言?python入門Python
- 深度學習之PyTorch實戰(4)——遷移學習深度學習PyTorch遷移學習
- 好程式設計師Python學習路線之python爬蟲入門程式設計師Python爬蟲
- 機器學習之小白入門機器學習
- 使用 C# 入門深度學習:Pytorch 基礎C#深度學習PyTorch