探索無限大的神經網路
AI 科技評論按:今天我們要介紹對無限大的神經網路的探索。這個故事來自 CMU 杜少雷和胡威撰寫的部落格《Ultra-Wide Deep Nets and Neural Tangent Kernel (NTK)》,雷鋒網 AI 科技評論編譯。讓我們從頭講起。
被推翻的「複雜度甜點」
機器學習界有一條流傳已久的規矩:你需要在模型的訓練誤差和泛化能力之間謹慎地做出取捨。衡量泛化能力,有一個很便捷的指標是看看模型在訓練集和測試集上的誤差相差多大,那麼,一個較小的模型通常很難在訓練集上做到很小的訓練誤差,不過這個誤差和測試集上的測試誤差在同一水平;換用更大的模型以後一般都可以得到更小的訓練誤差,但是測試誤差往往會比訓練誤差大不少;也就是說,太大、太小的模型都無法得到較好的測試誤差。所以大家總結,要尋找模型複雜度的甜點(「sweet spot」),複雜度要足夠大,能夠達到足夠低的訓練誤差;複雜度也不能太大,避免測試誤差比訓練誤差大太多。下面這個經典的 U 型曲線就是根據這個理論繪製的 —— 測試誤差隨著模型複雜度的增加而先減小、再增大。
不過,隨著深度神經網路之類的高度複雜、高度過引數化(over-parameterized)的模型得到廣泛研究和使用,大家發現它們經常可以在訓練資料集上做到接近 0 的誤差,然後還能在測試資料上發揮出令人驚訝地好的表現,如今的網路大小也就隨著越來越大。Belkin 等人()用一種新的雙峰曲線描述了這個現象,他們在經典的 U 型曲線的右邊繼續延伸,描繪出:當模型的複雜度繼續增大,越過了「模型複雜度足以完全擬合訓練資料」(比如可以用模型為資料點取差值)的那個點之後,測試誤差就可以持續下降!有趣的是,越大的模型往往能給出越好的結果,已經跳出了以往的「複雜度甜點」的考慮範疇。如下圖。
有人懷疑深度學習中使用的最佳化演算法,比如梯度下降、隨機梯度下降以及各種變體,其實起到了隱式地限制模型複雜度的效果(也就是說,雖然整個模型中的引數很多,但其中真正獨立有效的引數只有一部分),也就避免了過擬合,避免了測試誤差和訓練誤差相差過大。
另外,「越大的模型往往能給出越好的結果」,所以很自然地有人會問「如果我們有一個無限大的網路,它的表現會如何?」按照上面那張雙峰圖,答案就對應著隱藏在影像的最右側的東西。上一年中這是一個熱門的研究問題:神經網路的寬度,也就是卷積層中的通道數目、或者全連線隱層中的神經元數目,趨近於無窮大的時候會得到怎樣的表現。
核方法帶來一個契機
乍看上去這個問題是無解的,要做實驗的話,有再多的計算資源也無法訓練一個真正「無限大」的網路;而要理論分析的話,有限大小的網路都還沒有研究清楚呢。不過,數學和物理領域一直都有研究「趨於無限大」從而得到新的見解的慣例,研究「趨於無限大」也在理論上更容易一點。
研究深度神經網路的學者們可能還記得無限寬的神經網路和核方法之間的聯絡,25 年前 Neal (~radford/pin.abstract.html)闡述過,Lee 等()和 Matthews 等()近期也做了擴充。這些核可以對應所有引數都隨機選擇、且只有最上層(分類器層)用梯度下降訓練過的的無限寬的深度神經網路。具體來說,如果我們用 θ 表示網路中的引數集,x 表示網路的輸入,就可以把輸出表示為 f(θ,x);接著,W 是 θ 之上的初始化分佈(通常是帶有一定縮放的高斯分佈),那麼對應的核就是
,其中 x、x' 是兩個輸入。
那麼更常見的「網路中所有的層都是訓練過的」這種情況呢?Jacot 等()近期發現這也和一種核有關係,他們把它稱為 neural tangent kernel(NTK,神經正切核),它的形式可以寫作
。
NTK 和之前提出的核的關鍵區別在於,NTK 是由網路的輸出相對於網路引數的梯度之間的內乘積來定義的;其中的梯度來自訓練網路時使用的梯度下降演算法。概括地說,對於一個梯度下降訓練出的足夠寬的深度神經網路,下面這個結論是成立的:
一個正確地隨機初始化的、 足夠寬的、由具有無窮小步長大小(也就是梯度流 gradient flow)的 梯度下降訓練的深度神經網路,和一個帶有 NTK 的 確定性核迴歸預測器是等效的。
這個結論在 Jacot 等最初的論文()中就基本確立了,不過他們要求網路的各個層依次趨近於無限大。在 Sanjeev Arora, 杜少雷, 胡威, Zhiyuan Li, Ruslan Salakhutdinov and Ruosong Wang 等人最新的論文()中,他們把這個結果做了進一步的改進,讓它對非對稱環境也適用,也就是每層的寬度不用依次變大,只需要都高過某個有限的閾值就可以。
杜少雷和胡威
NTK 是如何出現的?
詳細的推導過程在論文()中有介紹,這裡我們只簡單提一下。作者們在標準的有監督學習環境下考慮這個問題,透過最小化訓練資料上的二次方損失的方式訓練神經網路。經過一系列推導,作者們得到了含有網路梯度項的核矩陣的表示式。
不過到這裡為止作者們還沒有使用「網路非常寬」的這個條件。當網路足夠寬時,他們推導的核可以逼近某個確定性的固定核,也就是前面提到的 neural tangent kernel(NTK,神經正切核)。不過,確定「到底多寬才是足夠寬」需要一些假設和技巧,在這篇論文中作者們最終得到的是隻要網路的每一層的寬度各自大於某個閾值就可以,要比更早的結果中要求每一層寬度逐漸更趨近於無窮大的限制更弱一些。
最終作者們推匯出訓練後的無限寬神經網路和 NTK 是等效的。詳細的推導過程請見論文原文。
無限寬的神經網路實際表現如何?
在證明了無限寬的神經網路和 NTK 等效之後,我們就有機會實際看看無限寬的神經網路的表現 —— 只要測試對應的使用 NTK 的核迴歸預測器就可以了!作者們在標準的影像分類測試集 CIFAR-10 上進行了測試。由於這是基於影像的任務,想要得到好的結果一定少不了卷積結構的參與,所以作者們也推導了卷積 NTK,並和標準的卷積網路進行對比。分類準確率對比如下:
圖中 CNN-V 是不帶有池化的、正常寬度的 CNN,CNTK-V 是對應的卷積 NTK。作者們也測試了帶有全域性平均池化(GAP)的網路,也就是 CNN-GAP 和 CNTK-GAP。在所有實驗中都沒有使用批次標準化(batch normalization)、資料增強等等訓練技巧,只使用 SGD 訓練 CNN,以及 CNTK 使用核迴歸的解析方程。
實驗表明 CNTK 其實是很強的核方法。實驗中最強的是帶有全域性平均池化的、11 層的 CNTK,得到了 77.43% 的分類準確率。目前為止最強的完全基於核的方法來自 Novak 等(),而 CNTK 要比他們的準確率高出超過 10%。而且 CNTK 和正常 CNN 的表現都很接近,也就是說在 CIFAR-10 上超寬(無限寬)的 CNN 是可以取得不錯的表現的。
另外有趣的是,全域性池化不僅(如預期地)顯著提升了正常 CNN 的準確率,也同樣明顯提升了 CNTK 的準確率。也許提高神經網路表現的許多技巧要比我們目前認識到的更通用一些,它們可能也對核方法有效。
結論
想要理解為什麼過度引數化的深度神經網路還能有好得驚人的表現的確是一個很有挑戰的理論問題。不過現在起碼我們已經對一類非常寬的神經網路有了更多的瞭解:可以用 NTK 來表示它們。不過還有一個未解的困難是,關於核方法的經典泛化理論沒法給出泛化能力的現實上下界。好在我們至少知道對核方法的更深的理解也可以幫助我們理解神經網路了。
另外我們還算探索出了一個新的方向,那就是把不同的神經網路架構、訓練技巧轉換到核方法上來,並檢查它們的表現。作者們發現全域性平均池化可以大幅提升核方法的表現,那很有可能 BN、drop-out、最大池化之類的方法也能在核方法中發揮作用;反過來,我們也可以嘗試把 RNN、圖神經網路、Transformer 之類的神經網路轉換成核方法。
以及那個核心的問題:有限寬和無限寬的神經網路之間確實有效能區別,如何解釋這種區別的原因也是重要的理論研究課題。
原論文:On Exact Computation with an Infinitely Wide Neural Net,無限寬的神經網路的精確計算
論文地址:
本文編譯自技術部落格 ,雷鋒網 AI 科技評論編譯
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69946223/viewspace-2659755/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 硬剛無限寬神經網路後,谷歌大腦有了12個新發現神經網路谷歌
- 探索日本網際網路:解鎖日本住宅IP的無限可能
- 谷歌開源Neural Tangents:簡單快速訓練無限寬度神經網路谷歌神經網路
- 神經網路:numpy實現神經網路框架神經網路框架
- 神經網路神經網路
- 探索Gameplay的無限可能GAM
- 人人都能搞定的大模型原理 - 神經網路大模型神經網路
- 【深度學習篇】--神經網路中的卷積神經網路深度學習神經網路卷積
- 8、神經網路神經網路
- LSTM神經網路神經網路
- 聊聊從腦神經到神經網路神經網路
- 圖神經網路GNN 庫,液體神經網路LNN/LFM神經網路GNN
- 神經網路篇——從程式碼出發理解BP神經網路神經網路
- 【神經網路篇】--RNN遞迴神經網路初始與詳解神經網路RNN遞迴
- 神經網路之所以強大的兩個原因 - tunguz神經網路
- 探索複合材料中的原子擴散,加州大學開發神經網路動力學方法神經網路
- 神經網路是如何工作的?神經網路
- 3.2 神經網路的通俗理解神經網路
- 3.3 神經網路的訓練神經網路
- 神經網路的發展史神經網路
- 神經網路(neural networks)神經網路
- 人工神經網路(ANN)神經網路
- 卷積神經網路卷積神經網路
- 迴圈神經網路神經網路
- 生成型神經網路神經網路
- 與神經網路相比,你對P圖一無所知神經網路
- (四)卷積神經網路 -- 8 網路中的網路(NiN)卷積神經網路
- 卷積神經網路十五問:CNN與生物視覺系統的研究探索卷積神經網路CNN視覺
- 卷積神經網路學習筆記——Siamese networks(孿生神經網路)卷積神經網路筆記
- Tensorflow系列專題(四):神經網路篇之前饋神經網路綜述神經網路
- 神經網路中常用的函式神經網路函式
- Tensor:Pytorch神經網路界的NumpyPyTorch神經網路
- 關於神經網路的討論神經網路
- 0603-常用的神經網路層神經網路
- 簡單的神經網路測試神經網路
- 神經網路原理的視覺化神經網路視覺化
- 三、淺層神經網路神經網路
- 殘差神經網路-ResNet神經網路