殘差網路(Residual Networks, ResNets)

wuliytTaotao發表於2018-09-15

1. 什麼是殘差(residual)?

  “殘差在數理統計中是指實際觀察值與估計值(擬合值)之間的差。”“如果迴歸模型正確的話, 我們可以將殘差看作誤差的觀測值。”

  更準確地,假設我們想要找一個 $x$,使得 $f(x) = b$,給定一個 $x$ 的估計值 $x_0$,殘差(residual)就是 $b-f(x_0)$,同時,誤差就是 $x-x_0$。

  即使 $x$ 不知道,我們仍然可以計算殘差,只是不能計算誤差罷了。

2. 什麼是殘差網路(Residual Networks,ResNets)?

  在瞭解殘差網路之前,先了解下面這個問題。

  Q1:神經網路越深越好嗎?(Deeper is better?)

  A1:如圖 1 所示,在訓練集上,傳統神經網路越深效果不一定越好。而 Deep Residual Learning for Image Recognition 這篇論文認為,理論上,可以訓練一個 shallower 網路,然後在這個訓練好的 shallower 網路上堆幾層 identity mapping(恆等對映) 的層,即輸出等於輸入的層,構建出一個 deeper 網路。這兩個網路(shallower 和 deeper)得到的結果應該是一模一樣的,因為堆上去的層都是 identity mapping。這樣可以得出一個結論:理論上,在訓練集上,Deeper 不應該比 shallower 差,即越深的網路不會比淺層的網路效果差。但為什麼會出現圖 1 這樣的情況呢,隨著層數的增多,訓練集上的效果變差?這被稱為退化問題(degradation problem),原因是隨著網路越來越深,訓練變得原來越難,網路的優化變得越來越難。理論上,越深的網路,效果應該更好;但實際上,由於訓練難度,過深的網路會產生退化問題,效果反而不如相對較淺的網路。而殘差網路就可以解決這個問題的,殘差網路越深,訓練集上的效果會越好。(測試集上的效果可能涉及過擬合問題。過擬合問題指的是測試集上的效果和訓練集上的效果之間有差距。)

圖 1 不同深度的傳統神經網路效果對比圖

(“plain” network指的是沒有使用 shortcut connection 的網路)

  殘差網路通過加入 shortcut connections,變得更加容易被優化。包含一個 shortcut connection 的幾層網路被稱為一個殘差塊(residual block),如圖 2 所示。

圖 2 殘差塊

  2.1 殘差塊(residual block)

  如圖 2 所示,$x$ 表示輸入,$F(x)$ 表示殘差塊在第二層啟用函式之前的輸出,即 $F(x) = W_2\sigma(W_1x)$,其中 $W_1$ 和 $W_2$ 表示第一層和第二層的權重,$\sigma$ 表示 ReLU 啟用函式。(這裡省略了 bias。)最後殘差塊的輸出是 $\sigma(F(x) + x)$。

  當沒有 shortcut connection(即圖 2 右側從 $x$ 到 $\bigoplus$ 的箭頭)時,殘差塊就是一個普通的 2 層網路。殘差塊中的網路可以是全連線層,也可以是卷積層。設第二層網路在啟用函式之前的輸出為 $H(x)$。如果在該 2 層網路中,最優的輸出就是輸入 $x$,那麼對於沒有 shortcut connection 的網路,就需要將其優化成 $H(x) = x$;對於有 shortcut connection 的網路,即殘差塊,最優輸出是 $x$,則只需要將 $F(x) = H(x) - x$ 優化為 0 即可。後者的優化會比前者簡單。這也是殘差這一叫法的由來。

  2.2 殘差網路舉例

  圖 3 最右側就是就是一個殘差網路。34-layer 表示含可訓練引數的層數為34層,池化層不含可訓練引數。圖 3 右側所示的殘差網路和中間部分的 plain network 唯一的區別就是 shortcut connections。這兩個網路都是當 feature map 減半時,filter 的個數翻倍,這樣保證了每一層的計算複雜度一致。

  ResNet 因為使用 identity mapping,在 shortcut connections 上沒有引數,所以圖 3 中 plain network 和 residual network 的計算複雜度都是一樣的,都是 3.6 billion FLOPs.


圖 3  VGG-19、plain network、ResNet

  殘差網路可以不是卷積神經網路,用全連線層也可以。當然,殘差網路在被提出的論文中是用來處理影象識別問題。

  2.3 為什麼殘差網路會work?

  我們給一個網路不論在中間還是末尾加上一個殘差塊,並給殘差塊中的 weights 加上 L2 regularization(weight decay),這樣圖 1 中 $F(x) = 0$ 是很容易的。這種情況下加上一個殘差塊和不加之前的效果會是一樣,所以加上殘差塊不會使得效果變得差。如果殘差塊中的隱藏單元學到了一些有用資訊,那麼它可能比 identity mapping(即 $F(x) = 0$)表現的更好。

  "The main reason the residual network works is that it's so easy for these extra layers to learn the identity function that you're kind of guaranteed that it doesn't hurt performance. And then lot of time you maybe get lucky and even helps performance, or at least is easier to go from a decent baseline of not hurting performance, and then creating the same can only improve the solution from there."

相關文章