Java開發者的神經網路進階指南:深入探討交叉熵損失函式

努力的小雨發表於2024-06-24

前言

今天來講一下損失函式——交叉熵函式,什麼是損失函式呢?大體就是真實與預測之間的差異,這個交叉熵(Cross Entropy)是Shannon資訊理論中一個重要概念,主要用於度量兩個機率分佈間的差異性資訊。在資訊理論中,交叉熵是表示兩個機率分佈 p,q 的差異,其中 p 表示真實分佈,q 表示預測分佈,那麼 \(H(p,q)\)就稱為交叉熵:

\(H(p,q) = -\sum_{i=0}^n p(i)ln^{q(i)}\)

交叉熵是一種常用的損失函式,特別適用於神經網路訓練中。在這種函式中,我們用 p 來表示真實標記的分佈,用 q 來表示經過訓練後模型預測的標記分佈。透過交叉熵損失函式,我們可以有效地衡量模型預測分佈 q 與真實分佈 p 之間的相似性。

交叉熵函式是邏輯迴歸(即分類問題)中常用的一種損失函式。

前置知識

有些同學和我一樣,長時間沒有接觸數學,已經完全忘記了。除了基本的加減乘除之外,對於交叉熵函式中的一些基本概念,他們可能只記得和符號。今天我會和大家一起回顧一下,然後再詳細解釋交叉熵函式。首先,我們來簡單瞭解一下指數和對數的基本概念。

指數

\(x^3\) 是一個典型的立方函式,大家對平方和立方可能都有所瞭解。指數級增長的函式具有特定的增長規律,讓我們更深入地記憶和理解它們的分佈特性。

image

這個概念非常簡單,無需舉例子來說明。重要的是要記住一個關鍵點:指數函式的一個特殊性質是它們都經過點(0,1),這意味著任何數的0次冪都等於1。

對數

好的,鋪墊已經完成了。現在讓我們繼續探討對數函式的概念。前面講解了指數函式,對數函式則是指數函式的逆運算。如果有一個指數函式表示式為\(y = a^x\),那麼它的對數表示式就是\(x = \log_a y\)。為了方便表示,我們通常將左側的結果記為\(y\),右側的未知函式記為\(x\),因此對數函式最終表示為\(y = \log_a x\)。為了更加深刻地記憶這一點,讓我們看一下它的分佈圖例。

image

當討論指數函式時,我們瞭解到其影像在( (0,1) ) 處穿過橫軸。然而,當我們轉而討論對數函式時,其表示形式導致了這一點被調換至( (1,0) ),因此對於對數函式而言,它的恆過點即為( (1,0) )。

剩下關於對數的變換我就不再詳細講解了。現在讓我們深入探討一下熵的概念。

交叉熵函式

在探討交叉熵之前,我們先來了解一下熵的概念。熵是根據已知的實際機率計算資訊量的度量,那麼資訊量又是什麼呢?

資訊理論中,資訊量的表示方式:\(I(x_j) = -ln^{(px_j)}\)

\(x_j\):表示一個事件。

\(px_j\):表示一個事件發生的機率。

\(-ln^{(px_j)}\):表示某一個事件發生後會有多大的資訊量,機率越低,所發生的資訊量也就越大。

這裡為了更好地說明,我來舉個例子。比如說有些人非常喜歡追星。那麼,按照一般的邏輯來說,我們可以談談明星結婚這件事的機率分佈:

事件編號 事件 機率p 資訊量 I
\(x_1\) 兩口子都在為事業奮鬥照顧家庭 0.7 \(I(x_1) = -ln^{0.7}= 0.36\)
\(x_2\) 兩口子吵架 0.2 \(I(x_2) = -ln^{0.2}= 1.61\)
\(x_3\) 離婚了 0.1 \(I(x_3) = -ln^{0.1}= 2.30\)

從上面的例子可以看出,如果一個事件的機率很低,那麼它所帶來的資訊量就會很大。比如,某某明星又離婚了!這個訊息的資訊量就非常大。相比之下,“奮鬥”事件的資訊量就顯得小多了。

按照熵的公式進行計算,那麼這個故事的熵即為:

熵:\(H(p) = -\sum_j^n(px_j)ln^{(px_j)}\)

計算得出:\(H(p) = -[(px_1)ln^{(px_1)}+(px_2)ln^{(px_2)}+(px_3)ln^{(px_3)}] = -[0.7*0.36+0.2*1.61+0.1*2.3] = 0.804\)

相對熵(KL散度)

上面我們討論了熵的概念及其應用,熵僅考慮了真實機率分佈。然而,我們的損失函式需要考慮真實機率分佈與預測機率分佈之間的差異。因此,我們需要進一步研究相對熵(KL散度),其計算公式為:

\(H(p) = \sum_j^n(px_j)ln^{(px_j) \over (qx_j)}\)

哎,這其實就是在原先的公式中加了一個\(q(x_j)\)而已。對了,這裡的\(q(x_j)\)指的是加上了預測機率分佈\(q\)。我們知道對數函式的對稱點是(1,0)。因此,很容易推斷出,當真實分佈\(p\)和預測分佈\(q\)越接近時,KL散度\(D\)的值就越小。當它們完全相等時,KL散度恆為0,即在點(1,0)。這樣一來,我們就能夠準確地衡量真實值與預測值之間的差異分佈了。但是沒有任何一個損失函式是能為0 的。

當談到相對熵已經足夠時,為何需要進一步討論交叉熵呢?讓我們繼續深入探討這個問題。

交叉熵

重頭戲來了,我們繼續看下相對熵函式的表示式:\(H(p) = \sum_j^n(px_j)ln^{(px_j) \over (qx_j)}\)

這裡注意下,\(log^{p \over q}\)是可以變換的,也就是說\(log^{p \over q}\) = \(log^p -log^ q\),這麼說,相對熵轉換後的公式就是:$H(p) = \sum_jn(px_j)ln - \sum_jn(px_j)ln = -H(p) + H(p,q) $

當我們考慮到\(H(p)\)在處理不同分佈時並沒有太大作用時,這是因為\(p\)的熵始終保持不變,它是由真實的機率分佈計算得出的。因此,損失函式只需專注於後半部分\(H(p,q)\)即可。

所以最終的交叉熵函式為:\(-\sum_j^n(px_j)ln^{(qx_j)}\)

這裡需要注意的是,上面顯示的是一個樣本計算出的多個機率的熵值。通常情況下,我們考慮的是多個樣本,而不僅僅是單一樣本。因此,我們需要在前面新增樣本的數量,最終表示為:\(-\sum_i^m\sum_j^n(px_j)ln^{(qx_j)}\)

程式碼實現

這裡主要使用Python程式碼來實現,因為其他語言實現起來沒有必要。好的,讓我們來看一下程式碼示例:

import numpy as np

def cross_entropy(y_true, y_pred):
    # 用了一個最小值
    epsilon = 1e-15
    y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
    
    # Computing cross entropy
    ce = - np.sum(y_true * np.log(y_pred))
    return ce

# Example usage:
y_true = np.array([1, 0, 1])
y_pred = np.array([0.9, 0.1, 0.8])

ce = cross_entropy(y_true, y_pred)
print(f'Cross Entropy: {ce}')

這裡需要解釋一下為什麼要使用一個最小值。因為對數函式的特性是,其引數 ( x ) 可以無限接近於0,但不能等於0。因此,如果引數等於0,就會導致對數函式計算時出現錯誤或無窮大的情況。為了避免這種情況,我們選擇使用一個足夠小的最小值作為閾值,以確保計算的穩定性和正確性。

總結

在本文中,我們深入探討了交叉熵函式作為一種重要的損失函式,特別適用於神經網路訓練中。交叉熵透過衡量真實標籤分佈與模型預測分佈之間的差異,幫助最佳化模型的效能。我們從資訊理論的角度解釋了交叉熵的概念,它是基於Shannon資訊理論中的熵而來,用於度量兩個機率分佈之間的差異。

在討論中,我們還回顧了指數和對數函式的基本概念,這些函式在交叉熵的定義和理解中起著重要作用。指數函式展示了指數級增長的特性,而對數函式則是其逆運算,用於計算相對熵和交叉熵函式中的對數項。

進一步探討了熵的概念及其在資訊理論中的應用,以及相對熵(KL散度)作為衡量兩個機率分佈差異的指標。最後,我們詳細介紹了交叉熵函式的定義和實際應用,以及在Python中的簡單實現方式。

透過本文,希望讀者能夠對交叉熵函式有一個更加深入的理解,並在實際應用中運用此知識來最佳化和改進機器學習模型的訓練效果。

相關文章