再談遷移學習:微調網路

雲水木石發表於2019-01-31

在《站在巨人的肩膀上:遷移學習》一文中,我們談到了一種遷移學習方法:將預訓練的卷積神經網路作為特徵提取器,然後使用一個標準的機器學習分類模型(比如Logistic迴歸),以所提取的特徵進行訓練,得到分類器,這個過程相當於用預訓練的網路取代上一代的手工特徵提取方法。這種遷移學習方法,在較小的資料集(比如17flowers)上也能取得不錯的準確率。

在那篇文章中,我還提到了另外一種遷移學習:微調網路,這篇文章就來談談微調網路。

所謂微調網路,相當於給預訓練的模型做一個“換頭術”,即“切掉”最後的全連線層(可以想象為卷積神經網路的“頭部”),然後接上一個新的引數隨機初始化的全連線層,接下來我們在這個動過“手術”的卷積神經網路上用我們比較小的資料集進行訓練。

在新模型上進行訓練,有幾點需要注意:

  1. 開始訓練時,“頭部”以下的層(也就是沒有被替換的網路層)的引數需要固定(frozen),也就是進行前向計算,但反向傳遞時不更新引數,訓練過程只更新新替換上的全連線層的引數。
  2. 使用一個非常小的學習率進行訓練,比如0.001
  3. 最後,作為可選,在全連線層的引數學習得差不多的時候,我們可以將“頭部”以下的層解凍(unfrozen),再整體訓練整個網路。

特徵提取和微調網路

對照一下上一篇文章中的特徵提取,我們以直觀的圖形來展現它們之間的不同:

如果我們在VGG16預訓練模型上進行特徵提取,其結構如下圖所示:

再談遷移學習:微調網路

對比原模型結構,從最後一個卷積池化層直接輸出,即特徵提取。而微調網路則如下圖所示:

再談遷移學習:微調網路

通常情況下,新替換的全連線層引數要比原來的全連線層引數要少,因為我們是在比較小的資料集上進行訓練。訓練過程通常分兩個階段,第一階段固定卷積層的引數,第二階段則全放開:

再談遷移學習:微調網路

相位元徵提取這種遷移學習方法,網路微調通常能得到更高的準確度。但記住,天下沒有免費的午餐這個原則,微調網路需要做更多的工作:

首先訓練時間很長,相位元徵提取只做前向運算,然後訓練一個簡單的Logisitic迴歸演算法,速度很快,微調網路因為是在很深的網路模型上訓練,特別是第二階段要進行全面的反向傳遞,耗時更長。在我的GTX 960顯示卡上用17flowers資料集,訓練了幾個小時,還沒有訓練完,結果我睡著了:(,也不知道最終花了多長時間。

其次需要調整的超引數比較多,比如選擇多大的學習率,全連線網路設計多少個節點比較合適,這個依賴經驗,但在某些特別的資料集上,可能需要多嘗試幾次才能得到比較好的結果。

網路層及索引

在“動手術”之前,我們需要了解模型的結構,最起碼我們需要知道層的名稱及索引,沒有這些資訊,就如同盲人拿起手術刀。在keras中,要了解層的資訊非常簡單:

print("[INFO] loading network ...")
model = VGG16(weights="imagenet", include_top=args["include_top"] > 0)

print("[INFO] showing layers ...")
for (i, layer) in enumerate(model.layers):
  print("[INFO] {}\t{}".format(i, layer.__class__.__name__))
複製程式碼

VGG16的模型結構如下:

再談遷移學習:微調網路

可以看到第20 ~ 22層為全連線層,這也是微調網路要替換的層。

網路“換頭術”

首先,我們定義一組全連線層:INPUT => FC => RELU => DO => FC => SOFTMAX。相比VGG16中的全連線層,這個更加簡單,引數更少。然後將基本模型的輸出作為模型的輸入,完成拼接。這個拼接在keras中也相當簡單。

class FCHeadNet:
  @staticmethod
  def build(base_model, classes, D):
    # initialize the head model that will be placed on top of
    # the base, then add a FC layer
    head_model = base_model.output
    head_model = Flatten(name="flatten")(head_model)
    head_model = Dense(D, activation="relu")(head_model)
    head_model = Dropout(0.5)(head_model)

    # add a softmax layer
    head_model = Dense(classes, activation="softmax")(head_model)

    return head_model
複製程式碼

因為在VGG16的建構函式中有一個include_top引數,可以決定是否包含頭部的全連線層,所以這個“換頭”步驟相當簡單:

base_model = VGG16(weights="imagenet", include_top=False, input_tensor=Input(shape=(224, 224, 3)))

# initialize the new head of the network, a set of FC layers followed by a softmax classifier
head_model = FCHeadNet.build(base_model, len(class_names), 256)

# place the head FC model on top of the base model -- this will become the actual model we will train
model = Model(inputs=base_model.input, outputs=head_model)
複製程式碼

這時得到的model就是經過“換頭術”的網路模型。

訓練

微調網路的訓練和之前談到的模型訓練過程差不多,只是多了一個freeze層的動作,實際上是進行兩個訓練過程。如何固定層的引數呢?一句話就可以搞定:

for layer in base_model.layers:
  layer.trainable = False
複製程式碼

“解凍”類似,只是layer.trainable值設為True。

為了更快的收斂,儘快的學習到全連線層的引數,在第一階段建議採用RMSprop優化器。但學習率需要選擇一個比較小的值,例如0.001。

在經過一個相當長時間的訓練之後,新模型在17flowers資料集上的結果如下:

[INFO] evaluating after fine-tuning ...
              precision    recall  f1-score   support

    bluebell       0.95      0.95      0.95        20
   buttercup       1.00      0.90      0.95        20
   coltsfoot       0.95      0.91      0.93        22
     cowslip       0.93      0.87      0.90        15
      crocus       1.00      1.00      1.00        23
    daffodil       0.92      1.00      0.96        23
       daisy       1.00      0.94      0.97        16
   dandelion       0.94      0.94      0.94        16
  fritillary       1.00      0.95      0.98        21
        iris       0.96      0.96      0.96        27
  lilyvalley       0.94      0.89      0.91        18
       pansy       0.90      0.95      0.92        19
    snowdrop       0.86      0.95      0.90        20
   sunflower       0.95      1.00      0.98        20
   tigerlily       0.96      0.96      0.96        23
       tulip       0.70      0.78      0.74        18
  windflower       1.00      0.95      0.97        19

   micro avg       0.94      0.94      0.94       340
   macro avg       0.94      0.93      0.94       340
weighted avg       0.94      0.94      0.94       340
複製程式碼

相位元徵提取這種遷移學習,準確率有了相當可觀的提升。

小結

網路微調是一項非常強大的技術,我們無需從頭開始訓練整個網路。相反,我們可以利用預先存在的網路架構,例如在ImageNet資料集上訓練的最先進模型,該模型由豐富的過濾器組成。使用這些過濾器,我們可以“快速啟動”我們的學習,使我們能夠進行網路手術,最終得到更高精度的遷移學習模型,而不是從頭開始訓練,而且工作量少。

以上例項均有完整的程式碼,點選閱讀原文,跳轉到我在github上建的示例程式碼。

另外,我在閱讀《Deep Learning for Computer Vision with Python》這本書,在微信公眾號後臺回覆“計算機視覺”關鍵字,可以免費下載這本書的電子版。

往期回顧

  1. 站在巨人的肩膀上:遷移學習
  2. 聊一聊rank-1和rank-5準確度
  3. 使用資料增強技術提升模型泛化能力

image

相關文章