[論文閱讀] Residual Attention(Multi-Label Recognition)

耳王東令發表於2021-08-15

Residual Attention

文章: Residual Attention: A Simple but Effective Method for Multi-Label Recognition, ICCV2021

下面說一下我對這篇文章的淺陋之見, 如有錯誤, 請多包涵指正.

文章的核心方法

如下圖所示為其處理流程:

residualAttentionFlowchart

圖中 X 為CNN骨幹網路提取得到的feature, 其大小為 d*h*w , 為1個batch資料. 一般 d*h*w=2048*7*7 .

從圖中可以看到, 有2個分支, 一個是 average pooling, 一個是 spatial pooling, 最後二者加權融合得到 residual attention .

Spatial pooling

其過程為:

residualAttentionSpatial

這裡有個 1*1 的卷積操作FC , 其大小為 C*d*1*1 , C 為類別數, 如果直接使用矩陣乘法計算, FC(X) 後的大小為 C*h*w .

但文章中的公式是將其展開為對每個空間點單獨計算, 其中 \(\pmb{m_i}\)​​​ 為 FCi 個類別的引數, 其大小為 d*1*1, 計算得到的 \(s^i_j\)​​​ 為第 i 個類別在第 j 個位置的概率, \(\pmb{a^i}\)​​​​ 為第 i 個類別的特徵, 其大小為 d*1 .

如果, \(\pmb{m_i}\)\(\pmb{a^i}\) 計算就可以得到第 i 個類別的概率. 這樣就可以用到每個空間點的特徵, 有利於不同目標不同類別物體的分類識別.

公式中有個溫度引數 T 用來控制 \(s^i_j\)​​​​ 的大小, 當 T 趨於無窮時, spatial pooling 就變成了 max pooling

Average pooling

其過程為:

residualAttentionAverage

上式其實就是一般分類模型的做法, 全域性均值池化.

Residual Attention

如下所示, 將上述2個過程進行加權融合:

residualAttentionResidual

其中, \(\pmb{f^i}\) 大小為 d*1, \(\pmb{m_i}^T \pmb{f^i}\) 為第 i 個類別的概率.

至於為什麼叫 Residual Attention , 文章中的說法是:

the max pooling among different spatial regions for every class, is in fact a class-specific attention operation, which can be further viewed as a residual component of the class-agnostic global average pooling.

我的理解是, 公式5形式有點像 residual 形式.

文章實驗結果

多標籤

如下表所示為作者對多個資料集的測試, 除了ImageNet 為單標籤外, 其它都為多標籤. 可以看到多標籤提升還是不錯的.

residualAttentionMultiResult

熱力圖

由於利用到了不同位置空間點的資訊, 獲得的 heatmap 會更加準確, 文章中給出了一張結果, 如下:

residualAttentionHeatmap

我覺得這裡有個遺憾的是, 文中沒有進行對比.

個人理解

關於原理

根據流程圖, 結合文中作者給出的核心程式碼, 其基本原理就是 average pooling + max pooling.

residualAttentionCode

上述程式碼中: y_avg 大小為 C*1, 為 average pooling ; y_max 大小為 C*1, 為 max pooling .

下面是上述程式碼的一個例子, y_raw 的大小為 1*3*9 , B=1, C=3, H3H, W=3:

residualAttentionExample

可以看到, y_avg 剛好為 average pooling , y_max 剛好為 max pooling .

關於公式

公式中的溫度引數 T 用於調整引數大小, 而給出的核心程式碼中, 只有T趨於無窮的情況(等價於max pooling), 對於多個 Head 的情況, T=2,3,4,5 等, 程式碼中是如何體現出來的?

關於效果

對於 multi-label , 使用了 spatial poolingmulti-head 來提高效果, 從實驗結果來看, 確實有效果, 但對於單標籤情況, max pooling 應該改善不大, 從實驗結果上看也確實可以看到, 單標籤資料集上, 最高提升了0.02個百分點.

測試程式碼

測試程式碼如下, 可以參考這裡.

import torch
from torch import nn

class ResidualAttention(nn.Module):
    def __init__(self, channel=512, num_class=1000, la=0.2):
        super().__init__()
        self.la = la
        self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False)

    def forward(self, x):
        y_raw = self.fc(x).flatten(2) # b, num_class, h*w
        y_avg = torch.mean(y_raw, dim=2) # b, num_class
        y_max = torch.max(y_raw, dim=2)[0] # b, num_class
        score = y_avg + self.la * y_max
        return score

if __name__ == '__main__':

    channel = 4
    num_class = 3
    batchsize = 1
    input = torch.randn(batchsize, channel, 3, 3)
    resatt = ResidualAttention(channel=channel, num_class=num_class, la=0.2)
    output = resatt(input)
    print(output.shape)

相關文章