負對數似然(NLL)和困惑度(PPL)

csjywu1發表於2024-08-29

讓我們透過一個簡單的例子來演示這段程式碼的計算過程,包括負對數似然(NLL)和困惑度(PPL)的計算。為了簡化,我們將假設一個非常小的模型輸出和資料。

假設:

  • 我們有兩個樣本(即 batch size 為 2)。
  • 每個樣本有 3 個可能的類別,S_logits 是模型輸出的 logits。
  • smask 是一個掩碼,假設全部為 True,即我們對所有樣本和所有類別都進行處理。
  • s_batch_id 是一個表示每個樣本的索引的向量,用於 scatter_mean 的計算。

1. 模型輸出的 logits:

假設 r_pred_S_logits 的最後一層輸出如下(為了簡單,假設只有一個時間步長):

import torch

# 假設的 logits
r_pred_S_logits = [torch.tensor([[[2.0, 1.0, 0.1], [1.0, 2.5, 0.5]]])]

# 掩碼
smask = torch.tensor([True, True])

# 批次 ID(假設第一個樣本和第二個樣本)
s_batch_id = torch.tensor([0, 1])

2. 計算 softmax 機率分佈:

首先,對 S_logits 進行 softmax 操作:

S_logits = r_pred_S_logits[-1][0][smask]  # shape: (2, 3)
S_dists = torch.softmax(S_logits, dim=-1)  # shape: (2, 3)
print(S_dists)

這將輸出:

tensor([[0.6590, 0.2424, 0.0986],
        [0.2312, 0.6285, 0.1403]])

每一行是一個樣本的機率分佈。

3. 取樣類別:

然後,從 S_dists 中使用 torch.multinomial 取樣類別:

pred_S = torch.zeros_like(smask, dtype=torch.long)
pred_S[smask] = torch.multinomial(S_dists, num_samples=1).squeeze()
print(pred_S)

假設取樣結果為:

tensor([0, 1])

這意味著第一個樣本預測為類別 0,第二個樣本預測為類別 1。

4. 計算 NLL:

我們從 S_dists 中提取出預測類別的機率,並計算負對數似然(NLL):

S_probs = S_dists[torch.arange(s_batch_id.shape[0]), pred_S[smask]]
print(S_probs)

假設輸出為:

tensor([0.6590, 0.6285])

計算 NLL:

nlls = -torch.log(S_probs)
print(nlls)

輸出:

tensor([0.4170, 0.4642])

這兩個值是每個樣本的 NLL 值。

5. 計算 PPL:

最後,透過 scatter_mean 計算每個 batch 的平均 NLL(在這個簡單的例子中,每個樣本有一個唯一的 ID,所以直接取平均值),然後困惑度(PPL)可以透過取指數得到:

from torch_scatter import scatter_mean

ppl = scatter_mean(nlls, s_batch_id)
print(ppl)

結果:

tensor([0.4170, 0.4642])

困惑度可以透過取指數計算得到:

ppl_value = torch.exp(ppl)
print(ppl_value)

結果:

tensor([1.5172, 1.5901])

總結

在這個簡單的例子中:

  • NLL: 是 0.41700.4642,分別對應兩個樣本的負對數似然。
  • 困惑度(PPL): 透過 exp(NLL) 計算得出 1.51721.5901

這些值反映了模型預測的質量,困惑度越小,表示模型對樣本的預測越好。

負對數似然是怎麼計算的

負對數似然(Negative Log-Likelihood, NLL)是用來衡量模型預測的機率與真實類別之間的差異的一種損失函式。在分類任務中,它可以用於評估模型預測的準確性。NLL 的計算是基於模型輸出的機率(通常透過 softmax 函式生成的機率分佈)和真實類別標籤的。

負對數似然的計算步驟

  1. 模型輸出 logits:

    • 模型通常輸出 logits,也就是未歸一化的分數。這些分數代表了模型對每個類別的信心。
  2. Softmax 函式:

    • 將 logits 轉換為機率分佈。這一步驟使用 softmax 函式,使得每個類別的分數被歸一化為一個機率值,所有類別的機率之和為 1。

    公式如下:

    [
    P(y_i) = \frac{\exp(\text{logit}i)}{\sum^{N} \exp(\text{logit}_j)}
    ]

    其中,( P(y_i) ) 是類別 ( y_i ) 的機率,logit 是模型輸出的原始分數。

  3. 選擇預測類別的機率:

    • 從 softmax 生成的機率分佈中,選擇實際發生的類別(或模型預測的類別)的機率。
  4. 計算負對數似然:

    • 取所選機率的負對數作為負對數似然值(NLL)。

    公式如下:

    [
    \text{NLL} = -\log(P(y_{\text{true}}))
    ]

    其中,( P(y_{\text{true}}) ) 是真實類別的預測機率。

例子

假設我們有一個三分類問題,模型輸出的 logits 為:

import torch

logits = torch.tensor([2.0, 1.0, 0.1])

1. Softmax 計算機率分佈:

probs = torch.softmax(logits, dim=-1)
print(probs)

這將輸出:

tensor([0.6590, 0.2424, 0.0986])

即,類別 0 的機率是 0.6590,類別 1 的機率是 0.2424,類別 2 的機率是 0.0986。

2. 假設真實類別是 0,那麼選擇類別 0 的機率:

P_true = probs[0]
print(P_true)

輸出:

tensor(0.6590)

3. 計算負對數似然:

nll = -torch.log(P_true)
print(nll)

輸出:

tensor(0.4170)

這個值 ( 0.4170 ) 就是類別 0 的負對數似然,它反映了模型對這個類別的預測質量。

總結

  • 負對數似然(NLL) 是模型對某個類別預測機率的負對數。
  • NLL 越小,說明模型對真實類別的預測機率越高,模型的表現越好。
  • NLL 越大,說明模型對真實類別的預測機率越低,模型的表現越差。

相關文章