【半監督學習】MixMatch、UDA、ReMixMatch、FixMatch

wuliytTaotao發表於2020-04-18

半監督學習(Semi-Supervised Learning,SSL)的 SOTA 一次次被 Google 重新整理,從 MixMatch 開始,到同期的 UDA、ReMixMatch,再到 2020 年的 FixMatch。

這四篇深度半監督學習方面的工作,都是從 consistency regularization 和 entropy minimization 兩方面入手:

  • consistency regularization:一致性約束,給輸入圖片或者中間層注入 noise,模型的輸出應該儘可能保持不變或者近似。
  • entropy minimization:最小化熵,模型在 unlabeled data 上的熵應該儘可能最小化。Pseudo label 也隱含地用到了 entropy minimization。

Consistency Regularization

對於每一個 unlabeled instance,consistency regularization 要求兩次隨機注入 noise 的輸出近似。背後的思想是,如果一個模型是魯棒的,那麼即使輸入有擾動,輸出也應該近似。

對於 consistency regularization 來說,如何注入 noise 以及如何計算近似,就是每個方法的不同之處。注入 noise 可以通過模型本身的隨機性(如 dropout)或者直接加入噪聲(如 Gaussian noise),也可以通過 data augmentation;計算一致性的方法,可以使用 L2,也可以使用 KL divergency、cross entropy。

Entropy Minimization

MixMatch、UDA 和 ReMixMatch 通過 temperature sharpening 來間接利用 entropy minimization,而 FixMatch 通過 Pseudo label 來間接利用 entropy minimization。可以認為,只要通過得到 unlabeled data 的人工標籤然後按照監督學習的方法(如 cross entropy loss)來訓練的,都間接用到了 entropy minimization。因為人工標籤都是 one-hot 或者近似 one-hot 的,如果 unlabeled data 的 prediction 近似人工標籤,那麼此時無標籤資料的熵肯定也是較小的。

為什麼這裡叫做人工標籤而不是偽標籤?一般而言,在半監督中,偽標籤(pseudo label)特指 hard label,即 one-hot 型別的或者通過 argmax 得到的。[4] 而 MixMatch、UDA、ReMixMatch 得到的人工標籤並不是 hard label。

Entropy minimization 可以在計算 unlabeled data 部分的 loss 和 consistency regularization 一起實現。

temperature sharpening 和 pseudo label 都得到了 unlabeled data 的人工標籤,當前者 temperature=0 時,兩者相等。pseudo label 要比 temperature sharpening 要簡單,因為少了一個 temperature 超引數。

如果不考慮 entropy minimization,那麼 temperature sharpening 和 pseudo label 其實都是不需要的,只需要兩次隨機注入 noise 的 unlabeled instance 輸出近似,就可以保證 consistency regularization。

或者說,得到 unlabeled data 的人工標籤,可以使得 entropy minimization 和 consistency regularization 通過一項 loss 來完成。

結合 Consistency Regularization 和 Entropy Minimization

一般來說,半監督學習中的 unlabeled data 會使用全部訓練資料集,即有標籤的樣本也會作為無標籤樣本來使用。

半監督學習中,labeled data 的標籤都是給定的,而 unlabeled data 的標籤都是不知道的。那麼如何獲得 unlabeled data 的人工標籤(artificial label),MixMatch、UDA、ReMixMatch 和 FixMatch 的做法或多或少都不相同:

  • MixMatch:平均 K 次 weak augmentation(如 shifting 和 flipping)的 predictions ,然後經過 temperature sharpening;
  • UDA:一次 weak augmentation 的 prediction,然後經過 temperature sharpening;
  • ReMixMatch:一次 weak augmentation 的 prediction,然後經過 distribution alignment,最後經過 temperature sharpening;
  • FixMatch:一次 weak augmentation 的 prediction,然後 argmax 得到 hard label(pseudo label)。
【半監督學習】MixMatch、UDA、ReMixMatch、FixMatch
Fig.1 MixMatch 人工標籤 (soft label)

得到了人工標籤,我們就可以按照監督學習的方式來訓練,這種思考方式就利用了 entropy minimization。而從 unlabeled data 的 consistency regularization 角度思考,我們需要注入不同的 noise,使得 unlabeled data 的 predictions 和它們的人工標籤一致。

MixMatch、UDA、ReMixMatch 和 FixMatch 都利用 data augmentation 改變輸入樣本來注入 noise,不同的是 data augmentation 的具體方式和強度:

  • MixMatch:一次 weak augmentation 得到 prediction,這就和正常的監督訓練一樣,只是 unlabeled loss 用的是 L2 而已;
  • UDA:一次 strong augmentation(RandAugment) 得到 prediction;
  • ReMixMatch:多次 strong augmentation(CTAugment)得到 predictions,然後同時參與 unlabeled loss 的計算,即一個 unlabeled instance 一個 step 多次增強後計算多次 loss;
  • FixMatch:一次 strong augmentation(RandAugment 或 CTAugment)得到 prediction。
【半監督學習】MixMatch、UDA、ReMixMatch、FixMatch
Fig.2 FixMatch 流程圖

從 UDA 和 ReMixMatch 開始,strong augmentation 引入了半監督訓練。UDA 使用了作者之前提出的 RandAugment 的 strong augmentation 方式,而 ReMixMatch 提出了一種 CTAugment。FixMatch 就把 UDA 和 ReMixMatch 中用到的 strong augmentation 都拿來用了一遍。

【半監督學習】MixMatch、UDA、ReMixMatch、FixMatch
Fig.3 weak augmentaion、strong augmentation 及 temperature sharpening 使用情況

對於 unlabeled data 部分的 loss:

  • MixMatch:L2 loss;
  • UDA:KL divergency;
  • ReMixMatch:cross entropy(包括自監督的 rotation loss 和沒有使用 mixup 的 pre-mixup unlabeled loss);
  • FixMatch:帶閾值的 cross entropy。

FixMatch: Simplifying SSL with Consistency and Confidence

FixMatch 簡化了 MixMatch、UDA 和 ReMixMatch,然後獲得了更好的效果:

  • 首先,temperature sharpening 換成 pseudo label,這是一個簡化;
  • 其次,FixMatch 通過設定一個閾值,在計算 unlabeled loss 時,對 prediction 的 confidence 超過閾值的 unlabeled instance 才算入 unlabeled loss,這樣使得 unlabeled loss 的權重可以固定,這是第二個簡化。
【半監督學習】MixMatch、UDA、ReMixMatch、FixMatch
Fig.4 Error rates for CIFAR-10, CIFAR-100 and SVHN on 5 different folds.

References

[1] Berthelot, D., Carlini, N., Goodfellow, I., Papernot, N., Oliver, A., Raffel, C. (2019). MixMatch: A Holistic Approach to Semi-Supervised Learning arXiv https://arxiv.org/abs/1905.02249
[2] Berthelot, D., Carlini, N., Cubuk, E., Kurakin, A., Sohn, K., Zhang, H., Raffel, C. (2019). ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring arXiv https://arxiv.org/abs/1911.09785
[3] Xie, Q., Dai, Z., Hovy, E., Luong, M., Le, Q. (2019). Unsupervised Data Augmentation for Consistency Training arXiv https://arxiv.org/abs/1904.12848
[4] Sohn, K., Berthelot, D., Li, C., Zhang, Z., Carlini, N., Cubuk, E., Kurakin, A., Zhang, H., Raffel, C. (2020). FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence arXiv https://arxiv.org/abs/2001.07685

相關文章