[論文簡析]How Do Vision Transformers Work?[2202.06709]
-
論文題目:How Do Vision Transformers Work?
-
ICLR2022 - Reviewer Kvf7:
- 這個文章整理的太難懂了
- 很多trick很有用,但是作者並沒有完全說明
-
行文線索 Emporocal Observations:
- MSAs(多頭自注意力機制 / 一般取代CNN)能夠提高CNN的預測效能,VIT裡面能夠很好的去預測 well-calibrated uncertainty P(模型輸出的預測概率值)
- 魯棒性,對於data corruptions、image occlusions、adversarial attacks、特別是對high-frequency noisy 高頻噪聲
- 靠近最後幾層的MSAs能夠顯著的提高我們的效能
Q1:Inductive Biases 歸納偏置
-
歸納偏置可以理解為,從現實生活中觀察到的現象中歸納出一定的規則(heuristics),然後對模型做一定的約束,從而可以起到“模型選擇”的作用,即從假設空間中選擇出更符合現實規則的模型。其實,貝葉斯學習中的“先驗(Prior)”這個叫法,可能比“歸納偏置”更直觀一些。
- CNN的inductive bias應該是locality和spatial invariance,即空間相近的grid elements有聯絡而遠的沒有,和空間不變性(kernel權重共享)
- RNN的inductive bias是sequentiality和time invariance,即序列順序上的timesteps有聯絡,和時間變換的不變性(rnn權重共享)
-
MSA本質上是一個
generalized spatial smoothing
廣義空間平滑,由幾個value在進行求和,權重由Q和K來給定- 相對比於CNN和RNN而言,MSA的歸納偏置是
weak Inductive Biases
,由於是全域性soft-attention
,因此會有長距離的關係long-range dependencies
- 因此適當的約束
Appropriate constraints
/ 強歸納偏置可能會對模型具有幫助:swin、twins
(做區域性attention)
- 相對比於CNN和RNN而言,MSA的歸納偏置是
Conclusion 1
- 因此
Conclusion 1
:歸納偏置越強,預測/特徵學出來的就越強。- RX:ResNeXt,R:ResNet,Mixes是最自由的
- 疑問1:會不會VIT這些模型在CIFAR-100小資料集上overfit過擬合才會導致效能差?- 實際上不是如此,可以看出隨著訓練集大小和訓練時間長短變化,NLLtrain和Error仍然是相關,並沒有出現過擬合現象。
Conclusion 2
-
疑問2:能不能衡量convexity? - 統計很多個海塞矩陣的最大特徵值,再用來做一個平譜來反映loss function的形狀特性
Hessian Max Eigenvalue Spectrum
(是這個作者的另一篇論文) -
歸納偏置的作用:強的能夠壓制負的特徵值
negative eigenvalues
-
- VIT 6%表示少量資料集下,負特徵值非常明顯,隨著樣本數量增加,負特徵值得到抑制
-
loss平滑是由MSA導致的,但是這並不一定是一個壞事:
regarding generalization & performance
具有更好的表達能力,泛化可塑,隨著樣本數量的增加,可以去壓制負特徵值,讓loss function變得沒有那麼平坦,讓凸一些,與VIT適用於大資料量樣本的觀察是契合的。- ResNet更加陡峭,在大資料量樣本下很容易陷入區域性最優
- 夾角越大說明離初始模型越來越不一樣,
Transformer
過度的十分平滑,越強的歸納偏置會導致優化的更加曲折,可以理解為執行力
-
以上就是說明,需要在歸納偏置和資料量之間尋找到一個平衡:
- patch_size ≈ CNN中的kernel大小,大小越大,偏執歸納越強,右圖可以看出歸納偏置越強對負特徵值的抑制越強(本質就是如何作用於loss function)
- 在ImageNet上,就是由於CNN其歸納偏置太強,導致沒有MSA那麼靈活,沒有那麼強的表達能力 / 泛化能力表達複雜的資料集。
-
- local MSA / swin 滑窗機制產生了很多負特徵值,但是它很大程度上減少了特徵值的度量。有圖看出swim相比其他的模型更加集中於靠近0的位置,說明swim的loss更加的平滑,甚至當訓練過後也是。
- PIT相對於VIT而言是一個
multi-stage
多級的,不斷的把模型shape縮小,深度增加,這種結構同樣抑制了負的海塞矩陣特徵值。 - 兩種常用的設計如何影響特徵值平譜。
-
除此之外:
- MSA中head的數量 = loss landspace convexity,head越多歸納偏置越強,因為每個head只跟自己交流。
- NEP衡量負樣本比例,APE衡量loss陡峭程度,隨著head增加,都是下降的趨勢,有圖表示head越大,離0越近表示越來越平坦,面積越小表示越來越凸。
- 此圖就是說明head的深度
Embed dim
。
Conclusion 3
- loss landscape smoothing methods aisds 更平滑更凸
- 首先用GAP
Global Average Pooling
而不用CLS - SAM
Sharpness-Aware Minimization
改善VIT的效能(另外兩篇論文,一種梯度下降優化演算法)
- 首先用GAP
總結
- 歸納偏置 <=> loss landscape convexity / smoothness & flatness <=> 海塞矩陣最大值norm
- MSA為什麼讓效能更好?
- 平坦 / loss不凸 / 資料量 / 平滑
另一個觀察
- Data Specificity (not long-range dependency) 資料特異性(非全域性關聯)
- 實驗:用NLP思想,將
global MSA
替換2D conv MSA
(1d區域性,2d區域性相鄰head做attention) - 8X8kernel等效為全域性attention,區域性更加有利
- Data Specificity(attention)替代long-range dependency,用資料特異性來替代全域性attention
- 實驗:用NLP思想,將
Q2:MSAs和Convs差異
- Convs:data-agnostic and channel-specific 資料無關和通道特定,不管資料怎麼樣,權重都是固定好的,按照同樣的權重去提取特徵,特徵放到特定的位置/channel
- MSA:data-specific and channel-agnostic 資料特定和通道無關,attention計算是和資料本身有關的,都是進來相乘做attention,進來順序就不重要了。
- 可以看出這兩者是相互的。
Conclusion 1
- MSA是低通濾波器(本質上就是把所有空間上的值求個平均),Conv是高通濾波器
- 作者對兩者輸入不同頻率得到輸出後的視覺化,可以看到,Convs對高頻損失不大,低波段下降明顯,而MSA相反。同樣右圖,對加入不同頻率噪聲的影響,Convs對高頻噪聲反應大,MAS對低頻噪聲反應大。
- 不過對於swim而言,它同樣可以保持一定的高頻訊號。
- 看到可以看出,灰色部分的高層時可以減少高頻訊號的響應,而在白色部分都是增加高頻訊號的響應。低層的時候與高層相反。
Conclusion 2
- MSAs聚合特徵,而Convs相反讓特徵更加多樣。
- 白色卷積/MLP多層感知機部分提高方差,藍色部分下采樣降低方差。
- 有意思的是在swim中,表現得更像是一個Convs的結構,方差逐漸的增加,甚至在做下采樣的時候方差還在增加,直到最後的時候方差才降下來,不知道如何解釋。
- 總的來說這意味著就是,我們或許可以將兩者的性質結合起來設計一個更好的網路。
Q3:結合MSA+Conv
-
可以看出來,在Resnet/PIT/Swin中是多層結構,(小塊)層之內相關性明顯,層之間相關性很弱 / PIT也是一樣。而在VIT裡面,由於本身就沒有Multi-stage的概念,所以一大塊都是相關的。
-
這個圖使用:Minibatch Centered Kernel Alignment(CKA)計算出來的。
-
從已經訓練好的Res和Swin裡面移除網路單元進行測試效能
-
對於ResNet而言,移除早期的模組比後期的模組更重要,同一層中移除前面的比移除後面的更致命。
-
對於Swin而言,在stage(以藍色下采樣為劃分)中,開頭移除MLP會大幅影響準確性,結尾移除MSA會大幅影響準確性。
-
基於以上觀察,把Convs逐漸替換成Attention有三個準則:
- 1、從全域性最後開始,每隔一層把Conv塊替換成MSA塊
- 2、如果替換的MSA並不能增加我們模型的效能,那麼我們去找到上一個stage,在該stage的最後把Conv替換成MSA(同時不能增加效能的MSA還原成Conv)
- 3、在相對比較靠後的stage裡,我們的MSA需要有更多的head以及更大的
hidden dimensions
- 1、Alternately replace Conv blocks with MSA blocks from the end.
- 2、lf the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA block .
- 3、Use more heads and higher hidden dimensions for MSA blocks in late stages.
-
可以看到替換後的結果:精度有提升,魯棒性有提升,eigenvalue頻譜可以看出AlterNet更加平滑(特徵值更加接近0)。魯棒性計算另一篇論文。
-
不同資料集針對的修改網路不一樣,ImageNet相比於CIFAR上多了兩個MSA層,這意味著我們有更弱的歸納偏置,更加強的表達效能,更多的資料才能支撐起更多的MSA。
-
我們可以看到,相比於swin/ twins這種純MSA的網路而言,在大資料集下效能肯定是不如的,因為他們的歸納偏置更加的弱。
一個發現
data augmention
資料增強- 1、
Data augmentation can harm uncertainty calibration
資料增強會損害不確定性校準 - 在沒有進行資料增強前,模型對自己的預測比較自信,在進行資料增強後,模型對自己的預測反而不太自信了。
- 2、
Data augmentation reduces the magnitude of Hessian eigenvalues
資料增強降低了 Hessian 特徵值的大小 - 可以看出加了資料增強後的特徵值更加趨向於0了,即表明loss變得更加平滑了,負的特徵值變多了表示loss函式不凸了。
- 1、
總結
-
附錄B:
- MSA is Spatial smoothing: Appendix B
Taking all these observations together, we provide an explanation of how MSAs work by addressing themselves as a general form of spatial smoothing or an implementation of data-complemented BNN. Spatial smoothing improves performance in the following ways:
1、Spatial smoothing helps in NN optimization by flattening the loss landscapes. Even a small 2×2 box blurfilter significantly improves performance.
2、Spatial smoothing is a low-pass filter. CNNs are vulnerable to high-frequency noises, but spatial smoothing improves the robustness against suchnoises by significantly reducing these noises.
3、Spatial smoothing is effective when applied atthe end of a stage because it aggregates all transformed feature maps. This paper shows that these mechanisms also apply to MSAs.
- MSA is Spatial smoothing: Appendix B