深度互學習-Deep Mutual Learning:三人行必有我師

深度學習大講堂發表於2019-07-31

編者按:更高效能的深度神經網路往往伴隨著愈加龐大的引數量,而大量的計算需求使其難以部署在移動端。為此,精巧的網路結構設計(如MobileNet、ShuffleNet)、模型壓縮策略(剪枝二值化等)及其他優化方法應運而生。

Hinton等人在2015年提出的模型蒸餾演算法,利用預訓練好的大網路當作教師來向小網路傳遞知識,從而提高小網路效能。而模型蒸餾演算法需要有提前預訓練好的大網路,且僅可對小網路進行單向的知識傳遞。古人云“三人行必有我師焉”,本文作者提出了一種“深度互學習Deep Mutual Learning”策略,使得小網路之間能夠互相學習共同進步。

1.研究動機

近幾年來,深度神經網路計算機視覺語音識別、語言翻譯等領域中取得了令人矚目的成果,為了完成更加複雜的資訊處理任務,網路在設計上不斷增加深度或寬度,使得模型引數量越來越大,如經典的VGG、Inception、ResNet系列網路。儘管更深或更寬的神經網路取得了更好的效能,大量計算需求使得它們難以部署在資源條件有限的環境中,如手機、平板、車載等移動端應用。這促使研究者們採用各種各樣的方法去探索更高效的模型,如更精巧的網路結構設計MobileNet和ShuffleNet,還有網路壓縮、剪枝二值化,以及比較有趣的模型蒸餾等。

模型蒸餾演算法由Hinton等人在2015年提出,利用一個預訓練好的大網路當作教師來提供小網路額外的知識即平滑後的概率估計,實驗表明小網路通過模仿大網路估計的類別概率,優化過程變得更容易,且表現出與大網路相近甚至更好的效能。然而模型蒸餾演算法需要有提前預訓練好的大網路,且大網路在學習過程中保持固定,僅對小網路進行單向的知識傳遞,難以從小網路的學習狀態中得到反饋資訊來對訓練過程進行優化調整。

我們嘗試探索一種能夠學習到更強大小網路的訓練機制—深度互學習,即採用多個網路同時進行訓練,每個網路在訓練過程中不僅接受來自真值標記的監督,還參考同伴網路的學習經驗來進一步提升泛化能力。在整個過程中,兩個網路之間不斷分享學習經驗,實現互相學習共同進步。

2.演算法描述深度互學習-Deep Mutual Learning:三人行必有我師圖1 深度互學習演算法框架具體來說,每個網路在學習過程中有兩個損失函式,一個是傳統的監督損失函式,採用交叉熵損失來度量網路預測的目標類別與真實標籤之間的差異,另一個是網路間的互動損失函式,採用KL散度來度量兩個網路預測概率分佈之間的差異。公式表示為

深度互學習-Deep Mutual Learning:三人行必有我師

採用這兩種損失函式,不僅可以使得網路學習到如何區分不同的類別,還能夠使其參考另一個網路的概率估計來提升自身泛化能力。

接下來我們給出網路的優化策略。對於單塊GPU,我們採用交替迭代的方式依次更新兩個網路,當有多塊GPU時,我們可以採用分散式訓練,每次迭代時兩個網路同時計算概率估計差異並更新模型引數。實驗發現分散式訓練可以獲得更好的效能。目前關於分散式訓練為何能比序列訓練獲得更好的效能還未有比較好的理論解釋,一些研究者認為在分散式訓練中每個worker對附近引數空間的探索實際上提高了模型在連續梯度下降方面的統計效能。

我們提出的互學習演算法也很容易擴充套件到多網路學習和半監督學習場景中。當有K個網路時,深度互學習學習每個網路時將其餘K-1個網路分別作為教師來提供學習經驗。另外一種策略是將其餘K-1個網路融合後得到一個教師來提供學習經驗 。在半監督互學習場景中,我們對有標籤的資料計算監督損失和互動損失,而針對無標籤資料我們僅計算互動損失來幫助網路從訓練資料中挖掘更多有用資訊。

3.實驗結果

我們首先在CIFAR-10和CIFAR-100上用不同的網路做了實驗,從表中可以看出,所有不同的網路組合採用深度互學習演算法均可以提升分類準確率,這表明了我們演算法具有較高的靈活性,對網路結構的適應性較強。一般來說小網路從互學習訓練中獲益更多,比如Resnet-32和MobileNet。儘管WRN-28-10網路引數量很大,與其它網路進行互學習訓練依然可以獲得效能提升。因此,不同於模型蒸餾演算法需要預訓練大網路來幫助小網路提升效能,我們提出的深度互學習演算法也可以幫助參與訓練的大網路來提升其效能。

深度互學習-Deep Mutual Learning:三人行必有我師表1 資料集CIFAR-10與CIFAR-100實驗結果我們在ImageNet上也做了實驗,從圖2中可以看出採用互學習訓練均可以提升網路在大規模分類任務上的效能。
深度互學習-Deep Mutual Learning:三人行必有我師圖2 ImageNet實驗結果針對多網路互學習,我們從圖3看出增加網路數量可以提升互學習策略下的單個網路效能,這說明更多教師網路提供了更多學習經驗,幫助網路學習到更好的特徵。另一方面,多網路互學習中多個獨立教師(DML)的效能會優於融合教師(DML_e),這說明多個不同教師網路可以提供更多樣化的學習經驗,更有益於每個網路的學習。
深度互學習-Deep Mutual Learning:三人行必有我師圖3 多網路互學習實驗結果針對半監督學習,從圖4中可以看出,僅採用有標籤資料參與訓練時,深度互學習策略可以提高演算法分類準確率。而當我們將未標記資料加入互學習訓練中,網路的效能可以得到進一步提升,當標記樣本數量較少時,其優勢更明顯。
深度互學習-Deep Mutual Learning:三人行必有我師圖4 半監督深度互學習實驗結果4.作用機制分析

那麼,為什麼互學習機制能起作用呢?為什麼網路從頭開始互學習訓練也能收斂到更好的解而不是被互相拉低?當兩個網路均從頭開始訓練時額外的知識從哪裡來?為什麼約束兩個網路的概率估計相近可以提升泛化能力?經過互學習訓練後兩個網路是不是更相似了?

首先,為什麼網路從頭開始互學習訓練也能收斂到更好的解而不是被互相拉低?直觀解釋如下:每個網路一開始採用隨機初始化,類別概率估計接近於均勻分佈,這使得它們在訓練初期的監督損失較大,互動損失較小,每個網路主要由傳統的監督損失函式引導,這樣可以保證網路的效能在逐漸提升。隨著模型引數更新,每個網路在自己的學習過程中獲得不同的知識,它們對樣本類別的概率估計也會有所不同,這時互動損失開始促進網路互相參考學習經驗。

接下來是最關鍵的問題,為什麼互學習機制起作用?當兩個網路均從頭開始訓練時額外的知識從哪裡來?為什麼約束兩個網路的概率估計相近可以提升泛化能力?我們從三個角度來嘗試理解這些問題。

首先我們認為類別概率估計蘊含了網路挖掘到的資料本質規律。網路的泛化能力越強,則表示網路越有可能挖掘到了資料的內在本質特性,並可以通過類別概率估計表現出來。例如我們希望網路學習區分貓、狗、桌子三個類別,如圖5所示,網路在對貓進行分類時除了要最大化貓的類別概率估計,還會給錯誤類別如狗和桌子分配一定概率,儘管該概率值很低,但我們仍希望分配給狗的概率要大於分配給桌子的概率,即希望網路除了學習到貓的特徵,還能學習到和狗共有的一些特徵,認為貓與狗的類別距離要小於貓與桌子的類別距離。這樣網路在新的測試資料上就更有可能捕捉貓的多種特性,表現出較強的泛化能力。真值標籤提供的資訊僅包含樣本是否屬於某一類,但缺少不同類別之間的聯絡,而網路輸出的類別概率估計則能夠在一定程度上恢復該資訊,因此網路之間進行類別概率估計互動可以傳遞學習到的資料分佈特性,從而幫助網路改善泛化效能。

其次我們認為約束類別概率相近起到正則化作用。深度神經網路在訓練過程中一般採用one-hot-vector方式編碼真實類別分佈,即認為觀測樣本屬於某一類時,其概率值為1,否則為0。InceptionV3論文中認為這種真值標籤編碼會使得模型在訓練過程中對預測結果太過確信,容易導致過擬合,於是提出標籤平滑(Label Smoothing)策略,將正確類的概率分配一些給錯誤類,防止模型把預測值過度集中在較大概率上。Chaudhar等在ICLR2017論文中提出增加熵正則,約束網路預測輸出的概率稍微平滑一點。在互學習演算法中,當我們將網路2的類別概率傳遞給網路1時,本質上也是提供額外的類別先驗約束,防止網路1過度擬合真值標籤的0-1分佈,有效降低過擬合發生概率。然而不一樣的是,標籤平滑和熵正則的類別概率約束是盲目的,而互學習演算法中會有更多類別資訊。

最後,我們認為網路在訓練過程中會參考同伴網路的經驗來調整自己的學習過程,最終能夠收斂到一個更平緩的極小值點,從而具備更好的泛化效能。關於神經網路泛化效能的一些研究認為,儘管深度神經網路可以找到很多解(即網路學習到的引數)使得訓練損失降到零,但一些解能夠比其它解具有更好的泛化效能,其原因在於這些解處於更平緩的極小點,這意味著小的波動不會對網路的預測結果造成劇烈影響。

那麼我們的深度互學習演算法是不是幫助網路找到了一個更平緩的極小點呢?我們進行了實驗驗證,首先我們觀測了兩種訓練策略下網路在訓練資料集上的損失函式變化,從圖(a)可以看出單獨訓練及互學習訓練的網路都可以充分擬合訓練資料,訓練集上的分類準確率都可以達到100%,且訓練損失都可以降到幾乎相同的極小值。這說明深度互學習演算法並沒有幫助網路找一個更深的極小值點來幫助網路在訓練集上實現損失更小,而是有可能找到了一個深度相同但更平緩的極小值點。

深度互學習-Deep Mutual Learning:三人行必有我師圖5 深度互學習作用機制分析為了驗證該猜想,我們對兩種策略訓練好的網路引數新增高斯噪聲,並在圖(b)中比較了新增不同方差高斯噪聲後網路損失函式值的變化。從圖中可以看出,單獨訓練的網路在新增噪聲後損失函式值波動很大,而互學習訓練網路的損失函式值則增加很小。該實驗現象表明深度互學習演算法幫助網路找到了一個更平緩的極小點,針對噪聲具有更強的魯棒性,從而具有更好的泛化效能。

那麼深度互學習是如何幫助網路找到更好的解呢? 我們注意到深度互學習演算法要求一個網路1的概率估計與同伴網路2的概率估計相匹配,網路1在某個類別上估計概率為為零而網路2估計不為零時,就會產生比較大的懲罰。因此當多個網路參與訓練時,每個網路針對樣本估計的概率值會分佈在不同的類別上,監督損失函式會使得網路在第一最大類上產生較大的概率估計,而剩餘的概率值會依次分佈在第二最大類及之後的類別上。當兩個網路類別概率估計在這些第二類別有差異時,KL損失函式會使兩個網路相互妥協,每個網路將分出一些概率值給更接近真值類的第二最大類及之後類別,幫助網路挖掘更多類別資訊來找到更好的解。從圖上可以看出,採用深度互學習演算法可以使得訓練集上類別概率分佈估計更平緩,且不同類別的相對距離也更明顯。

5.結論與展望

我們提出了一個簡單有效的互學習演算法,通過採用兩個網路聯合訓練來提升深度神經網路的泛化效能。該演算法不僅可以用於訓練高效的小網路,也可以進一步提升大網路效能,且容易擴充套件到多網路學習及半監督學習場景中。我們對演算法的作用機制進行了探索分析,嘗試從網路泛化能力和尋找到解的性質來分析深度互學習演算法有效的原因。

程式碼:

https://github.com/YingZhangDUT/Deep-Mutual-Learning

論文:

http://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.pdf

作者簡介:

深度互學習-Deep Mutual Learning:三人行必有我師張瑩,大連理工大學2015級博士生,導師盧湖川教授,研究方向為行人搜尋,包括行人再識別和跨模態行人搜尋。目前已發表論文9篇,其中第一作者論文5篇,包括2篇CVPR,ECCV等。2016年赴博二期間就讀於倫敦瑪麗女王大學進行聯合培養,指導教師為向滔教授,合作導師Timothy M. Hospedales。

相關文章