keras實現MobileNet

周霖發表於2020-11-27

利用keras實現MobileNet,並以mnist資料集作為一個小例子進行識別。使用的環境是:tensorflow-gpu 2.0,python=3.7 , GTX-2070的GPU

1.匯入資料

  • 首先是匯入兩行魔法命令,可以多行顯示.
%config InteractiveShell.ast_node_interactivity="all"
%pprint
  • 載入keras中自帶的mnist資料
import tensorflow as tf
import keras 

tf.debugging.set_log_device_placement(True)

mnist = keras.datasets.mnist

(x_train,y_train),(x_test,y_test) = mnist.load_data()

上述tf.debugging.set_log_device_placement(True)的作用是將模型放在GPU上進行訓練。

  • 資料的轉換
    在mnist上下載的資料的解析度是2828的,mobilenet用來訓練的資料是ImageNet ,其圖片的解析度是224224,所以先將圖片的維度調整為224*224.
from PIL import Image
import numpy as np
def convert_mnist_224pix(X):
    img=Image.fromarray(X)
    x=np.zeros((224,224))
    img=np.array(img.resize((224,224)))
    x[:,:]=img
    
    return x

iteration = iter(x_train)
new_train =np.zeros((len(x_train),224,224),dtype=np.float32)
for i in range(len(x_train)):
    data = next(iteration)
    new_train[i]=convert_mnist_224pix(data)
    
    if i%5000==0:
        print(i)
    

new_train.shape

這裡要注意一下,new_train中一定要註明dtype=np.float32,不然預設的是float64,這樣資料就太大了,沒有那麼多儲存空間裝。最後輸出的維度是(60000,224,224)

2.搭建模型

  • 匯入所有需要的函式和庫
from keras.layers import Conv2D,DepthwiseConv2D,Dense,AveragePooling2D,BatchNormalization,Input
from keras import Model
from keras import Sequential
from keras.layers.advanced_activations import ReLU
from keras.utils import to_categorical
  • 自己定義中間可以重複利用的層,將其放在一起,簡化搭建網路的重複程式碼。
def depth_point_conv2d(x,s=[1,1,2,1],channel=[64,128]):
    """
    s:the strides of the conv
    channel: the depth of pointwiseconvolutions
    """
    
    dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
    bn1 = BatchNormalization()(dw1)
    relu1 = ReLU()(bn1)
    pw1 = Conv2D(channel[0],(1,1),strides=s[1],padding='same')(relu1)
    bn2 = BatchNormalization()(pw1)
    relu2 = ReLU()(bn2)
    dw2 = DepthwiseConv2D((3,3),strides=s[2],padding='same')(relu2)
    bn3 = BatchNormalization()(dw2)
    relu3 = ReLU()(bn3)
    pw2 = Conv2D(channel[1],(1,1),strides=s[3],padding='same')(relu3)
    bn4 = BatchNormalization()(pw2)
    relu4 = ReLU()(bn4)
    
    return relu4
    
def repeat_conv(x,s=[1,1],channel=512):
    dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
    bn1 = BatchNormalization()(dw1)
    relu1 = ReLU()(bn1)
    pw1 = Conv2D(channel,(1,1),strides=s[1],padding='same')(relu1)
    bn2 = BatchNormalization()(pw1)
    relu2 = ReLU()(bn2)
    
    return relu2
    

根據mobilenet論文中的結構進行模型的搭建
MobileNet在倒數第5行Conv/dw/s2中,我一直不理解如果strides=2,為什麼最後生成圖片尺寸沒有變化,我感覺可能是筆誤?,不過我這裡將strides定義為1,因為這樣才符合後面的整個輸出。

  • 搭建網路
h0=Input(shape=(224,224,1))
h1=Conv2D(32,(3,3),strides = 2,padding="same")(h0)
h2= BatchNormalization()(h1)
h3=ReLU()(h2)
h4 = depth_point_conv2d(h3,s=[1,1,2,1],channel=[64,128])
h5 = depth_point_conv2d(h4,s=[1,1,2,1],channel=[128,256])
h6 = depth_point_conv2d(h5,s=[1,1,2,1],channel=[256,512])
h7 = repeat_conv(h6)
h8 = repeat_conv(h7)
h9 = repeat_conv(h8)
h10 = repeat_conv(h9)
h11 = depth_point_conv2d(h10,s=[1,1,2,1],channel=[512,1024])
h12 = repeat_conv(h11,channel=1024)
h13 = AveragePooling2D((7,7))(h12)
h14 = Dense(10,activation='softmax')(h13)
model =Model(input=h0,output =h14)
model.summary()
Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_11 (InputLayer)        (None, 224, 224, 1)       0         
_________________________________________________________________
conv2d_63 (Conv2D)           (None, 112, 112, 32)      320       
_________________________________________________________________
batch_normalization_120 (Bat (None, 112, 112, 32)      128       
_________________________________________________________________
re_lu_120 (ReLU)             (None, 112, 112, 32)      0         
_________________________________________________________________
depthwise_conv2d_58 (Depthwi (None, 112, 112, 32)      320       
_________________________________________________________________
batch_normalization_121 (Bat (None, 112, 112, 32)      128       
_________________________________________________________________
re_lu_121 (ReLU)             (None, 112, 112, 32)      0         
_________________________________________________________________
conv2d_64 (Conv2D)           (None, 112, 112, 64)      2112      
_________________________________________________________________
batch_normalization_122 (Bat (None, 112, 112, 64)      256       
_________________________________________________________________
re_lu_122 (ReLU)             (None, 112, 112, 64)      0         
_________________________________________________________________
depthwise_conv2d_59 (Depthwi (None, 56, 56, 64)        640       
_________________________________________________________________
batch_normalization_123 (Bat (None, 56, 56, 64)        256       
_________________________________________________________________
re_lu_123 (ReLU)             (None, 56, 56, 64)        0         
_________________________________________________________________
conv2d_65 (Conv2D)           (None, 56, 56, 128)       8320      
_________________________________________________________________
batch_normalization_124 (Bat (None, 56, 56, 128)       512       
_________________________________________________________________
re_lu_124 (ReLU)             (None, 56, 56, 128)       0         
_________________________________________________________________
depthwise_conv2d_60 (Depthwi (None, 56, 56, 128)       1280      
_________________________________________________________________
batch_normalization_125 (Bat (None, 56, 56, 128)       512       
_________________________________________________________________
re_lu_125 (ReLU)             (None, 56, 56, 128)       0         
_________________________________________________________________
conv2d_66 (Conv2D)           (None, 56, 56, 128)       16512     
_________________________________________________________________
batch_normalization_126 (Bat (None, 56, 56, 128)       512       
_________________________________________________________________
re_lu_126 (ReLU)             (None, 56, 56, 128)       0         
_________________________________________________________________
depthwise_conv2d_61 (Depthwi (None, 28, 28, 128)       1280      
_________________________________________________________________
batch_normalization_127 (Bat (None, 28, 28, 128)       512       
_________________________________________________________________
re_lu_127 (ReLU)             (None, 28, 28, 128)       0         
_________________________________________________________________
conv2d_67 (Conv2D)           (None, 28, 28, 256)       33024     
_________________________________________________________________
batch_normalization_128 (Bat (None, 28, 28, 256)       1024      
_________________________________________________________________
re_lu_128 (ReLU)             (None, 28, 28, 256)       0         
_________________________________________________________________
depthwise_conv2d_62 (Depthwi (None, 28, 28, 256)       2560      
_________________________________________________________________
batch_normalization_129 (Bat (None, 28, 28, 256)       1024      
_________________________________________________________________
re_lu_129 (ReLU)             (None, 28, 28, 256)       0         
_________________________________________________________________
conv2d_68 (Conv2D)           (None, 28, 28, 256)       65792     
_________________________________________________________________
batch_normalization_130 (Bat (None, 28, 28, 256)       1024      
_________________________________________________________________
re_lu_130 (ReLU)             (None, 28, 28, 256)       0         
_________________________________________________________________
depthwise_conv2d_63 (Depthwi (None, 14, 14, 256)       2560      
_________________________________________________________________
batch_normalization_131 (Bat (None, 14, 14, 256)       1024      
_________________________________________________________________
re_lu_131 (ReLU)             (None, 14, 14, 256)       0         
_________________________________________________________________
conv2d_69 (Conv2D)           (None, 14, 14, 512)       131584    
_________________________________________________________________
batch_normalization_132 (Bat (None, 14, 14, 512)       2048      
_________________________________________________________________
re_lu_132 (ReLU)             (None, 14, 14, 512)       0         
_________________________________________________________________
depthwise_conv2d_64 (Depthwi (None, 14, 14, 512)       5120      
_________________________________________________________________
batch_normalization_133 (Bat (None, 14, 14, 512)       2048      
_________________________________________________________________
re_lu_133 (ReLU)             (None, 14, 14, 512)       0         
_________________________________________________________________
conv2d_70 (Conv2D)           (None, 14, 14, 512)       262656    
_________________________________________________________________
batch_normalization_134 (Bat (None, 14, 14, 512)       2048      
_________________________________________________________________
re_lu_134 (ReLU)             (None, 14, 14, 512)       0         
_________________________________________________________________
depthwise_conv2d_65 (Depthwi (None, 14, 14, 512)       5120      
_________________________________________________________________
batch_normalization_135 (Bat (None, 14, 14, 512)       2048      
_________________________________________________________________
re_lu_135 (ReLU)             (None, 14, 14, 512)       0         
_________________________________________________________________
conv2d_71 (Conv2D)           (None, 14, 14, 512)       262656    
_________________________________________________________________
batch_normalization_136 (Bat (None, 14, 14, 512)       2048      
_________________________________________________________________
re_lu_136 (ReLU)             (None, 14, 14, 512)       0         
_________________________________________________________________
depthwise_conv2d_66 (Depthwi (None, 14, 14, 512)       5120      
_________________________________________________________________
batch_normalization_137 (Bat (None, 14, 14, 512)       2048      
_________________________________________________________________
re_lu_137 (ReLU)             (None, 14, 14, 512)       0         
_________________________________________________________________
conv2d_72 (Conv2D)           (None, 14, 14, 512)       262656    
_________________________________________________________________
batch_normalization_138 (Bat (None, 14, 14, 512)       2048      
_________________________________________________________________
re_lu_138 (ReLU)             (None, 14, 14, 512)       0         
_________________________________________________________________
depthwise_conv2d_67 (Depthwi (None, 14, 14, 512)       5120      
_________________________________________________________________
batch_normalization_139 (Bat (None, 14, 14, 512)       2048      
_________________________________________________________________
re_lu_139 (ReLU)             (None, 14, 14, 512)       0         
_________________________________________________________________
conv2d_73 (Conv2D)           (None, 14, 14, 512)       262656    
_________________________________________________________________
batch_normalization_140 (Bat (None, 14, 14, 512)       2048      
_________________________________________________________________
re_lu_140 (ReLU)             (None, 14, 14, 512)       0         
_________________________________________________________________
depthwise_conv2d_68 (Depthwi (None, 14, 14, 512)       5120      
_________________________________________________________________
batch_normalization_141 (Bat (None, 14, 14, 512)       2048      
_________________________________________________________________
re_lu_141 (ReLU)             (None, 14, 14, 512)       0         
_________________________________________________________________
conv2d_74 (Conv2D)           (None, 14, 14, 512)       262656    
_________________________________________________________________
batch_normalization_142 (Bat (None, 14, 14, 512)       2048      
_________________________________________________________________
re_lu_142 (ReLU)             (None, 14, 14, 512)       0         
_________________________________________________________________
depthwise_conv2d_69 (Depthwi (None, 7, 7, 512)         5120      
_________________________________________________________________
batch_normalization_143 (Bat (None, 7, 7, 512)         2048      
_________________________________________________________________
re_lu_143 (ReLU)             (None, 7, 7, 512)         0         
_________________________________________________________________
conv2d_75 (Conv2D)           (None, 7, 7, 1024)        525312    
_________________________________________________________________
batch_normalization_144 (Bat (None, 7, 7, 1024)        4096      
_________________________________________________________________
re_lu_144 (ReLU)             (None, 7, 7, 1024)        0         
_________________________________________________________________
depthwise_conv2d_70 (Depthwi (None, 7, 7, 1024)        10240     
_________________________________________________________________
batch_normalization_145 (Bat (None, 7, 7, 1024)        4096      
_________________________________________________________________
re_lu_145 (ReLU)             (None, 7, 7, 1024)        0         
_________________________________________________________________
conv2d_76 (Conv2D)           (None, 7, 7, 1024)        1049600   
_________________________________________________________________
batch_normalization_146 (Bat (None, 7, 7, 1024)        4096      
_________________________________________________________________
re_lu_146 (ReLU)             (None, 7, 7, 1024)        0         
_________________________________________________________________
average_pooling2d_5 (Average (None, 1, 1, 1024)        0         
_________________________________________________________________
dense_4 (Dense)              (None, 1, 1, 10)          10250     
=================================================================
Total params: 3,249,482
Trainable params: 3,227,594
Non-trainable params: 21,888
_________________________________________________________________

因為這裡的類別只有10類,所以最後的輸出層只有10個神經元,原始的mobilenet要進行1000個類別分類,所以最後是1000個神經元。

model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])

上述程式碼定義優化演算法和損失函式。

3、訓練資料的整理與訓練

將訓練資料進行維度變換,標籤進行one-hot編碼並進行維度變換。

x_train = np.expand_dims(new_train,3)

y_train = to_categorical(y_train)

y=np.expand_dims(y_train,1)
y = np.expand_dims(y,1)
  • 定義資料生成函式
def data_generate(x_train,y_train,batch_size,epochs):
    for i in range(epochs):
        batch_num = len(x_train)//batch_size
        shuffle_index = np.arange(batch_num)
        np.random.shuffle(shuffle_index)
        for j in shuffle_index:
            begin = j*batch_size
            end =begin+batch_size
            x = x_train[begin:end]
            y = y_train[begin:end]
            
            yield ({"input_11":x},{"dense_4":y})
            

上述命名和model中的第一層和最後一層名字一樣,不然會報錯。

  • 開始訓練
model.fit_generator(data_generate(x_train,y,100,11),step_per_epoch=600,epochs=10)

訓練過程圖如下:

Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 1/10
Executing op __inference_keras_scratch_graph_22639 in device /job:localhost/replica:0/task:0/device:GPU:0
600/600 [==============================] - 411s 684ms/step - loss: 0.1469 - accuracy: 0.9529
Epoch 2/10
600/600 [==============================] - 398s 663ms/step - loss: 0.0375 - accuracy: 0.9884
Epoch 3/10
600/600 [==============================] - 401s 668ms/step - loss: 0.0283 - accuracy: 0.9909
Epoch 4/10
600/600 [==============================] - 399s 665ms/step - loss: 0.0211 - accuracy: 0.9936
Epoch 5/10
600/600 [==============================] - 400s 666ms/step - loss: 0.0216 - accuracy: 0.9932
Epoch 6/10
600/600 [==============================] - 401s 668ms/step - loss: 0.0208 - accuracy: 0.9935
Epoch 7/10
600/600 [==============================] - 401s 669ms/step - loss: 0.0174 - accuracy: 0.9945
Epoch 8/10
131/600 [=====>........................] - ETA: 5:13 - loss: 0.0091 - accuracy: 0.9973
​

模型卷積比較多,需要訓練的時間有點長,引數不多,所以更新較快,收斂速度也很快。

相關文章