Unlearn What You Want to Forget Efficient Unlearning for LLMs

馒头and花卷發表於2024-06-02

目錄
  • 符號說明
  • Unlearning Layers
    • Fusing Unlearning Layers
  • 程式碼

Chen J. and Yang D. Unlearn what you want to forget: efficient unlearning for llms. 2024.

本文提出一種 Unlearning layer 去幫助 LLMs '遺忘' 一些資料.

符號說明

  • \(F(\cdot)\), large language model (LLM):
  • \(F'(\cdot)\), updated model;
  • \(D = \{(x, y)\}\), training dataset;
  • \(D^f = \{(x^f, y^f)\}\), data to be forgot;
  • \(D^r = D - D^f = \{(x^r, y^r)\}\).

Unlearning Layers

  • 作者希望透過微調 Unlearning Layer \(f\) 來使得模型能夠忘掉資料 \(D^f\), 如上圖所示, 就是載入每個 block 中, 它的結構式一個簡單的線性層.

  • 為了達到這個目標作者首先引入 KL 散度:

    \[L_{KL} = \alpha \sum_{x^r} KL(F(x^r) \| F'(x^r)) -\sum_{x^f} KL(F(x^f) \| F'(x^f)), \]

    即對於一般的資料點, \(F'\) 的輸出要和原來的 \(F\) 靠近, 對於需要遺忘的資料點, 則需要和原來的資料點原理 (難道遠離就是遺忘嗎? 我感覺比較均勻分佈會不會更好一點?)

  • 其次為了保證下游任務的效能, 引入 task loss:

    \[L_{Task} = \sum_{x^r} l(F'(x), y^r). \]

  • 最後是 LM 的預訓練損失, 確保 LM 本身也忘掉資料 \(D^f\),

    \[L_{LM} = -\sum_{x^f} l(F'(X^f)). \]

  • 最後總的損失為:

    \[L_{EUL} = L_{KL} + \lambda L_{TASK} + \gamma L_{LM}. \]

Fusing Unlearning Layers

  • 作者還討論了, 假如我們依次遺忘了 \(m\) 次資料, 即有 \(f_1, f_2, \ldots, f_m\), 如何將這些 unlearning layers 綜合起來呢? 作者選擇求解如下的 \(W\):

    \[\min_{W_m} \sum_i \|W_m^T X_i^f - W_i^T X_i^f\|^2, \]

    它有顯式解如下:

    \[W_m = (\sum_i {X_i^f}^T X_i^f)^{-1} \sum_i ({X_i^f}^T X_i^f W_i). \]

程式碼

[official-code]

相關文章