在《提高模型效能,你可以嘗試這幾招...》一文中,我們給出了幾種提高模型效能的方法,但這篇文章是在訓練資料集不變的前提下提出的優化方案。其實對於深度學習而言,資料量的多寡通常對模型效能的影響更大,所以擴充資料規模一般情況是一個非常有效的方法。
對於Google、Facebook來說,收集幾百萬張圖片,訓練超大規模的深度學習模型,自然不在話下。但是對於個人或者小型企業而言,收集現實世界的資料,特別是帶標籤的資料,將是一件非常費時費力的事。本文探討一種技術,在現有資料集的基礎上,進行資料增強(data augmentation),增加參與模型訓練的資料量,從而提升模型的效能。
什麼是資料增強
所謂資料增強,就是採用在原有資料上隨機增加抖動和擾動,從而生成新的訓練樣本,新樣本的標籤和原始資料相同。這個也很好理解,對於一張標籤為“狗”的圖片,做一定的模糊、裁剪、變形等處理,並不會改變這張圖片的類別。資料增強也不僅侷限於圖片分類應用,比如有如下圖所示的資料,資料滿足正態分佈:
我們在資料集的基礎上,增加一些擾動處理,資料分佈如下:
資料就在原來的基礎上增加了幾倍,但整體上仍然滿足正態分佈。有人可能會說,這樣的出來的模型不是沒有原來精確了嗎?考慮到現實世界的複雜性,我們採集到的資料很難完全滿足正態分佈,所以這樣增加資料擾動,不僅不會降低模型的精確度,然而增強了泛化能力。
對於圖片資料而言,能夠做的資料增強的方法有很多,通常的方法是:
- 平移
- 旋轉
- 縮放
- 裁剪
- 切變(shearing)
- 水平/垂直翻轉
- ...
上面幾種方法,可能切變(shearing)比較難以理解,看一張圖就明白了:
我們要親自編寫這些資料增強演算法嗎?通常不需要,比如keras就提供了批量處理圖片變形的方法。
keras中的資料增強方法
keras中提供了ImageDataGenerator類,其構造方法如下:
ImageDataGenerator(featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization = False,
samplewise_std_normalization = False,
zca_whitening = False,
rotation_range = 0.,
width_shift_range = 0.,
height_shift_range = 0.,
shear_range = 0.,
zoom_range = 0.,
channel_shift_range = 0.,
fill_mode = 'nearest',
cval = 0.0,
horizontal_flip = False,
vertical_flip = False,
rescale = None,
preprocessing_function = None,
data_format = K.image_data_format(),
)
複製程式碼
引數很多,常用的引數有:
- rotation_range: 控制隨機的度數範圍旋轉。
- width_shift_range和height_shift_range: 分別用於水平和垂直移位。
- zoom_range: 根據[1 - zoom_range,1 + zoom_range]範圍均勻將影象“放大”或“縮小”。
- horizontal_flip:控制是否水平翻轉。
完整的引數說明請參考keras文件。
下面一段程式碼將1張給定的圖片擴充為10張,當然你還可以擴充更多:
image = load_img(args["image"])
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1, height_shift_range=0.1,
shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode="nearest")
aug.fit(image)
imageGen = aug.flow(image, batch_size=1, save_to_dir=args["output"], save_prefix=args["prefix"],
save_format="jpeg")
total = 0
for image in imageGen:
# increment out counter
total += 1
if total == 10:
break
複製程式碼
需要指出的是,上述程式碼的最後一個迭代是必須的,否在不會在output目錄下生成圖片,另外output目錄必須存在,否則會出現一下錯誤:
Traceback (most recent call last):
File "augmentation_demo.py", line 35, in <module>
for image in imageGen:
File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1526, in __next__
return self.next(*args, **kwargs)
File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1704, in next
return self._get_batches_of_transformed_samples(index_array)
File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1681, in _get_batches_of_transformed_samples
img.save(os.path.join(self.save_to_dir, fname))
File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/PIL/Image.py", line 1947, in save
fp = builtins.open(filename, "w+b")
FileNotFoundError: [Errno 2] No such file or directory: 'output/image_0_1091.jpeg'
複製程式碼
如下一張狗狗的圖片:
經過資料增強技術處理之後,可以得到如下10張形態稍微不同的狗狗的圖片,這相當於在原有資料集上增加了10倍的資料,其實我們還可以擴充得最多:
資料增強之後的比較
我們以MiniVGGNet模型為例,說明在其在17flowers資料集上進行訓練的效果。17flowers是一個非常小的資料集,包含17中品類的花卉圖案,每個品類包含80張圖片,這對於深度學習而言,資料量實在是太小了。一般而言,要讓深度學習模型有一定的精確度,每個類別的圖片至少需要1000~5000張。這樣的資料集可以很好的說明資料增強技術的必要性。
從網站上下載的17flowers資料,所有的圖片都放在一個目錄下,而我們通常訓練時的目錄結構為:
{類別名}/{圖片檔案}
複製程式碼
為此我寫了一個organize_flowers17.py指令碼。
在沒有使用資料增強的情況下,在訓練資料集和驗證資料集上精度、損失隨著訓練輪次的變化曲線圖:
可以看到,大約經過十幾輪的訓練,在訓練資料集上的準確率很快就達到了接近100%,然而在驗證資料集上的準確率卻無法再上升,只能達到60%左右。這個圖可以明顯的看出模型出現了非常嚴重的過擬合。
如果採用資料增強技術呢?曲線圖如下:
從圖中可以看到,雖然在訓練資料集上的準確率有所下降,但在驗證資料集上的準確率有比較明顯的提升,說明模型的泛化能力有所增強。
也許在我們看來,準確率從60%多增加到70%,只有10%的提升,並不是什麼了不得的成績。但要考慮到我們採用的資料集樣本數量實在是太少,能夠達到這樣的提升已經是非常難得,在實際專案中,有時為了提升1%的準確率,都會花費不少的功夫。
總結
資料增強技術在一定程度上能夠提高模型的泛化能力,減少過擬合,但在實際中,我們如果能夠收集到更多真實的資料,還是要儘量使用真實資料。另外,資料增強只需應用於訓練資料集,驗證集上則不需要,畢竟我們希望在驗證集上測試真實資料的準確。
以上例項均有完整的程式碼,點選閱讀原文,跳轉到我在github上建的示例程式碼。
另外,我在閱讀《Deep Learning for Computer Vision with Python》這本書,在微信公眾號後臺回覆“計算機視覺”關鍵字,可以免費下載這本書的電子版。