【火爐煉AI】深度學習006-移花接木-用Keras遷移學習提升效能

zybing發表於2021-09-09

【火爐煉AI】深度學習006-移花接木-用Keras遷移學習提升效能

(本文所使用的Python庫和版本號: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2, Keras 2.1.6, Tensorflow 1.9.0)

上一篇文章我們用自己定義的模型來解決了二分類問題,在20個回合的訓練之後得到了大約74%的準確率,一方面是我們的epoch太小的原因,另外一方面也是由於模型太簡單,結構簡單,故而不能做太複雜的事情,那麼怎麼提升預測的準確率了?一個有效的方法就是遷移學習。

遷移學習其本質就是移花接木:將其他大資料集(比如ImageNet等)上得到的網路結構及其weights應用到一個新的專案中來,比如此處的貓狗二分類問題。當然,ImageNet中肯定有貓和狗這兩個類別,可以說此處的小資料集是ImageNet的一個子集,但是,對於和ImageNet完全沒有任何關係的其他資料集,遷移學習也有一定的效果,當然,對於兩個資料集的相關性比較差的資料集,使用遷移學習可能效果不太好。

具體做法是:使用一個成熟的網路結構(比如此處用VGG16)和引數,把它的全連線層全部都去掉,只保留卷積層,這些卷積層可以看成是圖片的特徵提取器(得到的特徵被稱為bottleneck features),而全連線層是分類器,對這些圖片的特徵進行有效分類。對於新專案,我們要分類的類別數目並不是ImageNet的1000類,而是比如此處的2類。故而分類器對我們毫無用處,我們需要建立和訓練自己的分類器。如下為VGG16網路的結構:

圖片描述

image

其中的Conv block 1-5 都是卷積層和池化層,組成了圖片的特徵提取器,而後面的Flatten和Dense組成了分類器。

此處我們將Conv block 1-5的結構和引數都移接過來,在組裝上自己的分類器即可。

在訓練時,我們可以先我上一篇博文一樣,建立圖片資料流,將圖片資料流匯入到VGG16模型中提取特徵,然後將這些特徵送入到自定義的分類器中訓練,最佳化自定義分類器的引數,但是這種方式訓練速度很慢,此處我們用VGG16的卷積層統一提取所有圖片的特徵,將這些特徵儲存,然後直接載入特徵來訓練,載入數字比載入圖片要快的多,故而訓練也快得多。

我這篇博文主要參考了:,這篇博文也是參考的Building powerful image classification models using very little data,但我發現這兩篇博文有很多地方的程式碼跑不起來,主要原因可能是Keras或Tensorflow升級造成的,所以我做了一些必要的修改。


1. 準備資料集

首先使用預訓練好的模型VGG16來提取train set和test set圖片的特徵,然後將這些特徵儲存,這些特徵實際上就是numpy.ndarray,故而可以儲存為數字,然後載入這些數字來訓練。

# 此處的訓練集和測試集並不是原始圖片的train set和test set,而是用VGG16對圖片提取的特徵,這些特徵組成新的train set和test setfrom keras.preprocessing.image import ImageDataGeneratorfrom keras.models import Sequentialfrom keras.layers import Dropout, Flatten, Densefrom keras import applicationsdef save_bottlebeck_features():
    datagen = ImageDataGenerator(rescale=1. / 255) # 不需圖片增強

    # build the VGG16 network
    model = applications.VGG16(include_top=False, weights='imagenet') 
    # 使用imagenet的weights作為VGG16的初始weights,由於只是特徵提取,故而只取前面的卷積層而不需要DenseLayer,故而include_top=False

    generator = datagen.flow_from_directory( # 產生train set
        train_data_dir,
        target_size=(IMG_W, IMG_H),
        batch_size=batch_size,
        class_mode=None, 
        shuffle=False) # 必須為False,否則順序打亂之後,和後面的label對應不上。
    bottleneck_features_train = model.predict_generator(
        generator, train_samples_num // batch_size) # 如果是32,這個除法得到的是62,拋棄了小數,故而得到1984個sample
    np.save('E:PyProjectsDataSetFireAIDeepLearningFireAI006/bottleneck_features_train.npy', bottleneck_features_train)
    print('bottleneck features of train set is saved.')

    generator = datagen.flow_from_directory(
        val_data_dir,
        target_size=(IMG_W, IMG_H),
        batch_size=batch_size,
        class_mode=None,
        shuffle=False)
    bottleneck_features_validation = model.predict_generator(
        generator, val_samples_num // batch_size)
    np.save('E:PyProjectsDataSetFireAIDeepLearningFireAI006/bottleneck_features_val.npy',bottleneck_features_validation)
    print('bottleneck features of test set is saved.')

經過上面的程式碼,trainset圖片集的特徵被儲存到E:PyProjectsDataSetFireAIDeepLearningFireAI006/bottleneck_features_train.npy檔案中,而test set的特徵也被儲存到../bottleneck_features_val.npy中。


2. 構建模型並訓練

很顯然,此處我們並不要提取圖片的各種特徵,前面的VGG16已經幫我們做完了這件事,所以我們只需要對這些特徵進行分類即可,所以相當於我們只建立一個分類器模型就可以。

用keras建立一個簡單的二分類模型,如下:

def my_model():
    '''
    自定義一個模型,該模型僅僅相當於一個分類器,只包含有全連線層,對提取的特徵進行分類即可
    :return:
    '''
    # 模型的結構
    model = Sequential()
    model.add(Flatten(input_shape=train_data.shape[1:])) # 將所有data進行flatten
    model.add(Dense(256, activation='relu')) # 256個全連線單元
    model.add(Dropout(0.5)) # dropout正則
    model.add(Dense(1, activation='sigmoid')) # 此處定義的模型只有後面的全連線層,由於是本專案特殊的,故而需要自定義

    # 模型的配置
    model.compile(optimizer='rmsprop',
                  loss='binary_crossentropy', metrics=['accuracy']) # model的optimizer等

    return model

模型雖然建立好了,但我們要訓練裡面的引數,使用剛剛VGG16提取的特徵來進行訓練:

# 只需要訓練分類器模型即可,不需要訓練特徵提取器train_data = np.load('E:PyProjectsDataSetFireAIDeepLearningFireAI006/bottleneck_features_train.npy') # 載入訓練圖片集的所有圖片的VGG16-notop特徵train_labels = np.array(
    [0] * int((train_samples_num / 2)) + [1] * int((train_samples_num / 2)))# label是1000個cat,1000個dog,由於此處VGG16特徵提取時是按照順序,故而[0]表示cat,1表示dogvalidation_data = np.load('E:PyProjectsDataSetFireAIDeepLearningFireAI006/bottleneck_features_val.npy')
validation_labels = np.array(
    [0] * int((val_samples_num / 2)) + [1] * int((val_samples_num / 2)))# 構建分類器模型clf_model=my_model()
history_ft = clf_model.fit(train_data, train_labels,
              epochs=epochs,
              batch_size=batch_size,
              validation_data=(validation_data, validation_labels))

-------------------------------------輸---------出--------------------------------

Train on 2000 samples, validate on 800 samples
Epoch 1/20
2000/2000 [==============================] - 6s 3ms/step - loss: 0.8426 - acc: 0.7455 - val_loss: 0.4280 - val_acc: 0.8063
Epoch 2/20
2000/2000 [==============================] - 5s 3ms/step - loss: 0.3928 - acc: 0.8365 - val_loss: 0.3078 - val_acc: 0.8675
Epoch 3/20
2000/2000 [==============================] - 5s 3ms/step - loss: 0.3144 - acc: 0.8720 - val_loss: 0.4106 - val_acc: 0.8588

.......

Epoch 18/20
2000/2000 [==============================] - 5s 3ms/step - loss: 0.0479 - acc: 0.9830 - val_loss: 0.5380 - val_acc: 0.9025
Epoch 19/20
2000/2000 [==============================] - 5s 3ms/step - loss: 0.0600 - acc: 0.9775 - val_loss: 0.5357 - val_acc: 0.8988
Epoch 20/20
2000/2000 [==============================] - 5s 3ms/step - loss: 0.0551 - acc: 0.9810 - val_loss: 0.6057 - val_acc: 0.8825

--------------------------------------------完-------------------------------------

將訓練過程中的loss和acc繪圖如下:

圖片描述

image

很顯然,在第5個epoch之後,train set和test set出現了很明顯的分離,表明後面出現了比較強烈的過擬合,但是在test set上的準確率仍然有90%左右。

可以看出,相對上一篇文章我們自己定義的三層卷積層+兩層全連線層的網路結構,用VGG16網路結構的方法得到的準確率更高一些,而且訓練所需要的時間也更少。

注意一點:此處我們並沒有訓練VGG16中的任何引數,而僅僅訓練自己定義的分類器模型中的引數。

########################小**********結###############################

1,遷移學習就是使用已經存在的模型及其引數,使用該模型來提取圖片的特徵,然後構建自己的分類器,對這些特徵進行分類即可。

2,此處我們並沒有訓練已存在模型的結構和引數,僅僅是訓練自定義的分類器,如果要訓練已存在模型的引數,那就是微調(Fine-tune)的範疇了

#################################################################


注:本部分程式碼已經全部上傳到()上,歡迎下載。



作者:煉丹老頑童
連結:


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/855/viewspace-2817403/,如需轉載,請註明出處,否則將追究法律責任。

相關文章