最近在看一些深度學習相關的書,感覺對於參考文獻1中的mini-batch部分理解得不是很透徹,主要是因為神經網路的輸入開始變成批資料,加之對python的numpy不是很熟了。所以總想寫點什麼,一來有助於加深對於知識的理解,二來也算是分享知識咯。
閒話少敘,讓我們進入正題。
在機器學習中,學習的目標是選擇期望風險\(R_{exp}\)(expected loss)最小的模型,但在實際情況下,我們不知道資料的真實分佈(包含已知樣本和訓練樣本),僅知道訓練集上的資料分佈。因此,我們的目標轉化為最小化訓練集上的平均損失,這也被稱為經驗風險\(R_{emp}\)(empirical loss)。
嚴格地說,我們應該計算所有訓練資料的損失函式的總和,以此來更新模型引數(Batch Gradient Descent)。但隨著資料集的不斷增大,以ImagNet資料集為例,該資料集的資料量有百萬之多,計算所有資料的損失函式之和顯然是不現實的。若採用計算單個樣本的損失函式更新引數的方法(Stochastic Gradient Descent),會導致\(R_{emp}\)難以達到最小值,而且在數值處理上不能使用向量化的方法提高運算速度。
於是,我們採取一種折衷的想法,即取一部分資料,作為全部資料的代表,讓神經網路從這每一批資料中學習,這裡的“一部分資料”稱為mini-batch,這種方法稱為mini-batch學習。
以下圖為例,藍色的線表示Batch Gradient Descent,紫色的線表示Stochastic Gradient Descent,綠色的線表示Mini-Batch Gradient Descent。
從上圖可以看出,Mini-Batch相當於結合了Batch Gradient Descent和Stochastic Gradient Descent各自的優點,既能利用向量化方法提高運算速度,又能基本接近全域性最小值。
對於mini-batch學習的介紹到此為止。下面我們將MINIST資料集上的分類問題作為背景,以交叉熵cross-entropy損失函式為例,來實現一下mini-bacth版的cross-entropy error。
給出cross-entropy error的定義如下:
其中\(y_k\)表示神經網路輸出,\(t_k\)表示正確解標籤。
等式1表示的是針對單個資料的損失函式,現在我們給出在mini-batch下的損失函式,如下
其中N表示這一部分資料的數量,\(t_{nk}\)表示第n個資料在第k個元素的值(\(y_{nk}\)表示神經網路輸出,\(t_{nk}\)表示監督資料)
我們來看一下用Python如何實現mini-batch版的cross-entropy error。針對監督資料\(t_{nk}\)的標籤形式是否為one-hot,我們分類討論處理。
此外,需要明確的一點是,對於一個分類神經網路,最後一層經過softmax函式處理後,輸出\(y_{nk}\)是一個\(n\)x\(k\)的矩陣,\(y_{ij}\)表示第i個資料被預測為\(j(0 \leq j\leq10)\)的機率,特別地,當\(N=1\)時,\(y\)是一個包含10個元素的向量,類似於[0.1,0.2...0.3],其中0.1表示輸入資料預測為0的機率為0.1,0.2表示將輸入資料預測為1的機率為0.2,其他情況以此類推。
首先,對於\(t_{nk}\)為one-hot表示的情況,程式碼塊1如下
def cross_entropy_error(y,t):
batch_size = y.shape[0]
return -np.sum(t * np.log(y + 1e-7)) / batch_size
在上面的程式碼中,我們在y上加了一個微小值,防止出現np.log(0)的情況,因為np.log(0)會變成負無窮大-inf,從而導致後續的計算無法繼續進行。在等式2中\(y_{nk}\)與\(t_{nk}\)下標相同,所以我們直接使用*
做element-wise運算,即對應元素相乘。
但當我們希望同時能夠處理單個資料和批次資料時,程式碼塊1還不能滿足我們的要求。因為當\(N=1\)時,\(y\)是一個包含10個元素的一維向量,輸入到函式中,batch_size將等於10而不是1,於是我們將程式碼塊1進行進一步完善,如下:
def cross_entropy_error(y,t):
if y.ndim == 1:
y = y.reshape(1,y.size)
t = t.reshape(1,t.size)
batch_size = y.shape[0]
return -np.sum(t * np.log(y + 1e-7)) / batch_size
最後,來討論一下\(t_{nk}\)為非one-hot表示的情況。在one-hot情況的計算中,t為0的元素cross-entropy error也為0,所以對於這些元素的計算可以忽略。換言之,在非one-hot表示的情況下,我們只需要計算正確解標籤的交叉熵誤差即可。程式碼如下:
def cross_entropy_error(y,t):
if y.ndim == 1:
y = y.reshape(1,y.size)
t = t.reshape(1,t.size)
batch_size = y.shape[0]
return -np.sum(1 * np.log(y[np.arange(batch_size),t]+1e-7))/batch_size
在上面的程式碼中,y[np.arange(batch_size),t]
表示將從神經網路的輸出中抽出與正確解標籤相對應的元素。
參考文獻
[1] 深度學習入門
[2] DeepLearning.ai深度學習課程筆記
[3] 統計學習方法