TPAMI 2024 | ProCo: 無限contrastive pairs的長尾對比學習

机器之心發表於2024-07-25
圖片
AIxiv專欄是機器之心釋出學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯絡報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com

本論文第一作者杜超群是清華大學自動化系 2020 級直博生。導師為黃高副教授。此前於清華大學物理系獲理學學士學位。研究興趣為不同資料分佈上的模型泛化和魯棒性研究,如長尾學習,半監督學習,遷移學習等。在 TPAMI、ICML 等國際一流期刊、會議上發表多篇論文。

個人主頁:https://andy-du20.github.io

本文介紹清華大學的一篇關於長尾視覺識別的論文: Probabilistic Contrastive Learning for Long-Tailed Visual Recognition. 該工作已被 TPAMI 2024 錄用,程式碼已開源。

該研究主要關注對比學習在長尾視覺識別任務中的應用,提出了一種新的長尾對比學習方法 ProCo,透過對 contrastive loss 的改進實現了無限數量 contrastive pairs 的對比學習,有效解決了監督對比學習 (supervised contrastive learning)[1] 對 batch (memory bank) size 大小的固有依賴問題。除了長尾視覺分類任務,該方法還在長尾半監督學習、長尾目標檢測和平衡資料集上進行了實驗,取得了顯著的效能提升。

圖片

  • 論文連結: https://arxiv.org/pdf/2403.06726

  • 專案連結: https://github.com/LeapLabTHU/ProCo

研究動機

對比學習在自監督學習中的成功表明了其在學習視覺特徵表示方面的有效性。影響對比學習效能的核心因素是 contrastive pairs 的數量,這使得模型能夠從更多的負樣本中學習,體現在兩個最具代表性的方法 SimCLR [2] 和 MoCo [3] 中分別為 batch size 和 memory bank 的大小。然而在長尾視覺識別任務中,由於類別不均衡,增加 contrastive pairs 的數量所帶來的增益會產生嚴重的邊際遞減效應,這是由於大部分的 contrastive pairs 都是由頭部類別的樣本構成的,難以覆蓋到尾部類別

例如,在長尾 Imagenet 資料集中,若 batch size (memory bank) 大小設為常見的 4096 和 8192,那麼每個 batch (memory bank) 中平均分別有 212 個和 89 個類別的樣本數量不足一個。

因此,ProCo 方法的核心 idea 是:在長尾資料集上,透過對每類資料的分佈進行建模、引數估計並從中取樣以構建 contrastive pairs,保證能夠覆蓋到所有的類別。進一步,當取樣數量趨於無窮時,可以從理論上嚴格推匯出 contrastive loss 期望的解析解,從而直接以此作為最佳化目標,避免了對 contrastive pairs 的低效取樣,實現無限數量 contrastive pairs 的對比學習。

然而,實現以上想法主要有以下幾個難點:

  • 如何對每類資料的分佈進行建模。

  • 如何高效地估計分佈的引數,尤其是對於樣本數量較少的尾部類別。

  • 如何保證 contrastive loss 的期望的解析解存在且可計算。

事實上,以上問題可以透過一個統一的機率模型來解決,即選擇一個簡單有效的機率分佈對特徵分佈進行建模,從而可以利用最大似然估計高效地估計分佈的引數,並計算期望 contrastive loss 的解析解。

由於對比學習的特徵是分佈在單位超球面上的,因此一個可行的方案是選擇球面上的 von Mises-Fisher (vMF) 分佈作為特徵的分佈(該分佈類似於球面上的正態分佈)。vMF 分佈引數的最大似然估計有近似解析解且僅依賴於特徵的一階矩統計量,因此可以高效地估計分佈的引數,並且嚴格推匯出 contrastive loss 的期望,從而實現無限數量 contrastive pairs 的對比學習。

圖片

圖 1 ProCo 演算法根據不同 batch 的特徵來估計樣本的分佈,透過取樣無限數量的樣本,可以得到期望 contrastive loss 的解析解,有效地消除了監督對比學習對 batch size (memory bank) 大小的固有依賴。

方法詳述

接下來將從分佈假設、引數估計、最佳化目標和理論分析四個方面詳細介紹 ProCo 方法。

分佈假設

如前所述,對比學習中的特徵被約束在單位超球面上。因此,可以假設這些特徵服從的分佈為 von Mises-Fisher (vMF) 分佈,其機率密度函式為:圖片

其中 z 是 p 維特徵的單位向量,I 是第一類修正貝塞爾函式,

圖片

μ 是分佈的均值方向,κ 是集中引數,控制分佈的集中程度,當 κ 越大時,樣本聚集在均值附近的程度越高;當 κ =0 時,vMF 分佈退化為球面上的均勻分佈。

引數估計

基於上述分佈假設,資料特徵的總體分佈為混合 vMF 分佈,其中每個類別對應一個 vMF 分佈。

圖片

其中引數 圖片表示每個類別的先驗機率,對應於訓練集中類別 y 的頻率。特徵分佈的均值向量圖片和集中引數圖片 透過最大似然估計來估計。

假設從類別 y 的 vMF 分佈中取樣 N 個獨立的單位向量,則均值方向和集中引數的最大似然估計 (近似)[4] 滿足以下方程:

圖片

其中圖片是樣本均值,圖片是樣本均值的模長。此外,為了利用歷史上的樣本,ProCo 採用了線上估計的方法,能夠有效地對尾部類別的引數進行估計。

最佳化目標

基於估計的引數,一種直接的方法是從混合 vMF 分佈中取樣以構建 contrastive pairs . 然而在每次訓練迭代中從 vMF 分佈中取樣大量的樣本是低效的。因此,該研究在理論上將樣本數量擴充套件到無窮大,並嚴格推匯出期望對比損失函式的解析解直接作為最佳化目標。

圖片

透過在訓練過程中引入一個額外的特徵分支 (基於該最佳化目標進行 representation learning),該分支可以與分類分支一起訓練,並且由於在推理過程中只需要分類分支,因此不會增加額外的計算成本。兩個分支 loss 的加權和作為最終的最佳化目標,

圖片

在實驗中均設定 α=1. 最終,ProCo 演算法的整體流程如下:

圖片

理論分析

為了進一步從理論上驗證 ProCo 方法的有效性,研究者們對其進行了泛化誤差界和超額風險界的分析。為了簡化分析,這裡假設只有兩個類別,即 y∈ {-1,+1}.

圖片

分析表明,泛化誤差界主要由訓練樣本數量和資料分佈的方差控制,這一發現與相關工作的理論分析 [6][7] 一致,保證了 ProCo loss 沒有引入額外因素,也沒有增大泛化誤差界,從理論上保證了該方法的有效性。

此外,該方法依賴於關於特徵分佈和引數估計的某些假設。為了評估這些引數對模型效能的影響,研究者們還分析了 ProCo loss 的超額風險界,其衡量了使用估計引數的期望風險與貝葉斯最優風險之間的偏差,後者是在真實分佈引數下的期望風險。

圖片

這表明 ProCo loss 的超額風險主要受引數估計誤差的一階項控制。

實驗結果

作為核心 motivation 的驗證,研究者們首先與不同對比學習方法在不同 batch size 下的效能進行了比較。Baseline 包括同樣基於 SCL 在長尾識別任務上的改進方法 Balanced Contrastive Learning [5](BCL)。具體的實驗 setting 遵循 Supervised Contrastive Learning (SCL) 的兩階段訓練策略,即首先只用 contrastive loss 進行 representation learning 的訓練,然後在 freeze backbone 的情況下訓練一個 linear classifier 進行測試。

下圖展示了在 CIFAR100-LT (IF100) 資料集上的實驗結果,BCL 和 SupCon 的效能明顯受限於 batch size,但 ProCo 透過引入每個類別的特徵分佈,有效消除了 SupCon 對 batch size 的依賴,從而在不同的 batch size 下都取得了最佳效能。

圖片

此外,研究者們還在長尾識別任務,長尾半監督學習,長尾目標檢測和平衡資料集上進行了實驗。這裡主要展示了在大規模長尾資料集 Imagenet-LT 和 iNaturalist2018 上的實驗結果。首先在 90 epochs 的訓練 schedule 下,相比於同類改進對比學習的方法,ProCo 在兩個資料集和兩個 backbone 上都有至少 1% 的效能提升。

圖片

下面的結果進一步表明了 ProCo 也能夠從更長的訓練 schedule 中受益,在 400 epochs schedule 下,ProCo 在 iNaturalist2018 資料集上取得了 SOTA 的效能,並且還驗證了其能夠與其它非對比學習方法相結合,包括 distillation (NCL) 等方法。

圖片

  1. P. Khosla, et al. “Supervised contrastive learning,” in NeurIPS, 2020.

  2. Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International conference on machine learning. PMLR, 2020.

  3. He, Kaiming, et al. "Momentum contrast for unsupervised visual representation learning." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2020.

  4. S. Sra, “A short note on parameter approximation for von mises-fisher distributions: and a fast implementation of is (x),” Computational Statistics, 2012.

  5. J. Zhu, et al. “Balanced contrastive learning for long-tailed visual recognition,” in CVPR, 2022.

  6. W. Jitkrittum, et al. “ELM: Embedding and logit margins for long-tail learning,” arXiv preprint, 2022.

  7. A. K. Menon, et al. “Long-tail learning via logit adjustment,” in ICLR, 2021.

相關文章