How Do Vision Transformers Work?[2202.06709] - 論文研讀系列(2) 個人筆記

黃龍士發表於2022-03-18

[論文簡析]How Do Vision Transformers Work?[2202.06709]

  • 論文題目:How Do Vision Transformers Work?

  • 論文地址:http://arxiv.org/abs/2202.06709

  • 程式碼:https://github.com/xxxnell/how-do-vits-work

  • ICLR2022 - Reviewer Kvf7:

    • 這個文章整理的太難懂了
    • 很多trick很有用,但是作者並沒有完全說明
  • 行文線索 Emporocal Observations:

    • MSAs(多頭自注意力機制 / 一般取代CNN)能夠提高CNN的預測效能,VIT裡面能夠很好的去預測 well-calibrated uncertainty P(模型輸出的預測概率值)
    • 魯棒性,對於data corruptions、image occlusions、adversarial attacks、特別是對high-frequency noisy 高頻噪聲
    • 靠近最後幾層的MSAs能夠顯著的提高我們的效能

Q1:Inductive Biases 歸納偏置

  • 歸納偏置可以理解為,從現實生活中觀察到的現象中歸納出一定的規則(heuristics),然後對模型做一定的約束,從而可以起到“模型選擇”的作用,即從假設空間中選擇出更符合現實規則的模型。其實,貝葉斯學習中的“先驗(Prior)”這個叫法,可能比“歸納偏置”更直觀一些。

    • CNN的inductive bias應該是locality和spatial invariance,即空間相近的grid elements有聯絡而遠的沒有,和空間不變性(kernel權重共享)
    • RNN的inductive bias是sequentiality和time invariance,即序列順序上的timesteps有聯絡,和時間變換的不變性(rnn權重共享)
  • MSA本質上是一個generalized spatial smoothing廣義空間平滑,由幾個value在進行求和,權重由Q和K來給定

    • MSA公式
    • 相對比於CNN和RNN而言,MSA的歸納偏置是weak Inductive Biases,由於是全域性soft-attention,因此會有長距離的關係long-range dependencies
    • 因此適當的約束Appropriate constraints / 強歸納偏置可能會對模型具有幫助:swin、twins(做區域性attention)

Conclusion 1

  • 因此Conclusion 1:歸納偏置越強,預測/特徵學出來的就越強。
    • RX:ResNeXt,R:ResNet,Mixes是最自由的
    • 疑問1:會不會VIT這些模型在CIFAR-100小資料集上overfit過擬合才會導致效能差?- 實際上不是如此,可以看出隨著訓練集大小和訓練時間長短變化,NLLtrain和Error仍然是相關,並沒有出現過擬合現象。

Conclusion 2

  • 疑問2:能不能衡量convexity? - 統計很多個海塞矩陣的最大特徵值,再用來做一個平譜來反映loss function的形狀特性Hessian Max Eigenvalue Spectrum(是這個作者的另一篇論文)

  • 歸納偏置的作用:強的能夠壓制負的特徵值negative eigenvalues

    • VIT 6%表示少量資料集下,負特徵值非常明顯,隨著樣本數量增加,負特徵值得到抑制
  • loss平滑是由MSA導致的,但是這並不一定是一個壞事:regarding generalization & performance具有更好的表達能力,泛化可塑,隨著樣本數量的增加,可以去壓制負特徵值,讓loss function變得沒有那麼平坦,讓凸一些,與VIT適用於大資料量樣本的觀察是契合的。

    • ResNet更加陡峭,在大資料量樣本下很容易陷入區域性最優
    • 夾角越大說明離初始模型越來越不一樣,Transformer過度的十分平滑,越強的歸納偏置會導致優化的更加曲折,可以理解為執行力
  • 以上就是說明,需要在歸納偏置和資料量之間尋找到一個平衡:

    • patch_size ≈ CNN中的kernel大小,大小越大,偏執歸納越強,右圖可以看出歸納偏置越強對負特徵值的抑制越強(本質就是如何作用於loss function)
    • 在ImageNet上,就是由於CNN其歸納偏置太強,導致沒有MSA那麼靈活,沒有那麼強的表達能力 / 泛化能力表達複雜的資料集。
    • local MSA / swin 滑窗機制產生了很多負特徵值,但是它很大程度上減少了特徵值的度量。有圖看出swim相比其他的模型更加集中於靠近0的位置,說明swim的loss更加的平滑,甚至當訓練過後也是。
    • PIT相對於VIT而言是一個multi-stage多級的,不斷的把模型shape縮小,深度增加,這種結構同樣抑制了負的海塞矩陣特徵值。
    • 兩種常用的設計如何影響特徵值平譜。
  • 除此之外:

    • MSA中head的數量 = loss landspace convexity,head越多歸納偏置越強,因為每個head只跟自己交流。
    • NEP衡量負樣本比例,APE衡量loss陡峭程度,隨著head增加,都是下降的趨勢,有圖表示head越大,離0越近表示越來越平坦,面積越小表示越來越凸。
    • 此圖就是說明head的深度Embed dim

Conclusion 3

  • loss landscape smoothing methods aisds 更平滑更凸
    • 首先用GAPGlobal Average Pooling而不用CLS
    • SAMSharpness-Aware Minimization改善VIT的效能(另外兩篇論文,一種梯度下降優化演算法)

總結

  • 歸納偏置 <=> loss landscape convexity / smoothness & flatness <=> 海塞矩陣最大值norm
  • MSA為什麼讓效能更好?
    • 平坦 / loss不凸 / 資料量 / 平滑

另一個觀察

  • Data Specificity (not long-range dependency) 資料特異性(非全域性關聯)
    • 實驗:用NLP思想,將global MSA替換2D conv MSA (1d區域性,2d區域性相鄰head做attention)
    • 8X8kernel等效為全域性attention,區域性更加有利
    • Data Specificity(attention)替代long-range dependency,用資料特異性來替代全域性attention

Q2:MSAs和Convs差異

  • Convs:data-agnostic and channel-specific 資料無關和通道特定,不管資料怎麼樣,權重都是固定好的,按照同樣的權重去提取特徵,特徵放到特定的位置/channel
  • MSA:data-specific and channel-agnostic 資料特定和通道無關,attention計算是和資料本身有關的,都是進來相乘做attention,進來順序就不重要了。
  • 可以看出這兩者是相互的。

Conclusion 1

  • MSA是低通濾波器(本質上就是把所有空間上的值求個平均),Conv是高通濾波器
    • 作者對兩者輸入不同頻率得到輸出後的視覺化,可以看到,Convs對高頻損失不大,低波段下降明顯,而MSA相反。同樣右圖,對加入不同頻率噪聲的影響,Convs對高頻噪聲反應大,MAS對低頻噪聲反應大。
    • 不過對於swim而言,它同樣可以保持一定的高頻訊號。
    • 看到可以看出,灰色部分的高層時可以減少高頻訊號的響應,而在白色部分都是增加高頻訊號的響應。低層的時候與高層相反。

Conclusion 2

  • MSAs聚合特徵,而Convs相反讓特徵更加多樣。
  • 白色卷積/MLP多層感知機部分提高方差,藍色部分下采樣降低方差。
  • 有意思的是在swim中,表現得更像是一個Convs的結構,方差逐漸的增加,甚至在做下采樣的時候方差還在增加,直到最後的時候方差才降下來,不知道如何解釋。
  • 總的來說這意味著就是,我們或許可以將兩者的性質結合起來設計一個更好的網路。

Q3:結合MSA+Conv

  • image-20220318173932830

  • 可以看出來,在Resnet/PIT/Swin中是多層結構,(小塊)層之內相關性明顯,層之間相關性很弱 / PIT也是一樣。而在VIT裡面,由於本身就沒有Multi-stage的概念,所以一大塊都是相關的。

  • 這個圖使用:Minibatch Centered Kernel Alignment(CKA)計算出來的。

  • image-20220318174506060

  • 從已經訓練好的Res和Swin裡面移除網路單元進行測試效能

  • 對於ResNet而言,移除早期的模組比後期的模組更重要,同一層中移除前面的比移除後面的更致命。

  • 對於Swin而言,在stage(以藍色下采樣為劃分)中,開頭移除MLP會大幅影響準確性,結尾移除MSA會大幅影響準確性。

  • 基於以上觀察,把Convs逐漸替換成Attention有三個準則:

    • 1、從全域性最後開始,每隔一層把Conv塊替換成MSA塊
    • 2、如果替換的MSA並不能增加我們模型的效能,那麼我們去找到上一個stage,在該stage的最後把Conv替換成MSA(同時不能增加效能的MSA還原成Conv)
    • 3、在相對比較靠後的stage裡,我們的MSA需要有更多的head以及更大的hidden dimensions
    • 1、Alternately replace Conv blocks with MSA blocks from the end.
    • 2、lf the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA block .
    • 3、Use more heads and higher hidden dimensions for MSA blocks in late stages.
  • 可以看到替換後的結果:精度有提升,魯棒性有提升,eigenvalue頻譜可以看出AlterNet更加平滑(特徵值更加接近0)。魯棒性計算另一篇論文。

  • 不同資料集針對的修改網路不一樣,ImageNet相比於CIFAR上多了兩個MSA層,這意味著我們有更弱的歸納偏置,更加強的表達效能,更多的資料才能支撐起更多的MSA。

  • 我們可以看到,相比於swin/ twins這種純MSA的網路而言,在大資料集下效能肯定是不如的,因為他們的歸納偏置更加的弱。

一個發現

  • data augmention資料增強
    • 1、Data augmentation can harm uncertainty calibration 資料增強會損害不確定性校準
    • 在沒有進行資料增強前,模型對自己的預測比較自信,在進行資料增強後,模型對自己的預測反而不太自信了。
    • 2、Data augmentation reduces the magnitude of Hessian eigenvalues資料增強降低了 Hessian 特徵值的大小
    • 可以看出加了資料增強後的特徵值更加趨向於0了,即表明loss變得更加平滑了,負的特徵值變多了表示loss函式不凸了。

總結

  • 附錄B:

    • MSA is Spatial smoothing: Appendix B
      Taking all these observations together, we provide an explanation of how MSAs work by addressing themselves as a general form of spatial smoothing or an implementation of data-complemented BNN. Spatial smoothing improves performance in the following ways:
      1、Spatial smoothing helps in NN optimization by flattening the loss landscapes. Even a small 2×2 box blurfilter significantly improves performance.
      2、Spatial smoothing is a low-pass filter. CNNs are vulnerable to high-frequency noises, but spatial smoothing improves the robustness against suchnoises by significantly reducing these noises.
      3、Spatial smoothing is effective when applied atthe end of a stage because it aggregates all transformed feature maps. This paper shows that these mechanisms also apply to MSAs.

相關文章