文章來自公眾號【機器學習煉丹術】
1 focal loss的概述
焦點損失函式 Focal Loss(2017年何凱明大佬的論文)被提出用於密集物體檢測任務。
當然,在目標檢測中,可能待檢測物體有1000個類別,然而你想要識別出來的物體,只是其中的某一個類別,這樣其實就是一個樣本非常不均衡的一個分類問題。
而Focal Loss簡單的說,就是解決樣本數量極度不平衡的問題的。
說到樣本不平衡的解決方案,相比大家是知道一個混淆矩陣的f1-score的,但是這個好像不能用在訓練中當成損失。而Focal loss可以在訓練中,讓小數量的目標類別增加權重,讓分類錯誤的樣本增加權重。
先來看一下簡單的二值交叉熵的損失:
- y’是模型給出的預測類別概率,y是真實樣本。就是說,如果一個樣本的真實類別是1,預測概率是0.9,那麼\(-log(0.9)\)就是這個損失。
- 講道理,一般我不喜歡用二值交叉熵做例子,用多分類交叉熵做例子會更舒服。
【然後看focal loss的改進】:
這個增加了一個\((1-y')^\gamma\)的權重值,怎麼理解呢?就是如果給出的正確類別的概率越大,那麼\((1-y')^\gamma\)就會越小,說明分類正確的樣本的損失權重小,反之,分類錯誤的樣本的損權重大。
【focal loss的進一步改進】:
這裡增加了一個\(\alpha\),這個alpha在論文中給出的是0.25,這個就是單純的降低正樣本或者負樣本的權重,來解決樣本不均衡的問題。
兩者結合起來,就是一個可以解決樣本不平衡問題的損失focal loss。
【總結】:
- \(\alpha\)解決了樣本的不平衡問題;
- \(\beta\)解決了難易樣本不平衡的問題。讓樣本更重視難樣本,忽視易樣本。
- 總之,Focal loss會的關注順序為:樣本少的、難分類的;樣本多的、難分類的;樣本少的,易分類的;樣本多的,易分類的。
2 GHM
- GHM是Gradient Harmonizing Mechanism。
這個GHM是為了解決Focal loss存在的一些問題。
【Focal Loss的弊端1】
讓模型過多的關注特別難分類的樣本是會有問題的。樣本中有一些異常點、離群點(outliers)。所以模型為了擬合這些非常難擬合的離群點,就會存在過擬合的風險。
2.1 GHM的辦法
Focal Loss是從置信度p的角度入手衰減loss的。而GHM是一定範圍內建信度p的樣本數量來衰減loss的。
首先定義了一個變數g,叫做梯度模長(gradient norm):
可以看出這個梯度模長,其實就是模型給出的置信度\(p^*\)與這個樣本真實的標籤之間的差值(距離)。g越小,說明預測越準,說明樣本越容易分類。
下圖中展示了g與樣本數量的關係:
【從圖中可以看到】
- 梯度模長接近於0的樣本多,也就是易分類樣本是非常多的
- 然後樣本數量隨著梯度模長的增加迅速減少
- 然後當梯度模長接近1的時候,樣本的數量又開始增加。
GHM是這樣想的,對於梯度模長小的易分類樣本,我們忽視他們;但是focal loss過於關注難分類樣本了。關鍵是難分類樣本其實也有很多!,如果模型一直學習難分類樣本,那麼可能模型的精確度就會下降。所以GHM對於難分類樣本也有一個衰減。
那麼,GHM對易分類樣本和難分類樣本都衰減,那麼真正被關注的樣本,就是那些不難不易的樣本。而抑制的程度,可以根據樣本的數量來決定。
這裡定義一個GD,梯度密度:
- \(GD(g)\)是計算在梯度g位置的梯度密度;
- \(\delta(g_k,g)\)就是樣本k的梯度\(g_k\)是否在\([g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]\)這個區間內。
- \(l(g)\)就是\([g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]\)這個區間的長度,也就是\(\epsilon\)
總之,\(GD(g)\)就是梯度模長在\([g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]\)內的樣本總數除以\(\epsilon\).
然後把每一個樣本的交叉熵損失除以他們對應的梯度密度就行了。
- \(CE(p_i,p_i^*)\)表示第i個樣本的交叉熵損失;
- \(GD(g_i)\)表示第i個樣本的梯度密度;
2.2 論文中的GHM
論文中呢,是把梯度模長劃分成了10個區域,因為置信度p是從0~1的,所以梯度密度的區域長度就是0.1,比如是0~0.1為一個區域。
下圖是論文中給出的對比圖:
【從圖中可以得到】
- 綠色的表示交叉熵損失;
- 藍色的是focal loss的損失,發現梯度模長小的損失衰減很有效;
- 紅色是GHM的交叉熵損失,發現梯度模長在0附近和1附近存在明顯的衰減。
當然可以想到的是,GHM看起來是需要整個樣本的模型估計值,才能計算出梯度密度,才能進行更新。也就是說mini-batch看起來似乎不能用GHM。
在GHM原文中也提到了這個問題,如果光使用mini-batch的話,那麼很可能出現不均衡的情況。
【我個人覺得的處理方法】
- 可以使用上一個epoch的梯度密度,來作為這一個epoch來使用;
- 或者一開始先使用mini-batch計算梯度密度,然後模型收斂速度下降之後,再使用第一種方式進行更新。
3 python實現
上面講述的關鍵在於focal loss實現的功能:
- 分類正確的樣本的損失權重小,分類錯誤的樣本的損權重大。
- 樣本過多的類別的權重較小
在CenterNet中預測中心點位置的時候,也是使用了Focal Loss,但是稍有改動。
3.1 概述
- 假設\(Y=1\),那麼預測的\(\hat{Y}\)越靠近1,說明預測的約正確,然後\((1-\hat{Y})^\alpha\)就會越小,從而體現分類正確的樣本的損失權重小;otherwize的情況也是這樣。
- 但是這裡的otherwize中多了一個\((1-Y)^\beta\),這個是用來平衡樣本不均衡問題的,在後面的程式碼部分會提到CenterNet的熱力圖。就會明白這個了。
3.2 程式碼講解
下面通過程式碼來理解:
class FocalLoss(nn.Module):
def __init__(self):
super().__init__()
self.neg_loss = _neg_loss
def forward(self, output, target, mask):
output = torch.sigmoid(output)
loss = self.neg_loss(output, target, mask)
return loss
這裡面的output可以理解為是一個1通道的特徵圖,每一個pixel的值都是模型給出的置信度,然後通過sigmoid函式轉換成0~1區間的置信度。
而target是CenterNet的熱力圖,這一點可能比較難理解。打個比方,一個10*10的全都是0的特徵圖,然後這個特徵圖中只有一個pixel是1,那麼這個pixel的位置就是一個目標檢測物體的中心點。有幾個1就說明這個圖中有幾個要檢測的目標物體。
然後,如果一個特徵圖上,全都是0,只有幾個孤零零的1,未免顯得過於稀疏了,直觀上也非常的不平滑。所以CenterNet的熱力圖還需要對這些1為中心做一個高斯
可以看作是一種平滑:
可以看到,數字1的四周是同樣的數字。這是一個以1為中心的高斯平滑。
這裡我們回到上面說到的\((1-Y)^\beta\):
對於數字1來說,我們計算loss自然是用第一行來計算,但是對於1附近的其他點來說,就要考慮\((1-Y)^\beta\)了。越靠近1的點的\(Y\)越大,那麼\((1-Y)^\beta\)就會越小,這樣從而降低1附近的權重值。其實這裡我也講不太明白,就是根據距離1的距離降低負樣本的權重值,從而可以實現樣本過多的類別的權重較小。
我們回到主題,對output進行sigmoid之後,與output一起放到了neg_loss中。我們來看什麼是neg_loss:
def _neg_loss(pred, gt, mask):
pos_inds = gt.eq(1).float() * mask
neg_inds = gt.lt(1).float() * mask
neg_weights = torch.pow(1 - gt, 4)
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * \
neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
先說一下,這裡面的mask是根據特定任務中加上的一個小功能,就是在該任務中,一張圖片中有一部分是不需要計算loss的,所以先用過mask把那個部分過濾掉。這裡直接忽視mask就好了。
從neg_weights = torch.pow(1 - gt, 4)
可以得知\(\beta=4\),從下面的程式碼中也不難推出,\(\alpha=2\),剩下的內容就都一樣了。
把每一個pixel的損失都加起來,除以目標物體的數量即可。