Objects as Points:預測目標中心,無需NMS等後處理操作 | CVPR 2019

曉飛的演算法工程筆記發表於2021-01-19

論文基於關鍵點預測網路提出CenterNet演算法,將檢測目標視為關鍵點,先找到目標的中心點,然後迴歸其尺寸。對比上一篇同名的CenterNet演算法,本文的演算法更簡潔且效能足夠強大,不需要NMS等後處理方法,能夠擴充到其它檢測任務中

來源:曉飛的演算法工程筆記 公眾號

論文: Objects as Points

Introduction


  論文認為當前的anchor-based方法雖然效能很高,但需要列舉所有目標可能出現的位置以及尺寸,實際上是很浪費的。為此,論文提出了簡單且高效的CenterNet,將目標表示為其中心點,再通過中心點特徵迴歸目標的尺寸。

  CenterNet將輸入的圖片轉換成熱圖,熱圖中的高峰點對應目標的中心,將高峰點的特徵向量用於預測目標的高和寬,如圖2所示。在推理時,只需要簡單的前向計算即可,不需要NMS等後處理操作。

  對比現有的方法,CenterNet在準確率和速度上有更好的trade-off。另外,CenterNet的架構是通用的,能夠擴充到其它任務,比如3D目標檢測以及人體關鍵點預測。

Preliminary


  定義輸入圖片$I\in R^{W\times H\times 3}$,預測關鍵點熱圖$\hat{Y}\in
[ 0, 1 ]^{\frac{W}{R}\times \frac{H}{R}\times C}$,其中$R$為熱圖的縮放比例,設定為4,$C$為關鍵點的型別。當$\hat{Y}{x,y,c}=1$時,畫素點為檢測的關鍵點,當$\hat{Y}{x,y,c}=0$時,畫素點為背景。在主幹網路方法,論文嘗試了多種全卷積encoder-decoder網路:Hourglass網路,帶反摺積的殘差網路以及DLA(deep layer aggregation)。

  關鍵點預測部分的訓練跟CornerNet一樣,對於類別$c$的GT關鍵點$p\in \mathcal{R}2$,計算其在熱圖上對應的位置$\tilde{p}=\lfloor\frac{p}{R}\rfloor$,然後使用高斯核$Y_{xyc}=exp(-\frac{(x-\tilde{p}_x)2+(y-\tilde{p}_y)2}{2\sigma2_p })$將GT關鍵點散射,即根據畫素位置到關鍵點的距離賦予不同的權值,得到GT熱圖$Y\in [ 0,1 ]^{(\frac{W}{R}\times \frac{H}{R}\times C)}$,$\sigma_p$為目標尺寸自適應的標準差,如圖3所示。如果相同類別的高斯核散射重疊了,則取element-wise的最大值。訓練的損失函式為懲罰衰減的邏輯迴歸,附加了focal loss:

  $\alpha$和$\beta$為focal loss的超引數,$N$為關鍵點數。為了恢復特徵圖縮放帶來的誤差,額外預測每個關鍵點的偏移值$\hat{O}\in \mathcal{R}^{\frac{W}{R}\times \frac{H}{R}\times 2}$,偏移值與類別無關,通過L1損失進行訓練:

  偏移值只使用GT關鍵點,其它位置的點不參與訓練。

Objects as Points


  定義$(x^{(k)}_1, y^{(k)}_1, x{(k)}_2,y{(k)}_2)$為目標$k$的GT框,類別為$c_k$,其中心點為$p_k=(\frac{x{(k)}_1+x{(k)}2}{2}, \frac{y{(k)}_1+y{(k)}2}{2})$。論文使用熱圖$\hat{Y}$得到所有的中心點,另外再回歸每個目標$k$的尺寸$s_k=(x{(k)}_{2}-x{(k)}{1}, y{(k)}_{2}-y{(k)}{1})$。為減少計算負擔,尺寸的預測與類別無關$\hat{S}\in \mathcal{R}^{\frac{W}{R}\times \frac{H}{R}\times 2}$,通過L1損失進行訓練,只使用GT關鍵點:

  完整的CenterNet損失函式為:

  CenterNet直接預測關鍵點熱圖$\hat{Y}$、偏移值$\hat{O}$和目標尺寸$\hat{S}$,每個位置共計預測$C+4$個輸出。所有的輸出共用主幹網路特徵,再接各自的$3\times 3$卷積、ReLU和$1\times 1$卷積。

  在推理時,首先獲取各類別熱圖上的高峰點,高峰點的值需高於周圍八個聯通點的值,最後取top-100高峰點。對於每個高峰點$(x_i, y_i)$,使用預測的關鍵點值$\hat{Y}_{x,y,c}$作為檢測置信度,結合預測的偏移值$\hat{O}=(\delta \hat{x}_i, \delta \hat{y}_i)$和目標尺寸$\hat{S}=(\hat{w}_i, \hat{h}_i)$生成預測框:

  由於高峰點的提取方法足以替代NMS的作用,所有的預測框都直接通過關鍵點輸出,不需要再進行NMS操作以及其它後處理。需要注意的是,論文采用了巧妙的方法實現高峰點獲取,先對特徵圖使用padding=1的$3\times 3$最大值池化,然後對比輸出特徵圖和原圖,值一樣的點即為滿足要求的高峰點。

Implementation details


  CenterNet的輸入為$512\times 512$,輸出的熱圖大小為$128\times 128$。實驗測試了4種網路結構:ResNet-18、ResNet-101、DLA-34和Hourglass-104,其中使用可變形卷積對ResNet和DLA-34進行了改進。

Hourglass

  Hourglass結構如圖a所示,框中的數字為特徵圖的縮放比例,包含兩個hourglass模組,每個模組有5個下采樣層以及5個上取樣層,上取樣和下采樣對應的層有短路連線。Hourglass的網路尺寸最大,關鍵點預測的效果也是最好的。

ResNet

  ResNet大體結構跟原版一致,加入了反摺積用來恢復特徵圖大小,反摺積的權值初始化為雙線性插值操作,虛線箭頭為$3\times 3$可變形卷積操作。

DLA

  DLA使用層級短路連線,原版的結構如圖c所示。論文將大部分的卷積操作修改為可變形卷積,並對每層的輸出進行了$3\times 3$卷積融合,最後使用$1\times 1$卷積輸出到目標維度,如圖d所示。

Experiment


  不同主幹網路在目標檢測上的準確率和速度對比。

  目標檢測效能對比。

  3D檢測效能對比。

  人體關鍵點檢測效能對比。

Conclusion


  論文基於關鍵點預測網路提出CenterNet演算法,將檢測目標視為關鍵點,先找到目標的中心點,然後迴歸其尺寸。對比上一篇同名的CenterNet演算法,本文的演算法更簡潔且效能足夠強大,不需要NMS等後處理方法,能夠擴充到其它檢測任務中 。



如果本文對你有幫助,麻煩點個贊或在看唄~
更多內容請關注 微信公眾號【曉飛的演算法工程筆記】

work-life balance.

相關文章