如何解決迴歸任務資料不均衡的問題?

華為雲開發者社群發表於2021-06-11

摘要:現有的處理不平衡資料/長尾分佈的方法絕大多數都是針對分類問題,而回歸問題中出現的資料不均衡問題確極少被研究。

本文分享自華為雲社群《如何解決迴歸任務資料不均衡的問題?》,原文作者:PG13。

現有的處理不平衡資料/長尾分佈的方法絕大多數都是針對分類問題,而回歸問題中出現的資料不均衡問題確極少被研究。但是,現實很多的工業預測場景都是需要解決迴歸的問題,也就是涉及到連續的,甚至是無限多的目標值,如何解決迴歸問題中出現的資料不均衡問題呢?ICML2021一篇被接收為Long oral presentation的論文:Delving into Deep Imbalanced Regression,推廣了傳統不均衡分類問題的正規化,將資料不平衡問題從離散值域推廣到了連續值域,並提出了兩種解決深度不均衡迴歸問題的方法。

主要的貢獻是三個方面:1)提出了一個深度不均衡迴歸(Deep Imbalanced Regression, DIR)任務,定義為從具有連續目標的不平衡資料中學習,並能泛化到整個目標範圍;2)提出了兩種解決DIR的新方法,標籤分佈平滑(label distribution smoothing, LDS)和特徵分佈平滑(feature distribution smoothing, FDS),來解決具有連續目標的不平衡資料的學習問題;3)建立了5個新的DIR資料集,包括了CV、NLP、healthcare上的不平衡迴歸任務,致力於幫助未來在不平衡資料上的研究。

資料不平衡問題背景

現實世界的資料通常不會每個類別都具有理想的均勻分佈,而是呈現出長尾的偏斜分佈,其中某些目標值的觀測值明顯較少,這對於深度學習模型有較大的挑戰。傳統的解決辦法可以分為基於資料基於模型兩種:基於資料的解決方案無非對少數群體進行過取樣和對多數群體進行下采樣,比如SMOTE演算法;基於模型的解決方案包括對損失函式的重加權(re-weighting)或利用相關的學習技巧,如遷移學習、元學習、兩階段訓練等。

但是現有的資料不平衡解決方案,主要是針對具有categorical index的目標值,也就是離散的類別標籤資料。其目標值屬於不同的類別,並且具有嚴格的硬邊界,不同類別之間沒有重疊。現實世界很多的預測場景可能涉及到連續目標值的標籤資料。比如,根據人臉視覺圖片預測年齡,年齡便是一個連續的目標值,並且在目標範圍內可能會高度失衡。在工業領域中,也會發生類似的問題,比如在水泥領域,水泥熟料的質量,一般都是連續的目標值;在配煤領域,焦炭的熱強指標也是連續的目標值。這些應用中需要預測的目標變數往往存在許多稀有和極端值。在連續域的不平衡問題線上性模型和深度模型中都是存在的,在深度模型中甚至更為嚴重,這是因為深度學習模型的預測往往都是over-confident的,會導致這種不平衡問題被嚴重的放大。

因此,這篇文章定義了深度不平衡迴歸問題(DIR),即從具有連續目標值的不平衡資料中學習,同時需要處理某些目標區域的潛在確實資料,並使最終模型能夠泛化到整個支援所有目標值的範圍上。

https://i.iter01.com/images/ba66ec2cc35cb33322af1ada44b3ecc2b63c747560cdca5231a9b65e867f0ca4.png

不平衡迴歸問題的挑戰

解決DIR問題的三個挑戰如下:

  1.  對於連續的目標值(標籤),不同目標值之間的硬邊界不再存在,無法直接採用不平衡分類的處理方法。
  2.  連續標籤本質上說明在不同的目標值之間的距離是有意義的。這些目標值直接告訴了哪些資料之間相隔更近,指導我們該如何理解這個連續區間上的資料不均衡的程度。
  3.  對於DIR,某些目標值可能根本沒有資料,這為對目標值做extrapolation和interpolation提供了需求。

解決方法一:標籤分佈平滑(LDS)

首先通過一個例子展示一下當資料出現不均衡的時候,分類和迴歸問題之間的區別。作者在兩個不同的資料集:(1)CIFAR-100,一個100類的影像分類資料集;(2)IMDB-WIKI,一個用於根據人像估算年齡(迴歸)的影像資料集,進行了比較。通過取樣處理來模擬資料不平衡,保證兩個資料集具有完全相同的標籤密度分佈,如下圖所示:

https://i.iter01.com/images/3f40f876742a9d06574719a3b6234c8742811da9fec5ad1c9f01bbb8c53433e2.png

然後,分別在兩個資料集上訓練一個ResNet-50模型,並畫出它們的測試誤差的分佈。從圖中可以看出,在不平衡的分類資料集CIFAR-100上,測試誤差的分佈與標籤密度的分佈是高度負相關的,這很好理解,因為擁有更多樣本的類別更容易學好。但是,連續標籤空間的IMDB-WIKI的測試誤差分佈更加平滑,且不再與標籤密度分佈很好地相關。這說明了對於連續標籤,其經驗標籤密度並不能準確地反映模型所看到的不均衡。這是因為相臨標籤的資料樣本之間是相關的,相互依賴的。

標籤分佈平滑:基於這些發現,作者提出了一種在統計學習領域中的核密度估計(LDS)方法,給定連續的經驗標籤密度分佈,LDS使用了一個對稱核函式k,用經驗密度分佈與之卷積,得到一個kernel-smoothed的有效標籤密度分佈,用來直觀體現臨近標籤的資料樣本具有的資訊重疊問題,通過LDS計算出的有效標籤密度分佈結果與誤差分佈的相關性明顯增強。有了LDS估計出的有效標籤密度,就可以用解決類別不平衡問題的方法,直接應用於解決DIR問題。比如,最簡單地一種make sence方式是利用重加權的方法,通過將損失函式乘以每個目標值的LDS估計標籤密度的倒數來對其進行加權。

https://i.iter01.com/images/bc8b196f105957d8c40cbbbdce6772a2940c54352c7d6d04eeec1145e3d0f664.png

解決方法二:特徵分佈平滑(FDS)

如果模型預測正常且資料是均衡的,那麼label相近的samples,它們對應的feature的統計資訊應該也是彼此接近的。這裡作者也舉了一個例項驗證了這個直覺。作者同樣使用對IMDB-WIKI上訓練的ResNet-50模型。主要focus在模型學習到的特徵空間,不是標籤空間。我們關注的最小年齡差是1歲,因此我們將標籤空間分為了等間隔的區間,將具有相同目標區間的要素分到同一組。然後,針對每個區間中的資料計算其相應的特徵統計量(均值、方差)。特徵的統計量之間的相似性視覺化為如下圖:

https://i.iter01.com/images/a4f77e9bb061901bef48353a031c9c04b554db6872ad82d98b693a77ae6c5569.png
紅色區間代表anchor區間,計算這個anchor label與其他所有label的特徵統計量(即均值、方差)的餘弦相似度。此外,不同顏色區域(紫色,黃色,粉紅色)表示不同的資料密度。從圖中可以得到兩個結論:

  1.  anchor label和其臨近的區間的特徵統計量是高度相似的。而anchor label = 30 剛好是在訓練資料量非常多的區域。這說明了,當有足夠多的資料時,特徵的統計量在臨近點是相似的。
  2.  此外,在資料量很少的區域,如0-6歲的年齡範圍,與30歲年齡段的特徵統計量高度相似。這種不合理的相似性是由於資料不均衡造成的。因為,0-6歲的資料很少,該範圍的特徵會從具有最大資料量的範圍繼承其先驗。

特徵分佈平滑:受到這些啟發,作者提出了特徵分佈平滑(FDS)。FDS是對特徵空間進行分佈的平滑,本質上是在臨近的區間之間傳遞特徵的統計資訊。此過程的主要作用是去校準特徵分佈的潛在的有偏差的估計,尤其是對那些樣本很少的目標值而言。

https://i.iter01.com/images/71527192dc5cdfd8b78b4da2f72ddc609ac3886d067b4c642e86019cebcec6bf.png
具體來說,有一個模型,f代表一個encoder將輸入資料對映到隱層的特徵,g作為一個predictor來輸出連續的預測目標值。FDS會首先估計每個區間特徵的統計資訊。這裡用特徵的協方差代替方差,來反映特徵z內部元素之間的關係。給定特徵統計量,再次使用對稱核函式k來smooth特徵均值和協方差的分佈,這樣可以拿到統計資訊的平滑版本。利用估計和平滑統計量,遵循標準的whitening and re-coloring過程來校準每個輸入樣本的特徵表示。那麼整個FDS過程可以通過在最終特徵圖之後插入一個特徵的校準層,實現將FDS整合到深度網路中。最後,在每個epoch採用了動量更新,來獲得對訓練過程中特徵統計資訊的一個更穩定和更準確的估計。

基準DIR資料集

  1.  IMDB-WIKI-DIR(vision, age):基於IMDB-WIKI資料集,從包含人面部的影像來推斷估計相應的年齡。
  2.  AgeDB-DIR(vision, age):基於AgeDB資料集,同樣是根據輸入影像進行年齡估計。
  3.  NYUD2-DIR(vision, depth):基於NYU2資料集,用於構建depth estimation的DIR任務。
  4.  STS-B-DIR(NLP, test similarity score):基於STS-B資料集,任務是推斷兩個輸入句子之間的語義文字的相似度得分。
  5.  SHHS-DIR(Healthcare, health condition score):基於SHHS資料集,該任務是推斷一個人的總體健康評分。

具體的實驗可以檢視該論文,這裡附上論文原文以及程式碼地址:

[論文]:https://arxiv.org/abs/2102.09554

[程式碼]:https://github.com/YyzHarry/imbalanced-regression

 

點選關注,第一時間瞭解華為雲新鮮技術~

相關文章