LLM可解釋性的未來希望?稀疏自編碼器是如何工作的,這裡有一份直觀說明

机器之心發表於2024-08-05

簡而言之:矩陣 → ReLU 啟用 → 矩陣

在解釋機器學習模型方面,稀疏自編碼器(SAE)是一種越來越常用的工具(雖然 SAE 在 1997 年左右就已經問世了)。

機器學習模型和 LLM 正變得越來越強大、越來越有用,但它們仍舊是黑箱,我們並不理解它們完成任務的方式。理解它們的工作方式應當大有助益。

SAE 可幫助我們將模型的計算分解成可以理解的元件。近日,LLM 可解釋性研究者 Adam Karvonen 釋出了一篇部落格文章,直觀地解釋了 SAE 的工作方式。

可解釋性的難題

神經網路最自然的元件是各個神經元。不幸的是,單個神經元並不能便捷地與單個概念相對應,比如學術引用、英語對話、HTTP 請求和韓語文字。在神經網路中,概念是透過神經元的組合表示的,這被稱為疊加(superposition)。

之所以會這樣,是因為世界上很多變數天然就是稀疏的。

舉個例子,某位名人的出生地可能出現在不到十億分之一的訓練 token 中,但現代 LLM 依然能學到這一事實以及有關這個世界的大量其它知識。訓練資料中單個事實和概念的數量多於模型中神經元的數量,這可能就是疊加出現的原因。

近段時間,稀疏自編碼器(SAE)技術越來越常被用於將神經網路分解成可理解的元件。SAE 的設計靈感來自神經科學領域的稀疏編碼假設。現在,SAE 已成為解讀人工神經網路方面最有潛力的工具之一。SAE 與標準自編碼器類似。

常規自編碼器是一種用於壓縮並重建輸入資料的神經網路

舉個例子,如果輸入是一個 100 維的向量(包含 100 個數值的列表);自編碼器首先會讓該輸入透過一個編碼器層,讓其被壓縮成一個 50 維的向量,然後將這個壓縮後的編碼表示饋送給解碼器,得到 100 維的輸出向量。其重建過程通常並不完美,因為壓縮過程會讓重建任務變得非常困難。

圖片

一個標準自編碼器的示意圖,其有 1x4 的輸入向量、1x2 的中間狀態向量和 1x4 的輸出向量。單元格的顏色表示啟用值。輸出是輸入的不完美重建結果。

解釋稀疏自編碼器

稀疏自編碼器的工作方式

稀疏自編碼器會將輸入向量轉換成中間向量,該中間向量的維度可能高於、等於或低於輸入的維度。在用於 LLM 時,中間向量的維度通常高於輸入。在這種情況下,如果不加額外的約束條件,那麼該任務就很簡單,SAE 可以使用單位矩陣來完美地重建出輸入,不會出現任何意料之外的東西。但我們會新增約束條件,其中之一是為訓練損失新增稀疏度懲罰,這會促使 SAE 建立稀疏的中間向量。

舉個例子,我們可以將 100 維的輸入擴充套件成 200 維的已編碼表徵向量,並且我們可以訓練 SAE 使其在已編碼表徵中僅有大約 20 個非零元素。

圖片

稀疏自編碼器示意圖。請注意,中間啟用是稀疏的,僅有 2 個非零值。

我們將 SAE 用於神經網路內的中間啟用,而神經網路可能包含許多層。在前向透過過程中,每一層中和每一層之間都有中間啟用。

舉個例子,GPT-3 有 96 層。在前向透過過程中,輸入中的每個 token 都有一個 12,288 維向量(一個包含 12,288 個數值的列表)。此向量會累積模型在每一層處理時用於預測下一 token 的所有資訊,但它並不透明,讓人難以理解其中究竟包含什麼資訊。

我們可以使用 SAE 來理解這種中間啟用。SAE 基本上就是「矩陣 → ReLU 啟用 → 矩陣」。

舉個例子,如果 GPT-3 SAE 的擴充套件因子為 4,其輸入啟用有 12,288 維,則其 SAE 編碼的表徵有 49,512 維(12,288 x 4)。第一個矩陣是形狀為 (12,288, 49,512) 的編碼器矩陣,第二個矩陣是形狀為 (49,512, 12,288) 的解碼器矩陣。透過讓 GPT 的啟用與編碼器相乘並使用 ReLU,可以得到 49,512 維的 SAE 編碼的稀疏表徵,因為 SAE 的損失函式會促使實現稀疏性。

通常來說,我們的目標讓 SAE 的表徵中非零值的數量少於 100 個。透過將 SAE 的表徵與解碼器相乘,可得到一個 12,288 維的重建的模型啟用。這個重建結果並不能與原始的 GPT 啟用完美匹配,因為稀疏性約束條件會讓完美匹配難以實現。

一般來說,一個 SAE 僅用於模型中的一個位置舉個例子,我們可以在 26 和 27 層之間的中間啟用上訓練一個 SAE。為了分析 GPT-3 的全部 96 層的輸出中包含的資訊,可以訓練 96 個分立的 SAE—— 每層的輸出都有一個。如果我們也想分析每一層內各種不同的中間啟用,那就需要數百個 SAE。為了獲取這些 SAE 的訓練資料,需要向這個 GPT 模型輸入大量不同的文字,然後收集每個選定位置的中間啟用。

下面提供了一個 SAE 的 PyTorch 參考實現。其中的變數帶有形狀註釋,這個點子來自 Noam Shazeer,參見:https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd 。請注意,為了儘可能地提升效能,不同的 SAE 實現往往會有不同的偏置項、歸一化方案或初始化方案。最常見的一種附加項是某種對解碼器向量範數的約束。更多細節請訪問以下實現:

  • OpenAI:https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/model.py#L16

  • SAELens:https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L97

  • dictionary_learning:https://github.com/saprmarks/dictionary_learning/blob/main/dictionary.py#L30

import torch
import torch.nn as nn

# D = d_model, F = dictionary_size
# e.g. if d_model = 12288 and dictionary_size = 49152
# then model_activations_D.shape = (12288,) and encoder_DF.weight.shape = (12288, 49152)

class SparseAutoEncoder (nn.Module):
    """
    A one-layer autoencoder.
    """
    def __init__(self, activation_dim: int, dict_size: int):
        super ().__init__()
        self.activation_dim = activation_dim
        self.dict_size = dict_size

        self.encoder_DF = nn.Linear (activation_dim, dict_size, bias=True)
        self.decoder_FD = nn.Linear (dict_size, activation_dim, bias=True)

    def encode (self, model_activations_D: torch.Tensor) -> torch.Tensor:
        return nn.ReLU ()(self.encoder_DF (model_activations_D))

    def decode (self, encoded_representation_F: torch.Tensor) -> torch.Tensor:
        return self.decoder_FD (encoded_representation_F)

    def forward_pass (self, model_activations_D: torch.Tensor) -> tuple [torch.Tensor, torch.Tensor]:
        encoded_representation_F = self.encode (model_activations_D)
        reconstructed_model_activations_D = self.decode (encoded_representation_F)
        return reconstructed_model_activations_D, encoded_representation_F

標準自編碼器的損失函式基於輸入重建結果的準確度。為了引入稀疏性,最直接的方法是向 SAE 的損失函式新增一個稀疏度懲罰項。對於這個懲罰項,最常見的計算方式是取這個 SAE 的已編碼表徵(而非 SAE 權重)的 L1 損失並將其乘以一個 L1 係數。這個 L1 係數是 SAE 訓練中的一個關鍵超引數,因為它可確定實現稀疏度與維持重建準確度之間的權衡。

請注意,這裡並沒有針對可解釋性進行最佳化。相反,可解釋的 SAE 特徵是最佳化稀疏度和重建的一個附帶效果。下面是一個參考損失函式

# B = batch size, D = d_model, F = dictionary_size
def calculate_loss (autoencoder: SparseAutoEncoder, model_activations_BD: torch.Tensor, l1_coeffient: float) -> torch.Tensor:
    reconstructed_model_activations_BD, encoded_representation_BF = autoencoder.forward_pass (model_activations_BD)
    reconstruction_error_BD = (reconstructed_model_activations_BD - model_activations_BD).pow (2)
    reconstruction_error_B = einops.reduce (reconstruction_error_BD, 'B D -> B', 'sum')
    l2_loss = reconstruction_error_B.mean ()

    l1_loss = l1_coefficient * encoded_representation_BF.sum ()
    loss = l2_loss + l1_loss
    return loss

圖片

稀疏自編碼器的前向透過示意圖。

這是稀疏自編碼器的單次前向透過過程。首先是 1x4 大小的模型向量。然後將其乘以一個 4x8 的編碼器矩陣,得到一個 1x8 的已編碼向量,然後應用 ReLU 將負值變成零。這個編碼後的向量就是稀疏的。之後,再讓其乘以一個 8x4 的解碼器矩陣,得到一個 1x4 的不完美重建的模型啟用。

假想的 SAE 特徵演示

理想情況下,SAE 表徵中的每個有效數值都對應於某個可理解的元件。

這裡假設一個案例進行說明。假設一個 12,288 維向量 [1.5, 0.2, -1.2, ...] 在 GPT-3 看來是表示「Golden Retriever」(金毛犬)。SAE 是一個形狀為 (49,512, 12,288) 的矩陣,但我們也可以將其看作是 49,512 個向量的集合,其中每個向量的形狀都是 (1, 12,288)。如果該 SAE 解碼器的 317 向量學習到了與 GPT-3 那一樣的「Golden Retriever」概念,那麼該解碼器向量大致也等於 [1.5, 0.2, -1.2, ...]。

無論何時 SAE 的啟用的 317 元素是非零的,那麼對應於「Golden Retriever」的向量(並根據 317 元素的幅度)會被新增到重建啟用中。用機械可解釋性的術語來說,這可以簡潔地描述為「解碼器向量對應於殘差流空間中特徵的線性表徵」。

也可以說有 49,512 維的已編碼表徵的 SAE 有 49,512 個特徵。特徵由對應的編碼器和解碼器向量構成。編碼器向量的作用是檢測模型的內部概念,同時最小化其它概念的干擾,儘管解碼器向量的作用是表示「真實的」特徵方向。研究者的實驗發現,每個特徵的編碼器和解碼器特徵是不一樣的,並且餘弦相似度的中位數為 0.5。在下圖中,三個紅框對應於單個特徵。

圖片

稀疏自編碼器示意圖,其中三個紅框對應於 SAE 特徵 1,綠框對應於特徵 4。每個特徵都有一個 1x4 的編碼器向量、1x1 的特徵啟用和 1x4 的解碼器向量。重建的啟用的構建僅使用了來自 SAE 特徵 1 和 4 的解碼器向量。如果紅框表示「紅顏色」,綠框表示「球」,那麼該模型可能表示「紅球」。

那麼我們該如何得知假設的特徵 317 表示什麼呢?目前而言,人們的實踐方法是尋找能最大程度啟用特徵並對它們的可解釋性給出直覺反應的輸入。能讓每個特徵啟用的輸入通常是可解釋的。

舉個例子,Anthropic 在 Claude Sonnet 上訓練了 SAE,結果發現:與金門大橋、神經科學和熱門旅遊景點相關的文字和影像會啟用不同的 SAE 特徵。其它一些特徵會被並不顯而易見的概念啟用,比如在 Pythia 上訓練的一個 SAE 的一個特徵會被這樣的概念啟用,即「用於修飾句子主語的關係從句或介詞短語的最終 token」。

由於 SAE 解碼器向量的形狀與 LLM 的中間啟用一樣,因此可簡單地透過將解碼器向量加入到模型啟用來執行因果乾預。透過讓該解碼器向量乘以一個擴充套件因子,可以調整這種干預的強度。當 Anthropic 研究者將「金門大橋」SAE 解碼器向量新增到 Claude 的啟用時,Claude 會被迫在每個響應中都提及「金門大橋」。

下面是使用假設的特徵 317 得到的因果乾預的參考實現。類似於「金門大橋」Claude,這種非常簡單的干預會迫使 GPT-3 模型在每個響應中都提及「金毛犬」。

def perform_intervention (model_activations_D: torch.Tensor, decoder_FD: torch.Tensor, scale: float) -> torch.Tensor:
    intervention_vector_D = decoder_FD [317, :]
    scaled_intervention_vector_D = intervention_vector_D * scale
    modified_model_activations_D = model_activations_D + scaled_intervention_vector_D
    return modified_model_activations_D

稀疏自編碼器的評估難題

使用 SAE 的一大主要難題是評估。我們可以訓練稀疏自編碼器來解釋語言模型,但我們沒有自然語言表示的可度量的底層 ground truth。目前而言,評估都很主觀,基本也就是「我們研究一系列特徵的啟用輸入,然後憑直覺闡述這些特徵的可解釋性。」這是可解釋性領域的主要限制。

研究者已經發現了一些似乎與特徵可解釋性相對應的常見代理指標。最常用的是 L0 和 Loss Recovered。L0 是 SAE 的已編碼中間表徵中非零元素的平均數量。Loss Recovered 是使用重建的啟用替換 GPT 的原始啟用,並測量不完美重建結果的額外損失。這兩個指標通常需要權衡考慮,因為 SAE 可能會為了提升稀疏性而選擇一個會導致重建準確度下降的解。

在比較 SAE 時,一種常用方法是繪製這兩個變數的圖表,然後檢查它們之間的權衡。為了實現更好的權衡,許多新的 SAE 方法(如 DeepMind 的 Gated SAE 和 OpenAI 的 TopK SAE)對稀疏度懲罰做了修改。下圖來自 DeepMind 的 Gated SAE 論文。Gated SAE 由紅線表示,位於圖中左上方,這表明其在這種權衡上表現更好。

圖片

Gated SAE L0 與 Loss Recovered

SAE 的度量存在多個難度層級。L0 和 Loss Recovered 是兩個代理指標。但是,在訓練時我們並不會使用它們,因為 L0 不可微分,而在 SAE 訓練期間計算 Loss Recovered 的計算成本非常高。相反,我們的訓練損失由一個 L1 懲罰項和重建內部啟用的準確度決定,而非其對下游損失的影響。

訓練損失函式並不與代理指標直接對應,並且代理指標只是對特徵可解釋性的主觀評估的代理。由於我們的真正目標是「瞭解模型的工作方式」,主觀可解釋性評估只是代理,因此還會有另一層不匹配。LLM 中的一些重要概念可能並不容易解釋,而且我們可能會在盲目最佳化可解釋性時忽視這些概念。

總結

可解釋性領域還有很長的路要走,但 SAE 是真正的進步。SAE 能實現有趣的新應用,比如一種用於查詢「金門大橋」導向向量(steering vector)這樣的導向向量的無監督方法。SAE 也能幫助我們更輕鬆地查詢語言模型中的迴路,這或可用於移除模型內部不必要的偏置。

SAE 能找到可解釋的特徵(即便目標僅僅是識別啟用中的模式),這一事實說明它們能夠揭示一些有意義的東西。還有證據表明 LLM 確實能學習到一些有意義的東西,而不僅僅是記憶表層的統計規律。

SAE 也能代表 Anthropic 等公司曾引以為目標的早期里程碑,即「用於機器學習模型的 MRI(磁共振成像)」。SAE 目前還不能提供完美的理解能力,但卻可用於檢測不良行為。SAE 和 SAE 評估的主要挑戰並非不可克服,並且現在已有很多研究者在攻堅這一課題。

有關稀疏自編碼器的進一步介紹,可參閱 Callum McDougal 的 Colab 筆記本:https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab

參考連結:

https://www.reddit.com/r/MachineLearning/comments/1eeihdl/d_an_intuitive_explanation_of_sparse_autoencoders/

https://adamkarvonen.github.io/machine_learning/2024/06/11/sae-intuitions.html

相關文章