步驟
- 讀取本地的圖片資料以及類別
- 模型的結構修改(
新增我們自定的分類層
) freeze掉原始VGG模型
編譯以及訓練和儲存模型方式
輸入資料進行預測
讀取本地圖片
ImageDataGenerator:生產圖片的批次張量值並且提供資料增強功能
引數:
- rescale=1.0 / 255,:標準化
- zca_whitening=False: # zca白化的作用是針對圖片進行PCA降維操作,減少圖片的冗餘資訊
- rotation_range=20:預設0, 旋轉角度,在這個角度範圍隨機生成一個值
- width_shift_range=0.2,:預設0,水平平移
- height_shift_range=0.2:預設0, 垂直平移
- shear_range=0.2:# 平移變換
- zoom_range=0.2:
- horizontal_flip=True:水平翻轉
使用flow
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
gen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
for e in range(epochs):
print('Epoch', e)
batches = 0
for x_batch, y_batch in gen.flow(x_train, y_train, batch_size=32):
model.fit(x_batch, y_batch)
使用flow_from_directory
- irectory=path,# 讀取目錄
- target_size=(h,w),# 目標形狀
- batch_size=size,# 批數量大小
- class_mode=’binary’, # 目標值格式,One of “categorical”, “binary”, “sparse”,
- “categorical” :2D one-hot encoded labels
- “binary” will be 1D binary labels
- shuffle=True
這個API固定了讀取的目錄格式,參考:
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
'data/validation',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
# 使用fit_generator
model.fit_generator(
train_generator,
steps_per_epoch=2000,
epochs=50,
validation_data=validation_generator,
validation_steps=800)
VGG模型的修改
notop模型:不包含最後的3個全連線層。用來做fine-tuning專用,專門開源了這類模型。
# 在__init__中新增
self.base_model = VGG16(weights='imagenet', include_top=False)
做法:一個GlobalAveragePooling2D + 兩個全連線層
如下:
from keras.layers import Dense, Input, Conv2D
from keras.layers import MaxPooling2D, GlobalAveragePooling2D
x = Input(shape=[8, 8, 2048])
# 假定最後一層CNN的層輸出為(None, 8, 8, 2048)
x = GlobalAveragePooling2D(name='avg_pool')(x) # shape=(?, 2048)
# 取每一個特徵圖的平均值作為輸出,用以替代全連線層
x = Dense(1000, activation='softmax', name='predictions')(x) # shape=(?, 1000)
freeze 模型
讓VGG結構當中的權重引數不參與訓練,只訓練我們新增的最後兩層全連線網路的權重引數。
通過使用每一層的layer.trainable=False
def freeze_vgg_model(self):
for layer in self.base_model.layers:
layer.trainable = False
編譯和訓練
在遷移學習中演算法:學習率初始化較小的值,0.001,0.0001,因為已經在已訓練好的模型基礎之上更新,所以不需要太大學習率去學習
def compile(self, model):
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
使用ModelCheckpoint指定相關引數:
calls = keras.callbacks.ModelCheckpoint(
filepath='./snn_model/transfer-{epoch:02d}-{acc:.2f}.h5',
monitor='val_acc',
save_best_only=True,
save_weights_only=True,
mode='auto',
period=1
)
fine_model.fit_generator(train_g, epochs=3, validation_data=test_g, callbacks=[calls])
預測
讀取圖片以及處理到模型中預測,載入我們訓練的模型
def predict(self, model):
model.load_weights("./Transfer.h5")
# 2、對圖片進行載入和型別修改
image = load_img("./data/test/dinosaurs/402.jpg", target_size=(224, 224))
# 轉換成numpy array陣列
image = img_to_array(image)
# 形狀從3維度修改成4維
img = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
print("改變形狀結果:", img.shape)
# 3、處理影像內容,歸一化處理等,進行預測
img = preprocess_input(img)
y_predict = model.predict(img)
index = np.argmax(y_predict, axis=1)
print(self.label_dict[str(index[0])])
完整程式碼
import numpy as np
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python import keras
import tensorflow as tf
from tensorflow.python.keras.applications.vgg16 import VGG16
from tensorflow.python.keras.preprocessing.image import load_img, img_to_array
from tensorflow.python.keras.applications.vgg16 import preprocess_input, decode_predictions
class Transfer(object):
def __init__(self):
# 定義資料定義方式
self.train_generator = ImageDataGenerator(rescale=1.0 / 255.0)
self.test_generator = ImageDataGenerator(rescale=1.0 / 255.0)
self.train_dir = "./data/train"
self.test_dir = "./data/test"
# VGG16不包含全連結層模型
self.base_model = VGG16(weights='imagenet', include_top=False)
self.label_dict = {
'0': 'bus',
'1': 'dinosaurs',
'2': 'elephants',
'3': 'flowers',
'4': 'horse'
}
pass
def get_data(self):
train_g = self.train_generator.flow_from_directory(
self.train_dir,
target_size=(224, 224),
class_mode='binary',
batch_size=32,
)
test_g = self.test_generator.flow_from_directory(
self.test_dir,
target_size=(224, 224),
class_mode='binary',
batch_size=32,
)
return train_g, test_g
def refine_model(self):
# 1、獲取原notop模型得出
x = self.base_model.outputs[0]
# 2、在輸出後面增加我們結構
x = keras.layers.GlobalAveragePooling2D()(x)
# 新的遷移模型
x = keras.layers.Dense(1024, activation=tf.nn.relu)(x)
y_p = keras.layers.Dense(5, activation=tf.nn.softmax)(x)
fine_model = keras.models.Model(inputs=self.base_model.inputs,
outputs=y_p)
return fine_model
def freeze_model(self):
# 凍結模型,不訓練
for layer in self.base_model.layers:
layer.trainable = False
def compile(self, model):
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']
)
def fit(self, fine_model, train_g, test_g):
calls = keras.callbacks.ModelCheckpoint(
filepath='./snn_model/transfer-{epoch:02d}-{acc:.2f}.h5',
monitor='val_acc',
save_best_only=True,
save_weights_only=True,
mode='auto',
period=1
)
fine_model.fit_generator(train_g, epochs=3, validation_data=test_g, callbacks=[calls])
def predict(self, model):
# 載入我們自己模型
model.load_weights("./snn_model/transfer-03-0.98.h5")
# 讀取圖片
img = load_img("./data/test/bus/300.jpg", target_size=(224, 224))
image = img_to_array(img)
# 四維(224,224,3)—>(1,224,224,3)
img = image.reshape([1, image.shape[0], image.shape[1], image.shape[2]])
# 歸一化處理
image = preprocess_input(img)
y_p = model.predict(image)
# print(y_p)
# 解碼
# label = decode_predictions(y_p)
res = np.argmax(y_p, axis=1)
print(f"預測了類別為:{self.label_dict[str(res[0])]}")
def train(cnn):
train_g, test_g = cnn.get_data()
model = cnn.refine_model()
cnn.freeze_model()
cnn.compile(model)
cnn.fit(model, train_g, test_g)
def use_train(cnn):
model = cnn.refine_model()
cnn.predict(model)
if __name__ == '__main__':
cnn = Transfer()
# train(cnn)
use_train(cnn)
旋律一張汽車的圖片:
本作品採用《CC 協議》,轉載必須註明作者和本文連結