content
環境準備
資料獲取與處理
模型訓練
本專案使用fashion-MNIST資料集,模型採用keras方式進行訓練並最終部署在Android上。完整請參考Github:fashionMNIST-on-device
環境準備
1. Anaconda
- mac搭建Python開發環境
- 使用conda建立屬於深度學習的虛擬環境
注:安裝任何包請使用conda install xxx
命令
2. docker
更推薦使用docker方式搭建自己的開發環境
- 動手學Docker:docs.docker.knowledge-precipitation.site/
- 推薦大家使用此深度學習docker映象:deepo,此映象已預裝深度學習基本開發環境
>>> import tensorflow
>>> import sonnet
>>> import torch
>>> import keras
>>> import mxnet
>>> import cntk
>>> import chainer
>>> import theano
>>> import lasagne
>>> import caffe
>>> import caffe2
複製程式碼
資料獲取與處理
Fashion-MNIST是一個替代MNIST手寫數字集的影像資料集。 它是由Zalando(一家德國的時尚科技公司)旗下的研究部門提供。其涵蓋了來自10種類別的共7萬個不同商品的正面圖片。Fashion-MNIST的大小、格式和訓練集/測試集劃分與原始的MNIST完全一致。60000/10000的訓練測試資料劃分,28x28的灰度圖片。下載地址
1. 通過kaggle下載資料
2. 通過keras下載資料
from keras.datasets import fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
複製程式碼
3. 資料處理
# 首先看一下資料的形狀
print(train_images.shape)
print(test_images.shape)
#輸出結果
(60000, 28, 28)
(10000, 28, 28)
複製程式碼
可以看到訓練資料是60000張28*28的圖片,測試資料是10000張28*28的圖片。
我們來看一下圖片上都是什麼資料:
import matplotlib.pyplot as plt
plt.imshow(train_images[0])
plt.savefig("train_images_0.png")
plt.show()
複製程式碼
顯示的結果:
之後將資料做reshape使資料適合訓練,並將資料縮放到0-1之間。
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)
train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255
複製程式碼
對標籤做one-hot
編碼:
train_labels = to_categorical(train_labels, 10)
test_labels = to_categorical(test_labels, 10)
複製程式碼
將所有操作整合為一個函式:
def load_data():
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)
train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255
train_labels = to_categorical(train_labels, 10)
test_labels = to_categorical(test_labels, 10)
return (train_images, train_labels), (test_images, test_labels)
複製程式碼
模型訓練
1. CNN-v1
第一個模型使用使用三個卷積+pooling操作接兩個全連結層。
model summary如下:
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 28, 28, 16) 80
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 16) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 14, 14, 32) 2080
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 7, 7, 32) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 7, 7, 64) 8256
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 3, 3, 64) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 3, 3, 64) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 576) 0
_________________________________________________________________
dense_1 (Dense) (None, 500) 288500
_________________________________________________________________
dropout_2 (Dropout) (None, 500) 0
_________________________________________________________________
dense_2 (Dense) (None, 10) 5010
=================================================================
Total params: 303,926
Trainable params: 303,926
Non-trainable params: 0
_________________________________________________________________
複製程式碼
模型結構如圖所示:
最終在測試集上的準確率為:87%
2. CNN-v2
第二個模型比第一個模型更簡單,使用了一個卷積+pooling接兩個全連線層。
model summary如下:
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 28, 28, 32) 320
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 32) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 6272) 0
_________________________________________________________________
dense_1 (Dense) (None, 5128) 32167944
_________________________________________________________________
dense_2 (Dense) (None, 10) 51290
=================================================================
Total params: 32,219,554
Trainable params: 32,219,554
Non-trainable params: 0
_________________________________________________________________
複製程式碼
模型結構如圖所示:
最終在測試集上的準確率為:91.2%
並將預測結果進行顯示,圖片上名稱為紅色的為分類錯誤的圖片:
此篇文章為一個系列,未完待續。 歡迎大家關注我們的公眾號:知識沉澱部落。