TensorFlow: 薛定諤的管道

深度學習大講堂發表於2018-06-13

要說史上最著名的貓,大概就是薛定諤的那隻了。它被關在裝有少量鐳和氰化物的密閉容器裡,當鐳發生衰變時,就會觸發機關打碎裝有氰化物的瓶子,貓就會死亡;如果鐳不發生衰變,貓就會存活下來。在量子力學理論中,由於放射性的鐳處於衰變和沒有衰變兩種狀態的疊加,這隻貓也處於生死疊加態,只有對其進行觀測,才能決定這隻貓的生死。

所以,哈姆雷特說:

生存還是死亡,這是一個問題!

今天我們從貓說起,來討論一個管道,一個蘊含著某些不確定性的管道:TensorFlow。我們從TensorFlow中的一個計算例項出發,在這個例子中的一個計算節點像薛定諤的貓一樣具備不確定性的輸出結果:如下圖所示的例子,同時計算節點assign和節點c時,c的計算結果out_c是不確定的,這是因為TensorFlow會盡可能的對計算過程並行化,所以out_c的結果依賴assignc誰先執行。

TensorFlow: 薛定諤的管道

那麼面對這樣一個可能會產生“薛定諤現象”的框架,我們如何利用它來實現模型呢?

我們在使用TensorFlow這個軟體庫構建模型時,實際上是在TensorFlow提供的這套api系統裡編寫TF程式,這裡可以把TensorFlow看成是一門進行數值計算的“程式語言”。那麼為了更好的掌握TF這麼“語言”,我們可以從三個層次來學習:

  • 理解TensorFlow的基本概念和基本元件

  • 基於對基本概念的理解,利用基本元件來構建模型

  • 除錯模型,優化模型速度,優化模型精度

這裡我們從理解基本概念以及一個除錯模型的例子出發,來介紹其實現模型計算的過程。

TensorFlow: 薛定諤的管道

TensorFlow的核心是通過資料流圖的方式來實現數值計算,這裡最核心的概念就是資料流圖,TensorFlow是以靜態圖(這裡不強調其動態圖的特性)的方式來表達計算,那麼一旦計算模型以圖的方式表達完成,就要通過Session來驅動計算,整體示意如上圖所示,節點和邊構成了你的計算模型,而實際計算時資料(Tensor)沿著圖的邊被驅動著進行計算從而流動起來,這也形象的表示了TensorFlow=Tensor+Flow。因此TensorFlow程式就可以分為兩個階段:

  • 階段一:組裝一個計算圖,這裡只是用TFapi來表達計算模型,生成的是一個靜態圖,圖由計算的節點以及節點之間的連線表示,這個階段只是靜態的表示了計算,因此得不到任何實際的計算值。

  • 階段二:通過一個Session(會話)來執行計算,這裡可以計算某個節點,而這個節點所依賴的父節點都會被驅動先行執行。

比如我們想從下圖所示的資料(X,Y)中學習一個線性關係y=w*x+b

TensorFlow: 薛定諤的管道

對於這樣一個機器學習任務,一般分為測試過程和訓練過程,測試過程一般比較簡單,這裡我們介紹如何使用TensorFlow來實現訓練過程,對於機器學習模型的訓練過程的一般可以如下面流程圖所示:

TensorFlow: 薛定諤的管道

具體步驟為:

1.    定義輸入和輸出標籤

2.    定義模型引數

3.    初始化模型引數

4.    基於輸入和模型引數,由模型的推理過程計算模型的預測結果

5.    基於模型的預測結果和標籤值,由損失函式來計算loss

6.    優化器通過更新引數來最小化loss

7.    不斷重複4-6直到迭代次數達到或者loss低於設定的閾值

使用TensorFlow來完成以上計算時,我們需要:

1)  使用靜態圖的方式表達上面的計算過程(對應階段一)

2)  使用Session(會話)來驅動上面的計算(對應階段二)

可以如下面程式碼所示, 

TensorFlow: 薛定諤的管道

TensorFlow: 薛定諤的管道

所有我們想要進行的計算都需要在階段一進行表達,如我們需要進行模型初始化這樣一個計算過程,那麼我們需要在階段一構造一個init操作節點,我們需要最小化loss,更新模型引數,我們可以構造一個train_op操作節點,每個計算對應計算圖中的一個計算節點,一旦計算圖構建完成,我們就可以在階段二過程通過執行這個節點來進行實際的計算如sess.run(init),sess.run(train_op)。由此通過階段一階段二兩部分程式完成我們想要的計算邏輯,學習到的線性模型如下圖所示。

TensorFlow: 薛定諤的管道

正是因為TF這種graphsession兩階段的劃分,導致我們在除錯TF的時候也會分為兩個階段:

1.    錯誤發生在組裝圖部分。這裡TF會進行型別檢測,以及shape推理,所以一般dtypeshape相關的錯誤會與這一部分程式碼相關。

2.    錯誤發生在執行圖部分。這裡TF會進行執行時的計算,所以NaN等問題會發生在這個階段。

相關文章