目錄
- 概
- 符號說明
- Unlearning Layers
- Fusing Unlearning Layers
- 程式碼
Chen J. and Yang D. Unlearn what you want to forget: efficient unlearning for llms. 2024.
概
本文提出一種 Unlearning layer 去幫助 LLMs '遺忘' 一些資料.
符號說明
- \(F(\cdot)\), large language model (LLM):
- \(F'(\cdot)\), updated model;
- \(D = \{(x, y)\}\), training dataset;
- \(D^f = \{(x^f, y^f)\}\), data to be forgot;
- \(D^r = D - D^f = \{(x^r, y^r)\}\).
Unlearning Layers
-
作者希望透過微調 Unlearning Layer \(f\) 來使得模型能夠忘掉資料 \(D^f\), 如上圖所示, 就是載入每個 block 中, 它的結構式一個簡單的線性層.
-
為了達到這個目標作者首先引入 KL 散度:
\[L_{KL} = \alpha \sum_{x^r} KL(F(x^r) \| F'(x^r)) -\sum_{x^f} KL(F(x^f) \| F'(x^f)), \]即對於一般的資料點, \(F'\) 的輸出要和原來的 \(F\) 靠近, 對於需要遺忘的資料點, 則需要和原來的資料點原理 (難道遠離就是遺忘嗎? 我感覺比較均勻分佈會不會更好一點?)
-
其次為了保證下游任務的效能, 引入 task loss:
\[L_{Task} = \sum_{x^r} l(F'(x), y^r). \] -
最後是 LM 的預訓練損失, 確保 LM 本身也忘掉資料 \(D^f\),
\[L_{LM} = -\sum_{x^f} l(F'(X^f)). \] -
最後總的損失為:
\[L_{EUL} = L_{KL} + \lambda L_{TASK} + \gamma L_{LM}. \]
Fusing Unlearning Layers
- 作者還討論了, 假如我們依次遺忘了 \(m\) 次資料, 即有 \(f_1, f_2, \ldots, f_m\), 如何將這些 unlearning layers 綜合起來呢? 作者選擇求解如下的 \(W\):\[\min_{W_m} \sum_i \|W_m^T X_i^f - W_i^T X_i^f\|^2, \]它有顯式解如下:\[W_m = (\sum_i {X_i^f}^T X_i^f)^{-1} \sum_i ({X_i^f}^T X_i^f W_i). \]
程式碼
[official-code]