揭祕:人工智慧深度神經網路的4種精簡除錯方法

人工智慧老李發表於2019-04-20

當你花了幾個星期構建一個資料集、編碼一個神經網路並訓練好了模型,然後發現結果並不理想,接下來你會怎麼做?

深度學習通常被視為一個黑盒子,我並不反對這種觀點——但是你能講清楚學到的上萬引數的意義嗎?

但是黑盒子的觀點為機器學習從業者指出了一個明顯的問題:你如何除錯模型?

在這篇文章中,我將會介紹一些我們在 Cardiogram 中除錯 DeepHeart 時用到的技術,DeepHeart 是使用來自 Apple Watch、 Garmin、和 WearOS 的資料預測疾病的深度神經網路。

在 Cardiogram 中,我們認為構建 DNN 並不是鍊金術,而是工程學。

揭祕:人工智慧深度神經網路的4種精簡除錯方法

你的心臟暴露了很多你的資訊。DeepHeart 使用來自 Apple Watch、 Garmin、和 WearOS 的心率資料來預測你患糖尿病、高血壓以及睡眠窒息症(sleep apnea)的風險。

一、預測合成輸出 通過預測根據輸入資料構建的合成輸出任務來測試模型能力。

我們在構建檢測睡眠窒息症的模型時使用了這個技術。現有關於睡眠窒息症篩查的文獻使用日間和夜間心率標準差的差異作為篩查機制。因此我們為每週的輸入資料建立了合成輸出任務:

標準差 (日間心率)—標準差 (夜間心率)

為了學習這個函式,模型要能夠:

  1. 區分白天和黑夜

  2. 記住過去幾天的資料

這兩個都是預測睡眠窒息症的先決條件,所以我們使用新架構進行實驗的第一步就是檢查它是否能學習這個合成任務。

你也可以通過在合成任務上預訓練網路,以半監督的形式來使用類似這樣的合成任務。當標記資料很稀缺,而你手頭有大量未標記資料時,這種方法很有用。

二、視覺化啟用值 理解一個訓練好的模型的內部機制是很難的。你如何理解成千上萬的矩陣乘法呢?

在這篇優秀的 Distill 文章《Four Experiments in Handwriting with a Neural Network》中,作者通過在熱圖中繪製單元啟用值,分析了手寫模型。我們發現這是一個「開啟 DNN 引擎蓋」的好方法。

我們檢查了網路中幾個層的啟用值,希望能夠發現一些語義屬性,例如,當使用者在睡覺、工作或者焦慮時,啟用的單元是怎樣的?

用 Keras 寫的從模型中提取啟用值的程式碼很簡單。下面的程式碼片段建立了一個 Keras 函式 last_output_fn,該函式在給定一些輸入資料的情況下,能夠獲得一層的輸出(即它的啟用值)。

from keras import backend as K def extract_layer_output(model, layer_name, input_data): layer_output_fn = K.function([model.layers[0].input], [model.get_layer(layer_name).output]) layer_output = layer_output_fn([input_data])

layer_output.shape is (num_units, num_timesteps)

return layer_output[0] 我們視覺化了網路好幾層的啟用值。在檢查第二個卷積層(一個寬為 128 的時間卷積層)的啟用值時,我們注意到了一些奇怪的事:

揭祕:人工智慧深度神經網路的4種精簡除錯方法

卷積層的每個單元在每個時間步長上的啟用值。藍色的陰影代表的是啟用值。

啟用值竟然不是隨著時間變化的!它們不受輸入值影響,被稱為「死神經元」。

揭祕:人工智慧深度神經網路的4種精簡除錯方法

ReLU 啟用函式,f(x) = max(0, x)

這個架構使用了 ReLU 啟用函式,當輸入是負數的時候它輸出的是 0。儘管它是這個神經網路中比較淺的層,但是這確實是實際發生的事情。

在訓練的某些時候,較大的梯度會把某一層的所有偏置項都變成負數,使得 ReLU 函式的輸入是很小的負數。因此這層的輸出就會全部為 0,因為對小於 0 的輸入來說,ReLU 的梯度為零,這個問題無法通過梯度下降來解決。

當一個卷積層的輸出全部為零時,後續層的單元就會輸出其偏置項的值。這就是這個層每個單元輸出一個不同值的原因——因為它們的偏置項不同。

我們通過用 Leaky ReLU 替換 ReLU 解決了這個問題,前者允許梯度傳播,即使輸入為負時。

我們沒想到會在此次分析中發現「死神經元」,但最難找到的錯誤是你沒打算找的。

三、梯度分析 梯度的作用當然不止是優化損失函式。在梯度下降中,我們計算與Δparameter 對應的Δloss。儘管通常意義上梯度計算的是改變一個變數對另一個變數的影響。由於梯度計算在梯度下降方法中是必需的,所以像 TensorFlow 這樣的框架都提供了計算梯度的函式。

我們使用梯度分析來確定我們的深度神經網路能否捕捉資料中的長期依賴。DNN 的輸入資料特別長:4096 個時間步長的心率或者計步資料。我們的模型架構能否捕捉資料中的長期依賴非常重要。例如,心率的恢復時間可以預測糖尿病。這就是鍛鍊後恢復至休息時的心率所耗的時間。為了計算它,深度神經網路必須能夠計算出你休息時的心率,並記住你結束鍛鍊的時間。

衡量模型能否追蹤長期依賴的一種簡單方法是去檢查輸入資料的每個時間步長對輸出預測的影響。如果後面的時間步長具有特別大的影響,則說明模型沒有有效地利用早期資料。

對於所有時間步長 t,我們想要計算的梯度是與Δinput_t 對應的Δoutput。下面是用 Keras 和 TensorFlow 計算這個梯度的程式碼示例:

def gradient_output_wrt_input(model, data):

[:, 2048, 0] means all users in batch, midpoint timestep, 0th task (diabetes)

output_tensor = model.model.get_layer('raw_output').output[:, 2048, 0]

output_tensor.shape == (num_users)

Average output over all users. Result is a scalar.

output_tensor_sum = tf.reduce_mean(output_tensor) inputs = model.model.inputs # (num_users x num_timesteps x num_input_channels) gradient_tensors = tf.gradients(output_tensor_sum, inputs)

gradient_tensors.shape == (num_users x num_timesteps x num_input_channels)

Average over users

gradient_tensors = tf.reduce_mean(gradient_tensors, axis=0)

gradient_tensors.shape == (num_timesteps x num_input_channels)

eg gradient_tensor[10, 0] is deriv of last output wrt 10th input heart rate

Convert to Keras function

k_gradients = K.function(inputs=inputs, outputs=gradient_tensors)

Apply function to dataset

return k_gradients([data.X]) 在上面的程式碼中,我們在平均池化之前,在中點時間步長 2048 處計算了輸出。我們之所以使用中點而不是最後的時間步長的原因是,我們的 LSTM 單元是雙向的,這意味著對一半的單元來說,4095 實際上是第一個時間步長。我們將得到的梯度進行了視覺化:

揭祕:人工智慧深度神經網路的4種精簡除錯方法

Δoutput_2048 / Δinput_t

請注意我們的 y 軸是 log 尺度的。在時間步長 2048 處,與輸入對應的輸出梯度是 0.001。但是在時間步長 2500 處,對應的梯度小了一百萬倍!通過梯度分析,我們發現這個架構無法捕捉長期依賴。

四、分析模型預測

你可能已經通過觀察像 AUROC 和平均絕對誤差這樣的指標分析了模型預測。你還可以用更多的分析來理解模型的行為。

例如,我們好奇 DNN 是否真的用心率輸入來生成預測,或者說它的學習是不是嚴重依賴於所提供的後設資料——我們用性別、年齡這樣的使用者後設資料來初始化 LSTM 的狀態。為了理解這個,我們將模型與在後設資料上訓練的 logistic 迴歸模型做了對比。

DNN 模型接收了一週的使用者資料,所以在下面的散點圖中,每個點代表的是一個使用者周。

揭祕:人工智慧深度神經網路的4種精簡除錯方法

這幅圖驗證了我們的猜想,因為預測結果並不是高度相關的。

除了進行彙總分析,檢視最好和最壞的樣本也是很有啟發性的。對一個二分類任務而言,你需要檢視最令人震驚的假陽性和假陰性(也就是預測距離標籤最遠的情況)。嘗試鑑別損失模式,然後過濾掉在你的真陽性和真陰性中出現的這種模式。

一旦你對損失模式有了假設,就通過分層分析進行測試。例如,如果最高損失全部來自第一代 Apple Watch,我們可以用第一代 Apple Watch 計算我們的調優集中使用者集的準確率指標,並將這些指標與在剩餘調優集上計算的指標進行比較。

原文連結:blog.cardiogr.am/4-ways-to-d…

相關文章