何愷明團隊最新力作SimSiam:消除表徵學習“崩潰解”,探尋對比表達學習成功之根源

極市平臺發表於2020-11-24

該文是FAIR的陳鑫磊&何愷明大神在無監督學習領域又一力作,提出了一種非常簡單的表達學習機制用於避免表達學習中的“崩潰”問題,從理論與實驗角度證實了所提方法的有效性;與此同時,還側面證實了對比學習方法成功的關鍵性因素:孿生網路。

paper: https://arxiv.org/abs/2011.10566

本文為極市平臺原創,作者Happy,轉載需獲授權。

Abstract

孿生網路已成為無監督表達學習領域的通用架構,現有方法通過最大化同一影像的兩者增廣的相似性使其避免“崩潰解(collapsing solutions)”問題。在這篇研究中,作者提出一種驚人的實證結果:**Simple Siamese(SimSiam)**網路甚至可以在無((1) negative sample pairs;(2)large batch;(3)momentum encoders)的情形下學習有意義的特徵表達。

作者通過實驗表明:對於損失與結構而言,“崩潰解”確實存在,但是“stop-gradient”操作對於避免“崩潰解”有非常重要的作用。作者提出了一種新穎的“stop-gradient”思想並通過實驗對其進行了驗證,該文所提SimSiam在ImageNet及下游任務上均取得了有競爭力的結果。作者期望:這個簡單的基準方案可以驅動更多研員重新思考無監督表達學習中的孿生結構。

Method

上圖給出了該文所提SimSiam的示意圖,它以影像 x x x的兩個隨機變換 x 1 , x 2 x_1, x_2 x1,x2作為輸入,通過相同的編碼網路 f f f(它包含一個骨幹網路和一個投影MLP頭模組,表示為h)提取特徵並變換到高維空間。此外作者還定義了一個預測MLP頭模組h,對其中一個分支的結果進行變換並與另一個分支的結果進行匹配,該過程可以描述為 p 1 = h ( f ( x 1 ) ) , z 2 = f ( x 2 ) p_1 = h(f(x_1)), z_2 = f(x_2) p1=h(f(x1)),z2=f(x2),SimSiam對上述特徵進行負cosine相似性最小化:

D ( p 1 , z 2 ) = − p 1 ∥ p 1 ∥ 2 ⋅ z 2 ∥ z 2 ∥ 2 \mathcal{D}(p_1, z_2) = - \frac{p_1}{\|p_1\|_2} \cdot \frac{z_2}{\|z_2\|_2} D(p1,z2)=p12p1z22z2

注:上述公式等價於 l 2 l_2 l2規範化向量的MSE損失。與此同時,作者還定義了一個對稱損失:

L = 1 2 D ( p 1 , z 2 ) + 1 2 D ( p 2 , z 1 ) \mathcal{L} = \frac{1}{2}\mathcal{D}(p_1, z_2) + \frac{1}{2}\mathcal{D}(p_2, z_1) L=21D(p1,z2)+21D(p2,z1)
上述兩個損失作用於每一張影像,總損失是所有影像損失的平均,故最小的可能損失為-1.

需要的是:該文一個非常重要的概念是Stop-gradient操作(即上圖的右分支部分)。可以通過對上述公式進行簡單的修改得到本文的損失函式:

D ( p 1 , s t o p g r a d ( z x ) ) L = 1 2 D ( p 1 , s t o p g r a d ( z 2 ) ) + 1 2 D ( p 2 , s t o p g r a d ( z 1 ) ) \mathcal{D}(p_1, stopgrad(z_x)) \\ \mathcal{L} = \frac{1}{2}\mathcal{D}(p_1, stopgrad(z_2)) + \frac{1}{2}\mathcal{D}(p_2, stopgrad(z_1)) D(p1,stopgrad(zx))L=21D(p1,stopgrad(z2))+21D(p2,stopgrad(z1))

也就是說:在損失 L \mathcal{L} L的第一項, x 2 x_2 x2不會從 z 2 z_2 z2接收梯度資訊;在其第二項,則會從 p 2 p_2 p2接收梯度資訊。

SimSiam的實現虛擬碼如下,有沒有一種“就這麼簡單”的感覺???

# Algorithm1 SimSiam Pseudocode, Pytorch-like
# f: backbone + projection mlp
# h: prediction mlp
for x in loader: # load a minibatch x with n samples
	x1, x2 = aug(x), aug(x) # random augmentation
	z1, z2 = f(x1), f(x2) # projections, n-by-d
	p1, p2 = h(z1), h(z2) # predictions, n-by-d
	L = D(p1, z2)/2 + D(p2, z1)/2 # loss
	L.backward() # back-propagate
	update(f, h) # SGD update
    
def D(p, z): # negative cosine similarity
	z = z.detach() # stop gradient
	p = normalize(p, dim=1) # l2-normalize
	z = normalize(z, dim=1) # l2-normalize
return -(p*z).sum(dim=1).mean()

我們再來看一下SimSiam的基礎配置:

  • Optimizer: SGD用於預訓練,學習率為 l r × B a t c h S i z e / 256 lr \times BatchSize/256 lr×BatchSize/256, 基礎學習率為 l r = 0.05 lr=0.05 lr=0.05,學習率採用consine衰減機制,weight decay=0.0001,momentum=0.9。BatchSize預設512,採用了SynBatchNorm。
  • Projection MLP:編碼網路中投影MLP部分的每個全連線層後接BN層,其輸出層 f c fc fc後無ReLU,隱含層的 f c fc fc的維度為2048,MLP包含三個全連線層。
  • Prediction MLP:預測MLP中同樣適用了BN層,但其輸出層 f c fc fc後無BN與ReLU。MLP有2個全連線層,第一個全連線層的輸入與輸出維度為2048,第二個的輸出維度為512.
  • Backbone:作者選用了ResNet50作為骨幹網路。

作者在ImageNet上線進行無監督預訓練,然後採用監督方式凍結骨幹網路訓練分類頭,最後在驗證集上驗證其效能。

Empirical Study

在該部分內容中,我們將實證研究SimSiam的表現,主要聚焦於哪些行為有助於避免“崩潰解”。

Stop-gradient

上圖給出了Stop-gradient新增與否的效能對比,注網路架構與超參保持不變,區別僅在於是否新增Stop-gradient

上圖left表示訓練損失,可以看到:在無Stop-gradient時,優化器迅速找了了一個退化解並達到了最小可能損失-1。為證實上述退化解是“崩潰”導致的,作者研究了輸出的 l 2 l_2 l2規範化結果的標準差。如果輸出“崩潰”到了常數向量,那麼其每個通道的標準差應當是0,見上圖middle。

作為對比,如果輸出具有零均值各項同性高斯分佈,可以看到其標準差為 1 d \frac{1}{\sqrt{d}} d 1。上圖middle中的藍色曲線(即新增了Stop-gradient)接近 1 d \frac{1}{\sqrt{d}} d 1,這也就意味著輸出並沒有“崩潰”。

上圖right給出了KNN分類器的驗證精度,KNN分類器可用於訓練過程的監控。在無Stop-gradient時,其分類進度僅有0.1%,而新增Stop-gradient後最終分類精度可達67.7%。

上述實驗表明:“崩潰”確實存在。但“崩潰”的存在不足以說明所提方法可以避免“崩潰”,儘管上述對比中僅有“stop-gradient”的區別。

Predictor

上表給出了Predictor MLP的影響性分析,可以看到:

  • 當移除預測MLP頭模組h(即h為恆等對映)後,該模型不再有效(work);

  • 如果預測MLP頭模組h固定為隨機初始化,該模型同樣不再有效;

  • 當預測MLP頭模組採用常數學習率時,該模型甚至可以取得比基準更好的結果(多個實驗中均有類似發現).

Batch Size

上表給出了Batch Size從64變換到4096過程中的精度變化,可以看到:該方法在非常大範圍的batch size下表現均非常好

Batch Normalization

上表比較了投影與預測MLP中不同BN的配置對比,可以看到:

  • 移除所有BN層後,儘管精度只有34.6%,但不會造成“崩潰”;這種低精度更像是優化難問題,對隱含層新增BN後精度則提升到了67.4%;
  • 在投影MLP的輸出後新增BN,精度可以進一步提升到68.1%;
  • 在預測MLP的輸出新增BN後反而導致訓練變的不穩定。

總而言之,BN有助於訓練優化,這與監督學習中BN的作用類似;但並未看到BN有助於避免“崩潰”的證據

Similarity Function

所提方法除了與cosine相似性組合表現好外,其與交叉熵相似組合表現同樣良好,見上表。此時的交叉熵相似定義如下:
D = − s o f t m a x ( z x ) ⋅ log s o f t m a x ( p 1 ) \mathcal{D} = -softmax(z_x) \cdot \text{log} softmax(p_1) D=softmax(zx)logsoftmax(p1)
可以看到:交叉熵相似性同樣可以收斂到一個合理的解並不會導致“崩潰”,這也就是意味著“崩潰”避免行為與cosine相似性無關。

Symmetrization

儘管前述描述中用到了對稱損失,但上表的結果表明:SimSiam的行為不依賴於對稱損失:非對稱損失同樣取得了合理的結果,而對稱損失有助於提升精度,這與“崩潰”避免無關

Summary

通過上面的一些列消融實驗對比分析,可以看到:SimSiam可以得到有意義的結果而不會導致“崩潰”。優化器、BN、相似性函式、對稱損失可能會影響精度,但與“崩潰”避免無關;對於“崩潰”避免起關鍵作用的是stop-gradient操作。

Hypothesis

接下來,我們將討論:SimSiam到底在隱式的優化什麼?並通過實驗對其進行驗證。主要從定義、證明以及討論三個方面進行介紹。

Formulation

作者假設:SimSiam是類期望最大化演算法的一種實現。它隱含的包含兩組變數,並解決兩個潛在子問題,而stop-gradient操作是引入額外變換的結果。我們考慮如下形式的損失:
L ( θ , η ) = E x , τ [ ∥ F θ ( τ ( x ) ) − η x ∥ 2 2 ] \mathcal{L}(\theta, \eta) = E_{x, \tau}[\|\mathcal{F}_{\theta}(\tau(x)) - \eta_x\|_2^2] L(θ,η)=Ex,τ[Fθ(τ(x))ηx22]

其中 F , τ \mathcal{F}, \tau F,τ分別表示特徵提取網路與資料增廣方法,x表示影像。在這裡,作者引入了另外一個變數 η \eta η,其大小正比於影像數量,直觀上來講, η x \eta_x ηx是x的特徵表達。

基於上述表述,我們考慮如下優化問題:
m i n θ , η L ( θ , η ) min_{\theta, \eta} \mathcal{L}(\theta, \eta) minθ,ηL(θ,η)
這種描述形式類似於k-means聚類問題,變數 θ \theta θ與聚類中心類似,是一個可學習引數;變數 η x \eta_x ηx與樣本x的對應向量(類似k-means的one-hot向量)類似:即它是x的特徵表達。類似於k-means,上述問題可以通過交替方案(固定一個,求解另一個)進行求解:
θ t ← a r g m i n θ L ( θ , η t − 1 ) η t ← a r g m i n η L ( θ t , η ) \theta^t \leftarrow argmin_{\theta} \mathcal{L}(\theta, \eta^{t-1}) \\ \eta^t \leftarrow argmin_{\eta} \mathcal{L} (\theta^t, \eta) θtargminθL(θ,ηt1)ηtargminηL(θt,η)
對於 θ \theta θ的求解,可以採用SGD進行子問題求解,此時stop-gradient是一個很自然的結果,因為梯度先不要反向傳播到 η t − 1 \eta^{t-1} ηt1,在該子問題中,它是一個常數;對於 η \eta η的七屆,上述問題將轉換為:
η x t ← E τ [ F θ t ( τ ( x ) ) ] \eta^t_x \leftarrow E_{\tau} [\mathcal{F}_{\theta^t}(\tau(x))] ηxtEτ[Fθt(τ(x))]

結合前述介紹,SimSiam可以視作上述求解方案的一次性交替近似。

此外需要注意:(1)上述分析並不包含預測器h;(2) 上述分析並不包含對稱損失,對稱損失並非該方法的必選項,但有助於提升精度。

Proof of concept

作者假設:SimSiam是一種類似交錯優化的方案,其SGD更新間隔為1。基於該假設,所提方案在多步SGD更新下同樣有效。為此,作者設計了一組實驗驗證上述假設,結果見下表。

在這裡, 1 − s t e p 1-step 1step等價與SimSiam。可以看到:multi-step variants work well。更多步的SGD更新甚至可以取得比SimSiam更優的結果。這就意味著:交錯優化是一種可行的方案,而SimSiam是其特例。

Comparison

前述內容已經說明了所提方法的有效性,接下來將從ImageNet以及遷移學習的角度對比一下所提方法與其他SOTA方法。

上圖給出了所提方法與其他SOTA無監督學習方法在ImageNet的效能,可以看到:SimSiam可以取得具有競爭力的結果。在100epoch訓練下,所提方法具有最高的精度;但更長的訓練所得收益反而變小。

上表給出了所提方法與其他SOTA方法在遷移學習方面的效能對比。從中可以看到:SimSiam表達可以很好的遷移到ImageNet以外的任務上,遷移模型的效能極具競爭力

最後,作者對比了所提方法與其他SOTA方法的區別&聯絡所在,見上圖。

  • Relation to SimCLR:SimCLR依賴於負取樣以避免“崩潰”,SimSiam可以是作為“SimCLR without negative”。

  • Relation to SwAV:SimSiam可以視作“SwAV without online clustering”.

  • Relation to BYOL: SimSiam可以視作“BYOL without the momentum encoder”.

全文到此結束,對該文感興趣的同學建議去檢視原文的實驗結果與實驗分析。

Conclusion

該文采通過非常簡單的設計探索了孿生網路,所提方法方法的有效性意味著:孿生形狀是這些表達學習方法(SimCLR, MoCo,SwAR等)成功的關鍵原因所在。孿生網路天然具有建模不變性的特徵,而這也是表達學習的核心所在。

相關文章

  1. SimCLR: A simple framework for contrastive learning of visual representations
  2. SimCLRv2: Big self-supervised models are strong semi-supervised learners.
  3. SwAV:Unsupervised learning of visual features by contrasting cluster assignments
  4. MoCo: Momentum contrast for unsupervised visual representation learning.
  5. MoCov2:Improved baselines with momentum contrastive learning
  6. BYOL: Bootstrap your own latten: A new aproach to self-supervised learning.
  7. CPC: Data efficient image recognition with contrastive predictive coding.
  8. PIC: Parametric instance classification for unsupervised visual feature learning.

相關文章