使用資料增強技術提升模型泛化能力

雲水木石發表於2019-01-17

在《提高模型效能,你可以嘗試這幾招...》一文中,我們給出了幾種提高模型效能的方法,但這篇文章是在訓練資料集不變的前提下提出的優化方案。其實對於深度學習而言,資料量的多寡通常對模型效能的影響更大,所以擴充資料規模一般情況是一個非常有效的方法。

對於Google、Facebook來說,收集幾百萬張圖片,訓練超大規模的深度學習模型,自然不在話下。但是對於個人或者小型企業而言,收集現實世界的資料,特別是帶標籤的資料,將是一件非常費時費力的事。本文探討一種技術,在現有資料集的基礎上,進行資料增強(data augmentation),增加參與模型訓練的資料量,從而提升模型的效能。

什麼是資料增強

所謂資料增強,就是採用在原有資料上隨機增加抖動和擾動,從而生成新的訓練樣本,新樣本的標籤和原始資料相同。這個也很好理解,對於一張標籤為“狗”的圖片,做一定的模糊、裁剪、變形等處理,並不會改變這張圖片的類別。資料增強也不僅侷限於圖片分類應用,比如有如下圖所示的資料,資料滿足正態分佈:

使用資料增強技術提升模型泛化能力

我們在資料集的基礎上,增加一些擾動處理,資料分佈如下:

使用資料增強技術提升模型泛化能力

資料就在原來的基礎上增加了幾倍,但整體上仍然滿足正態分佈。有人可能會說,這樣的出來的模型不是沒有原來精確了嗎?考慮到現實世界的複雜性,我們採集到的資料很難完全滿足正態分佈,所以這樣增加資料擾動,不僅不會降低模型的精確度,然而增強了泛化能力。

對於圖片資料而言,能夠做的資料增強的方法有很多,通常的方法是:

  • 平移
  • 旋轉
  • 縮放
  • 裁剪
  • 切變(shearing)
  • 水平/垂直翻轉
  • ...

上面幾種方法,可能切變(shearing)比較難以理解,看一張圖就明白了:

使用資料增強技術提升模型泛化能力

我們要親自編寫這些資料增強演算法嗎?通常不需要,比如keras就提供了批量處理圖片變形的方法。

keras中的資料增強方法

keras中提供了ImageDataGenerator類,其構造方法如下:

ImageDataGenerator(featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization = False,
    samplewise_std_normalization = False,
    zca_whitening = False,
    rotation_range = 0.,
    width_shift_range = 0.,
    height_shift_range = 0.,
    shear_range = 0.,
    zoom_range = 0.,
    channel_shift_range = 0.,
    fill_mode = 'nearest',
    cval = 0.0,
    horizontal_flip = False,
    vertical_flip = False,
    rescale = None,
    preprocessing_function = None,
    data_format = K.image_data_format(),
)
複製程式碼

引數很多,常用的引數有:

  • rotation_range: 控制隨機的度數範圍旋轉。
  • width_shift_range和height_shift_range: 分別用於水平和垂直移位。
  • zoom_range: 根據[1 - zoom_range,1 + zoom_range]範圍均勻將影象“放大”或“縮小”。
  • horizontal_flip:控制是否水平翻轉。

完整的引數說明請參考keras文件。

下面一段程式碼將1張給定的圖片擴充為10張,當然你還可以擴充更多:

image = load_img(args["image"])
image = img_to_array(image)
image = np.expand_dims(image, axis=0)

aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1, height_shift_range=0.1,
                         shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode="nearest")

aug.fit(image)

imageGen = aug.flow(image, batch_size=1, save_to_dir=args["output"], save_prefix=args["prefix"],
                    save_format="jpeg")

total = 0
for image in imageGen:
  # increment out counter
  total += 1

  if total == 10:
    break
複製程式碼

需要指出的是,上述程式碼的最後一個迭代是必須的,否在不會在output目錄下生成圖片,另外output目錄必須存在,否則會出現一下錯誤:

Traceback (most recent call last):
  File "augmentation_demo.py", line 35, in <module>
    for image in imageGen:
  File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1526, in __next__
    return self.next(*args, **kwargs)
  File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1704, in next
    return self._get_batches_of_transformed_samples(index_array)
  File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1681, in _get_batches_of_transformed_samples
    img.save(os.path.join(self.save_to_dir, fname))
  File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/PIL/Image.py", line 1947, in save
    fp = builtins.open(filename, "w+b")
FileNotFoundError: [Errno 2] No such file or directory: 'output/image_0_1091.jpeg'
複製程式碼

如下一張狗狗的圖片:

使用資料增強技術提升模型泛化能力

經過資料增強技術處理之後,可以得到如下10張形態稍微不同的狗狗的圖片,這相當於在原有資料集上增加了10倍的資料,其實我們還可以擴充得最多:

使用資料增強技術提升模型泛化能力

資料增強之後的比較

我們以MiniVGGNet模型為例,說明在其在17flowers資料集上進行訓練的效果。17flowers是一個非常小的資料集,包含17中品類的花卉圖案,每個品類包含80張圖片,這對於深度學習而言,資料量實在是太小了。一般而言,要讓深度學習模型有一定的精確度,每個類別的圖片至少需要1000~5000張。這樣的資料集可以很好的說明資料增強技術的必要性。

從網站上下載的17flowers資料,所有的圖片都放在一個目錄下,而我們通常訓練時的目錄結構為:

{類別名}/{圖片檔案}
複製程式碼

為此我寫了一個organize_flowers17.py指令碼。

在沒有使用資料增強的情況下,在訓練資料集和驗證資料集上精度、損失隨著訓練輪次的變化曲線圖:

使用資料增強技術提升模型泛化能力

可以看到,大約經過十幾輪的訓練,在訓練資料集上的準確率很快就達到了接近100%,然而在驗證資料集上的準確率卻無法再上升,只能達到60%左右。這個圖可以明顯的看出模型出現了非常嚴重的過擬合。

如果採用資料增強技術呢?曲線圖如下:

使用資料增強技術提升模型泛化能力

從圖中可以看到,雖然在訓練資料集上的準確率有所下降,但在驗證資料集上的準確率有比較明顯的提升,說明模型的泛化能力有所增強。

也許在我們看來,準確率從60%多增加到70%,只有10%的提升,並不是什麼了不得的成績。但要考慮到我們採用的資料集樣本數量實在是太少,能夠達到這樣的提升已經是非常難得,在實際專案中,有時為了提升1%的準確率,都會花費不少的功夫。

總結

資料增強技術在一定程度上能夠提高模型的泛化能力,減少過擬合,但在實際中,我們如果能夠收集到更多真實的資料,還是要儘量使用真實資料。另外,資料增強只需應用於訓練資料集,驗證集上則不需要,畢竟我們希望在驗證集上測試真實資料的準確。

以上例項均有完整的程式碼,點選閱讀原文,跳轉到我在github上建的示例程式碼。

另外,我在閱讀《Deep Learning for Computer Vision with Python》這本書,在微信公眾號後臺回覆“計算機視覺”關鍵字,可以免費下載這本書的電子版。

參考閱讀

提高模型效能,你可以嘗試這幾招...

計算機視覺與深度學習,看這本書就夠了

使用資料增強技術提升模型泛化能力

相關文章