你真的理解反向傳播嗎?面試必備
王小新 編譯自 Medium
量子位 出品 | 公眾號 QbitAI
當前,訓練機器學習模型的唯一方式是反向傳播演算法。
深度學習框架越來越容易上手,訓練一個模型也只需簡單幾行程式碼。但是,在機器學習面試中,也會考量面試者對機器學習原理的掌握程度。反向傳播問題經常出現,不少人碰到時仍覺得十分棘手。
最近,Medium上的一位機器學習初學者Ryan Gotesman,在學習吳恩達老師廣受歡迎的Machine Learning課程時,遇到一些困難。
Ryan眼中的學習過程是這樣的:線性迴歸,檢查,邏輯迴歸,檢查,梯度下降,檢查,檢查,再檢查……
接著,Ryan又學習了神經網路和模型訓練的相應演算法,即反向傳播。儘管吳老師花了很多時間來解釋它,但他還是無法理解這個技術的原理。
吳老師也在視訊中提到了,即使沒有深入理解這個演算法,也可以在神經網路中使用它,而且他也這麼做了很多年,但是Ryan想找資料來更好地理解這個概念。
因此,Ryan找到了Geoffrey Hinton等人在1986年發表在Nature上的關於反向傳播原始論文“Learning representations by back-propagating errors”,這篇論文到現在已經有近15000次引用。
論文地址:
http://www.iro.umontreal.ca/~pift6266/A06/refs/backprop_old.pdf
這篇文章十分值得認真閱讀,且只有短短的四頁。Ryan在詳細研讀後,對反向傳播有了新的認識並做了一些筆記。
量子位搬運過來幫助大家更好地理解反向傳播,以下為他部落格的譯文:
反向傳播的本質只是對鏈式法則的巧妙運用。
鏈式法則是在本科課程中導數的一個基本屬性。它指出,假設有三個函式f、g和h,其中f是g的函式,g是h的函式,那麼f相對於h的導數等於f相對於g的導數和g相對於h的導數的乘積,用公式表示如下:
△ 鏈式法則
我們要用這個法則來解釋反向傳播的工作原理。
下面,我們使用最簡單的神經網路來說明。這個網路只有3層,分別是藍色的輸入層、綠色的隱藏層和紅色的輸出層。上一層中的每個單元都連線到下一層中的每個單元,而且每個連線都具有一個權重,當某個單元向另一個單元傳遞資訊時,會乘以該連線的權重得到更新資訊。某個單元會把連線到它的上一層所有單元的輸入值相加,並對這個總和執行Logistic函式並向下一層網路傳遞該值。
△ 三層神經網路,單元數分別為3、4和3
假設給了m個訓練樣本,第i個輸入輸出對錶示為:
其中,x和y是3維向量。對於輸入x,我們把g稱作神經網路的預測(輸出)值,它也是一個3維向量,每個向量元素對應一個輸出單元。所以,對於每個訓練樣本來說,有:
△ 網路中輸入值、輸出值和預測值的向量形式
給定輸入x,我們要找到使得預測值g與輸出值y相等或比較相近的一組網路權重。因此,我們加入了誤差函式,定義如下:
△ 神經網路的誤差函式
為了計算總誤差,我們使用了訓練集中的所有樣本,並對紅色輸出層中的每個單元計算該單元預測值與真實輸出間的平方誤差。對每個樣本分別計算並求和,得到總誤差。
由於g為網路的預測值,取決於網路的權重,可以看到總誤差會隨權重變化而變化,網路的訓練目標就是找到一組誤差最小的權重。
我們可以使用梯度下降來做到這一點,但梯度下降方法要求算出總誤差E對每個權重的導數,這也是結合反向傳播要實現的目標。
現在,我們推廣到一般情況,而不是之前的3個輸出單元。假設輸出層有任意數量的輸出單元,設為n,對於這種情況此時的總誤差為:
△ 計算誤差
這裡為了簡潔,刪去了上標i,因為它是不變的。
大家可能有個疑問,這個誤差值是怎麼隨著某個輸出單元的預測值變化而變化的?導數在這裡起了作用:
△ 總誤差相對每個輸出單元的導數,這裡使用鏈式法則得到平方項的導數
我們還可以發現,隨著輸出單元預測值的變化,該誤差會根據預測值與真值間的差值,以同樣速率在變化。
這裡你可能還有疑問,當某個輸出單元的總輸入變化時,誤差會如何變化。這裡只使用了導數。用z來代表某個輸出單元的總輸入,求出下面公式的值:
△ 誤差E相對於第j個輸出單元總輸入的導數
但是,g是關於z的函式,應用鏈式法則,把它重寫為:
△ 應用鏈式法則後的公式
要記住,在每個單元中,先使用Logistic函式處理輸入後再把它向前傳遞。這意味著,g作為Logistic函式,z是它的輸入,所以可以表示為:
△ Logistic函式及其導數
進而得到:
△ 總誤差相對於第j個輸出單元總輸入的導數
這裡已經計算出,總誤差與某個輸出單元總輸入的變化規律。
現在,我們已經得到誤差相對於某個權重的導數,這就是所求的梯度下降法。
設綠色單元的預測值為g’,綠色層中的單元k與紅色層(輸出層)中的單元j之間的連線權重設為:
△ 綠色層中的單元k與紅色層中的單元j之間的連線權重
考慮下圖中,黃色輸出單元對應的總輸入z。為了計算這個總輸入,先獲得每個綠色單元的輸出值,在把其與連線綠色單元和黃色單元的紅色箭頭權重相乘,並將它們全部相加。
△ 紅色箭頭表示為獲得黃色單元的總輸入在節點間新增的連線
進行推廣,假如有任意數量的綠色單元,設為n,這個n與上面定義的不同,可以表示為:
△ 從總輸入到輸出j
所以,我們不僅可以把z看作是自變數為連線權重的函式,也可以看作是自變數為連線單元輸出值的函式。
下面輪到鏈式法則發揮了。
當與輸出單元的連線權重變化時,誤差該如何變化,這表示為:
△ 總誤差相對於輸出單元連線權重的導數
上面已經計算出誤差相對於輸出單元連線權重的導數,這正是梯度下降所需的公式。
但是推導還沒有完成,我們仍需要計算誤差相對於第一層和第二層連線權重的導數,這裡還需要用到鏈式法則。
接下來,計算誤差與第k個綠色單元輸出值的變化關係:
△ 總誤差相對於綠色層中第k個單元輸出值的導數
由於第k個單元有j個連線權重,我們也考慮在內:
△ 總誤差相對於綠色層中第k個單元輸出值的導數
推導到這裡結束,我們得到了總誤差相對於某個單元輸出值的導數。現在,我們可以忽略紅色輸出層,把綠色層作為網路的最後一層,並重覆上述所有步驟來計算總誤差E相對於輸入權重的導數。
你會注意到,我們計算出的第一個導數與預測值和真實值之間的“誤差”相等。同樣地,最終的導數中也是這個誤差項與其他項的乘積。
這種演算法叫做反向傳播,因為我們把這種形式的誤差進行反向傳播,從最後一層反饋到第一層,並被用來計算誤差E相對於網路中每個單元權重的導數。
只要計算出這些導數後,可在梯度下降過程中使用它們來最小化誤差E並訓練神經網路。
希望這篇文章能讓你更好地理解反向傳播的工作原理~
作者系網易新聞·網易號“各有態度”簽約作者
— 完 —
加入社群
量子位AI社群16群開始招募啦,歡迎對AI感興趣的同學,加小助手微信qbitbot6入群;
此外,量子位專業細分群(自動駕駛、CV、NLP、機器學習等)正在招募,面向正在從事相關領域的工程師及研究人員。
進群請加小助手微訊號qbitbot6,並務必備註相應群的關鍵詞~通過稽核後我們將邀請進群。(專業群稽核較嚴,敬請諒解)
誠摯招聘
量子位正在招募編輯/記者,工作地點在北京中關村。期待有才氣、有熱情的同學加入我們!相關細節,請在量子位公眾號(QbitAI)對話介面,回覆“招聘”兩個字。
量子位 QbitAI · 頭條號簽約作者
վ'ᴗ' ի 追蹤AI技術和產品新動態
相關文章
- 你真的理解this嗎
- 你真的理解setState嗎?
- 你真的理解==和===嗎
- 反向傳播演算法的暴力理解反向傳播演算法
- 正向傳播和反向傳播反向傳播
- 你真的理解 getLocationInWindow 了嗎?
- [譯]你真的理解grok嗎
- 你真的理解 new 了嗎?
- 你真的理解Python中的賦值、傳參嗎?Python賦值
- 你真的理解 flex 佈局嗎?Flex
- 三層,你真的理解了嗎?
- 2.反向傳播反向傳播
- 騰訊面試,你真的懂HTTP嗎?面試HTTP
- 你真的理解什麼是死鎖嗎?
- 你真的理解JS的繼承了嗎?JS繼承
- 傳統企業玩網際網路你真的準備好了嗎?
- 【機器學習】李宏毅——何為反向傳播機器學習反向傳播
- CUDA教學(2):反向傳播反向傳播
- Java面試必問面試題,你掌握了嗎?Java面試題
- 高併發,你真的理解透徹了嗎?
- 你真的理解T-sql中的NULL嗎?SQLNull
- 真的沒面試嗎?其實就是你不想去!面試
- 面試官帶你學Android——面試中Handler 這些必備知識點你都知道嗎?面試Android
- 反向傳播演算法(BackPropagation)反向傳播演算法
- 深入理解JVM垃圾收集機制,下次面試你準備好了嗎JVM面試
- 你真的理解函數語言程式設計嗎?函數程式設計
- 你真的理解【函數語言程式設計】嗎?函數程式設計
- 您真的瞭解網站必備的SSL證書嗎?網站
- 面試官:你真的瞭解Redis分散式鎖嗎?面試Redis分散式
- 你真的理解 Webpack?Web
- 你真的理解 Spring Boot 專案中的 parent 嗎?Spring Boot
- C++中的i++和++i你真的理解嗎?C++
- Android效能優化(七)之你真的理解ANR嗎?Android優化
- MVP模式(2):你真的理解下抽象類和介面嗎??MVP模式抽象
- 【DL筆記4】神經網路詳解,正向傳播和反向傳播筆記神經網路反向傳播
- 面試真的很難嗎?面試
- WebView你真的熟悉嗎?WebView
- 你真的知道JS嗎JS