焦點損失函式 Focal Loss 與 GHM

忽逢桃林 發表於 2020-08-01

文章來自公眾號【機器學習煉丹術】

1 focal loss的概述

焦點損失函式 Focal Loss(2017年何凱明大佬的論文)被提出用於密集物體檢測任務。

當然,在目標檢測中,可能待檢測物體有1000個類別,然而你想要識別出來的物體,只是其中的某一個類別,這樣其實就是一個樣本非常不均衡的一個分類問題。

而Focal Loss簡單的說,就是解決樣本數量極度不平衡的問題的。

說到樣本不平衡的解決方案,相比大家是知道一個混淆矩陣的f1-score的,但是這個好像不能用在訓練中當成損失。而Focal loss可以在訓練中,讓小數量的目標類別增加權重,讓分類錯誤的樣本增加權重

先來看一下簡單的二值交叉熵的損失:
焦點損失函式 Focal Loss 與 GHM

  • y’是模型給出的預測類別概率,y是真實樣本。就是說,如果一個樣本的真實類別是1,預測概率是0.9,那麼\(-log(0.9)\)就是這個損失。
  • 講道理,一般我不喜歡用二值交叉熵做例子,用多分類交叉熵做例子會更舒服。

【然後看focal loss的改進】:
焦點損失函式 Focal Loss 與 GHM
這個增加了一個\((1-y')^\gamma\)的權重值,怎麼理解呢?就是如果給出的正確類別的概率越大,那麼\((1-y')^\gamma\)就會越小,說明分類正確的樣本的損失權重小,反之,分類錯誤的樣本的損權重大


【focal loss的進一步改進】:
焦點損失函式 Focal Loss 與 GHM
這裡增加了一個\(\alpha\),這個alpha在論文中給出的是0.25,這個就是單純的降低正樣本或者負樣本的權重,來解決樣本不均衡的問題

兩者結合起來,就是一個可以解決樣本不平衡問題的損失focal loss。


【總結】:

  1. \(\alpha\)解決了樣本的不平衡問題;
  2. \(\beta\)解決了難易樣本不平衡的問題。讓樣本更重視難樣本,忽視易樣本。
  3. 總之,Focal loss會的關注順序為:樣本少的、難分類的;樣本多的、難分類的;樣本少的,易分類的;樣本多的,易分類的。

2 GHM

  • GHM是Gradient Harmonizing Mechanism。

這個GHM是為了解決Focal loss存在的一些問題。

【Focal Loss的弊端1】
讓模型過多的關注特別難分類的樣本是會有問題的。樣本中有一些異常點、離群點(outliers)。所以模型為了擬合這些非常難擬合的離群點,就會存在過擬合的風險。

2.1 GHM的辦法

Focal Loss是從置信度p的角度入手衰減loss的。而GHM是一定範圍內建信度p的樣本數量來衰減loss的。

首先定義了一個變數g,叫做梯度模長(gradient norm)
焦點損失函式 Focal Loss 與 GHM
可以看出這個梯度模長,其實就是模型給出的置信度\(p^*\)與這個樣本真實的標籤之間的差值(距離)。g越小,說明預測越準,說明樣本越容易分類。

下圖中展示了g與樣本數量的關係:
焦點損失函式 Focal Loss 與 GHM

【從圖中可以看到】

  • 梯度模長接近於0的樣本多,也就是易分類樣本是非常多的
  • 然後樣本數量隨著梯度模長的增加迅速減少
  • 然後當梯度模長接近1的時候,樣本的數量又開始增加。

GHM是這樣想的,對於梯度模長小的易分類樣本,我們忽視他們;但是focal loss過於關注難分類樣本了。關鍵是難分類樣本其實也有很多!,如果模型一直學習難分類樣本,那麼可能模型的精確度就會下降。所以GHM對於難分類樣本也有一個衰減。

那麼,GHM對易分類樣本和難分類樣本都衰減,那麼真正被關注的樣本,就是那些不難不易的樣本。而抑制的程度,可以根據樣本的數量來決定。

這裡定義一個GD,梯度密度

\[GD(g)=\frac{1}{l(g)}\sum_{k=1}^N{\delta(g_k,g)} \]
  • \(GD(g)\)是計算在梯度g位置的梯度密度;
  • \(\delta(g_k,g)\)就是樣本k的梯度\(g_k\)是否在\([g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]\)這個區間內。
  • \(l(g)\)就是\([g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]\)這個區間的長度,也就是\(\epsilon\)

總之,\(GD(g)\)就是梯度模長在\([g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]\)內的樣本總數除以\(\epsilon\).

然後把每一個樣本的交叉熵損失除以他們對應的梯度密度就行了。

\[L_{GHM}=\sum^N_{i=1}{\frac{CE(p_i,p_i^*)}{GD(g_i)}} \]
  • \(CE(p_i,p_i^*)\)表示第i個樣本的交叉熵損失;
  • \(GD(g_i)\)表示第i個樣本的梯度密度;

2.2 論文中的GHM

論文中呢,是把梯度模長劃分成了10個區域,因為置信度p是從0~1的,所以梯度密度的區域長度就是0.1,比如是0~0.1為一個區域。

下圖是論文中給出的對比圖:
焦點損失函式 Focal Loss 與 GHM

【從圖中可以得到】

  • 綠色的表示交叉熵損失;
  • 藍色的是focal loss的損失,發現梯度模長小的損失衰減很有效;
  • 紅色是GHM的交叉熵損失,發現梯度模長在0附近和1附近存在明顯的衰減。

當然可以想到的是,GHM看起來是需要整個樣本的模型估計值,才能計算出梯度密度,才能進行更新。也就是說mini-batch看起來似乎不能用GHM。

在GHM原文中也提到了這個問題,如果光使用mini-batch的話,那麼很可能出現不均衡的情況。

【我個人覺得的處理方法】

  1. 可以使用上一個epoch的梯度密度,來作為這一個epoch來使用;
  2. 或者一開始先使用mini-batch計算梯度密度,然後模型收斂速度下降之後,再使用第一種方式進行更新。

3 python實現

上面講述的關鍵在於focal loss實現的功能:

  1. 分類正確的樣本的損失權重小,分類錯誤的樣本的損權重大
  2. 樣本過多的類別的權重較小

在CenterNet中預測中心點位置的時候,也是使用了Focal Loss,但是稍有改動。

3.1 概述

焦點損失函式 Focal Loss 與 GHM
這裡面和上面講的比較類似,我們忽視腳標。

  • 假設\(Y=1\),那麼預測的\(\hat{Y}\)越靠近1,說明預測的約正確,然後\((1-\hat{Y})^\alpha\)就會越小,從而體現分類正確的樣本的損失權重小;otherwize的情況也是這樣。
  • 但是這裡的otherwize中多了一個\((1-Y)^\beta\),這個是用來平衡樣本不均衡問題的,在後面的程式碼部分會提到CenterNet的熱力圖。就會明白這個了。

3.2 程式碼講解

下面通過程式碼來理解:

class FocalLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.neg_loss = _neg_loss

    def forward(self, output, target, mask):
        output = torch.sigmoid(output)
        loss = self.neg_loss(output, target, mask)
        return loss

這裡面的output可以理解為是一個1通道的特徵圖,每一個pixel的值都是模型給出的置信度,然後通過sigmoid函式轉換成0~1區間的置信度。

而target是CenterNet的熱力圖,這一點可能比較難理解。打個比方,一個10*10的全都是0的特徵圖,然後這個特徵圖中只有一個pixel是1,那麼這個pixel的位置就是一個目標檢測物體的中心點。有幾個1就說明這個圖中有幾個要檢測的目標物體。

然後,如果一個特徵圖上,全都是0,只有幾個孤零零的1,未免顯得過於稀疏了,直觀上也非常的不平滑。所以CenterNet的熱力圖還需要對這些1為中心做一個高斯
焦點損失函式 Focal Loss 與 GHM
可以看作是一種平滑:
焦點損失函式 Focal Loss 與 GHM
可以看到,數字1的四周是同樣的數字。這是一個以1為中心的高斯平滑。


這裡我們回到上面說到的\((1-Y)^\beta\)
焦點損失函式 Focal Loss 與 GHM
對於數字1來說,我們計算loss自然是用第一行來計算,但是對於1附近的其他點來說,就要考慮\((1-Y)^\beta\)了。越靠近1的點的\(Y\)越大,那麼\((1-Y)^\beta\)就會越小,這樣從而降低1附近的權重值。其實這裡我也講不太明白,就是根據距離1的距離降低負樣本的權重值,從而可以實現樣本過多的類別的權重較小


我們回到主題,對output進行sigmoid之後,與output一起放到了neg_loss中。我們來看什麼是neg_loss:

def _neg_loss(pred, gt, mask):
    pos_inds = gt.eq(1).float() * mask
    neg_inds = gt.lt(1).float() * mask

    neg_weights = torch.pow(1 - gt, 4)

    loss = 0

    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * \
               neg_weights * neg_inds

    num_pos = pos_inds.float().sum()
    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()

    if num_pos == 0:
        loss = loss - neg_loss
    else:
        loss = loss - (pos_loss + neg_loss) / num_pos
    return loss

先說一下,這裡面的mask是根據特定任務中加上的一個小功能,就是在該任務中,一張圖片中有一部分是不需要計算loss的,所以先用過mask把那個部分過濾掉。這裡直接忽視mask就好了。

neg_weights = torch.pow(1 - gt, 4)可以得知\(\beta=4\),從下面的程式碼中也不難推出,\(\alpha=2\),剩下的內容就都一樣了。

把每一個pixel的損失都加起來,除以目標物體的數量即可。