深度學習 | 訓練網路trick——mixup

Cindy's發表於2020-08-09

1.mixup原理介紹

mixup 論文地址
mixup是一種非常規的資料增強方法,一個和資料無關的簡單資料增強原則,其以線性插值的方式來構建新的訓練樣本和標籤。最終對標籤的處理如下公式所示,這很簡單但對於增強策略來說又很不一般。


,兩個資料對是原始資料集中的訓練樣本對(訓練樣本和其對應的標籤)。其中是一個服從B分佈的引數,
。Beta分佈的概率密度函式如下圖所示,其中

因此,α 是一個超引數,隨著α的增大,網路的訓練誤差就會增加,而其泛化能力會隨之增強。而當 α→∞ 時,模型就會退化成最原始的訓練策略。

2.mixup的程式碼實現

如下程式碼所示,實現mixup資料增強很簡單,其實我個人認為這就是一種抑制過擬合的策略,增加了一些擾動,從而提升了模型的泛化能力。

def get_batch(x, y, step, batch_size, alpha=0.2):
    """
    get batch data
    :param x: training data
    :param y: one-hot label
    :param step: step
    :param batch_size: batch size
    :param alpha: hyper-parameter α, default as 0.2
    :return:
    """
    candidates_data, candidates_label = x, y
    offset = (step * batch_size) % (candidates_data.shape[0] - batch_size)

    # get batch data
    train_features_batch = candidates_data[offset:(offset + batch_size)]
    train_labels_batch = candidates_label[offset:(offset + batch_size)]

    # 最原始的訓練方式
    if alpha == 0:
        return train_features_batch, train_labels_batch
    # mixup增強後的訓練方式
    if alpha > 0:
        weight = np.random.beta(alpha, alpha, batch_size)
        x_weight = weight.reshape(batch_size, 1, 1, 1)
        y_weight = weight.reshape(batch_size, 1)
        index = np.random.permutation(batch_size)
        x1, x2 = train_features_batch, train_features_batch[index]
        x = x1 * x_weight + x2 * (1 - x_weight)
        y1, y2 = train_labels_batch, train_labels_batch[index]
        y = y1 * y_weight + y2 * (1 - y_weight)
        return x, y

3.mixup增強效果展示

import matplotlib.pyplot as plt
import matplotlib.image as Image
import numpy as np

im1 = Image.imread(r"C:\Users\Daisy\Desktop\1\xyjy.png")
im2 = Image.imread(r"C:\Users\Daisy\Desktop\1\xyjy2.png")
for i in range(1,10):
    lam= i*0.1
    im_mixup = (im1*lam+im2*(1-lam))
    plt.subplot(3,3,i)
    plt.imshow(im_mixup)
plt.show()

————————————————————
後來又發現一篇好文:https://www.zhihu.com/question/308572298?sort=created

相關文章