Residual Attention
文章: Residual Attention: A Simple but Effective Method for Multi-Label Recognition, ICCV2021
下面說一下我對這篇文章的淺陋之見, 如有錯誤, 請多包涵指正.
文章的核心方法
如下圖所示為其處理流程:
圖中 X
為CNN骨幹網路提取得到的feature, 其大小為 d*h*w
, 為1個batch資料. 一般 d*h*w=2048*7*7
.
從圖中可以看到, 有2個分支, 一個是 average pooling
, 一個是 spatial pooling
, 最後二者加權融合得到 residual attention
.
Spatial pooling
其過程為:
這裡有個 1*1
的卷積操作FC
, 其大小為 C*d*1*1
, C
為類別數, 如果直接使用矩陣乘法計算, FC(X)
後的大小為 C*h*w
.
但文章中的公式是將其展開為對每個空間點單獨計算, 其中 \(\pmb{m_i}\) 為 FC
第i
個類別的引數, 其大小為 d*1*1
, 計算得到的 \(s^i_j\) 為第 i
個類別在第 j
個位置的概率, \(\pmb{a^i}\) 為第 i
個類別的特徵, 其大小為 d*1
.
如果, \(\pmb{m_i}\) 和 \(\pmb{a^i}\) 計算就可以得到第 i
個類別的概率. 這樣就可以用到每個空間點的特徵, 有利於不同目標不同類別物體的分類識別.
公式中有個溫度引數 T
用來控制 \(s^i_j\) 的大小, 當 T
趨於無窮時, spatial pooling
就變成了 max pooling
Average pooling
其過程為:
上式其實就是一般分類模型的做法, 全域性均值池化.
Residual Attention
如下所示, 將上述2個過程進行加權融合:
其中, \(\pmb{f^i}\) 大小為 d*1
, \(\pmb{m_i}^T \pmb{f^i}\) 為第 i
個類別的概率.
至於為什麼叫 Residual Attention
, 文章中的說法是:
the max pooling among different spatial regions for every class, is in fact a class-specific attention operation, which can be further viewed as a residual component of the class-agnostic global average pooling.
我的理解是, 公式5形式有點像 residual 形式.
文章實驗結果
多標籤
如下表所示為作者對多個資料集的測試, 除了ImageNet
為單標籤外, 其它都為多標籤. 可以看到多標籤提升還是不錯的.
熱力圖
由於利用到了不同位置空間點的資訊, 獲得的 heatmap
會更加準確, 文章中給出了一張結果, 如下:
我覺得這裡有個遺憾的是, 文中沒有進行對比.
個人理解
關於原理
根據流程圖, 結合文中作者給出的核心程式碼, 其基本原理就是 average pooling
+ max pooling
.
上述程式碼中: y_avg
大小為 C*1
, 為 average pooling
; y_max
大小為 C*1
, 為 max pooling
.
下面是上述程式碼的一個例子, y_raw
的大小為 1*3*9
, B=1, C=3, H3H, W=3
:
可以看到, y_avg
剛好為 average pooling
, y_max
剛好為 max pooling
.
關於公式
公式中的溫度引數 T
用於調整引數大小, 而給出的核心程式碼中, 只有T
趨於無窮的情況(等價於max pooling
), 對於多個 Head
的情況, T=2,3,4,5
等, 程式碼中是如何體現出來的?
關於效果
對於 multi-label
, 使用了 spatial pooling
和 multi-head
來提高效果, 從實驗結果來看, 確實有效果, 但對於單標籤情況, max pooling
應該改善不大, 從實驗結果上看也確實可以看到, 單標籤資料集上, 最高提升了0.02個百分點.
測試程式碼
測試程式碼如下, 可以參考這裡.
import torch
from torch import nn
class ResidualAttention(nn.Module):
def __init__(self, channel=512, num_class=1000, la=0.2):
super().__init__()
self.la = la
self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False)
def forward(self, x):
y_raw = self.fc(x).flatten(2) # b, num_class, h*w
y_avg = torch.mean(y_raw, dim=2) # b, num_class
y_max = torch.max(y_raw, dim=2)[0] # b, num_class
score = y_avg + self.la * y_max
return score
if __name__ == '__main__':
channel = 4
num_class = 3
batchsize = 1
input = torch.randn(batchsize, channel, 3, 3)
resatt = ResidualAttention(channel=channel, num_class=num_class, la=0.2)
output = resatt(input)
print(output.shape)