Paper Information
Title:Fast Multi-Resolution Transformer Fine-tuning for Extreme Multi-label Text Classification
Authors:Jiong Zhang, Wei-Cheng Chang, Hsiang-Fu Yu, I. Dhillon
Sources:2021, ArXiv
Other:3 Citations, 61 References
Paper:download
Code:download
1 背景知識
訓練集 $\left\{\mathbf{x}_{i}, \mathbf{y}_{i}\right\}_{i=1}^{N} $,$\mathbf{x}_{i} \in \mathcal{D}$ 代表著第 $i$ 個文件,$\mathbf{y}_{i} \in\{0,1\}^{L}$ 是第$i$個樣本的第 $\ell$ 個標籤。
eXtreme Multi-label Text Classification (XMC) 目標是尋找一個這樣的函式 $f: \mathcal{D} \times[L] \mapsto \mathbb{R}$,$f(x,\ell)$ 表示輸入 $x$ 與標籤 $\ell$ 之間的相關性。
實際上,得到 $top-k$ 個最大值的索引作為給定輸入 $x$ 的預測相關標籤。最直接的模型是一對全(OVA)模型:
$f(\mathbf{x}, \ell)=\mathbf{w}_{\ell}^{\top} \Phi(\mathbf{x}) ; \ell \in[L]\quad\quad\quad(1)$
其中
-
- $\mathbf{W}=\left[\mathbf{w}_{1}, \ldots, \mathbf{w}_{L}\right] \in \mathbb{R}^{d \times L}$ 是權重向量
- $\Phi(\cdot)$ 是一個文字向量轉換器,$\Phi: \mathcal{D} \mapsto \mathbb{R}^{d}$用於將 $\mathbf{x}$轉換為 $d$ 維特徵向量
為了處理非常大的輸出空間,最近的方法對標籤空間進行了劃分,以篩選在訓練和推理過程中考慮的標籤。特別是 [7, 12, 13, 34, 35, 39] 遵循三個階段的框架:partitioning、shortlisting 和 ranking。
首先 partitioning 過程,將標籤分成 $K$ 個簇 $\mathbf{C} \in\{0,1\}^{L \times K}$ ,$C_{\ell, k}=1$ 代表這標籤 $\ell $ 在第 $k$ 個簇中。
然後 shortlisting 過程,將輸入 $x$ 對映到相關的簇當中:
$g(\mathbf{x}, k)=\hat{\mathbf{w}}_{k}^{\top} \Phi_{g}(\mathbf{x}) ; k \in[K]\quad\quad\quad(2)$
最後 ranking 過程,在 shortlisted 上訓練一個輸出大小為 $L $ 的分類模型:
$f(\mathbf{x}, \ell)=\mathbf{w}_{\ell}^{\top} \Phi(\mathbf{x}) ; \ell \in S_{g}(\mathbf{x})\quad\quad\quad(3)$
其中 $S_{q}(\mathbf{x}) \subset[L]$ 是標籤集的一個子集。
對於基於 transformer 的方法,主要花費的時間是 $\Phi(\mathbf{x})$ 的評價。但是 $K$ 值太大或太小仍然可能會有問題。實證結果表明,當 cluster 的大小 $B$ 太大時,模型的效能會下降。典型的 X-Transformer 和 LightXML ,他們的簇大小$B$ 通常 $B(\leq 100)$ ,聚類數 $K$ 通常為 $K \approx L / B$。
2 XR-Transformer 方法
在 XR-Transformer 中,我們遞迴地對 shortlisting 問題應用相同的三階段框架,直到達到一個相當小的輸出大小 $\frac{L}{B^{D}}$。
2.1 Hierarchical Label Tree (HLT)
遞迴生成標籤簇 $D$ 次,相當於構建一個深度為 $D$ 的 HLT。我們首先構建標籤特徵 $\mathbf{Z} \in \mathbb{R}^{L \times \hat{d}}$。這可以通過在標籤文字上應用文字向量量化器,或者從 Positive Instance Feature Aggregation(PIFA) 中實現:
$\mathbf{Z}_{\ell}=\frac{\mathbf{v}_{\ell}}{\left\|\mathbf{v}_{\ell}\right\|} ; \text { where } \mathbf{v}_{\ell}=\sum\limits _{i: y_{i, \ell}=1} \Phi\left(\mathbf{x}_{i}\right), \forall \ell \in[L]\quad\quad\quad(4)$
其中:$\Phi: \mathcal{D} \mapsto \mathbb{R}^{d}$是文字向量化轉換器。
使用平衡的 k-means($k=B$) 遞迴地劃分標籤集,並以自上而下的方式生成 HLT。
$\left\{\mathbf{C}^{(t)}\right\}_{t=1}^{D}$
其中 $\mathbf{C}^{(t)} \in\{0,1\}^{K_{t} \times K_{t-1}}$ with $K_{0}=1$、$K_{D}=L$
2.2 Multi-resolution Output Space
粗粒度的標籤向量可以通過對原始標籤進行max-pooling得到(在標籤空間中)。第 $t$ 層的真實標籤(偽標籤)為:
$\mathbf{Y}^{(t)}=\operatorname{binarize}\left(\mathbf{Y}^{(t+1)} \mathbf{C}^{(t+1)}\right)\quad\quad\quad(5)$
然而,直接用以上訓練方式會造成資訊損失。直接做max-pooling的方法無法區分:一個cluster中有多個真實標籤和一個cluster中有一個真實標籤。直觀上,前者應該有更高的權重。
因而,通過一個非負的重要性權重指示每個樣本對每個標籤的重要程度:
$\mathbf{R}^{(t)} \in \mathbb{R}_{+}^{N \times K_{t}}$
該重要性權重矩陣通過遞迴方式構建,最底層的重要性權重為原始 標籤歸一化。之後遞迴地將上一層的結果傳遞到下一層。
$\mathbf{R}^{(t)}=\mathbf{R}^{(t+1)} \mathbf{C}^{(t+1)} \quad \quad (6)$
$\mathbf{R}^{(D)}=\mathbf{Y}^{(D)}$
其中:
$\hat{R}_{i, j}^{(t)}=\left\{\begin{array}{ll}\frac{R_{i, j}^{(t)}}{\left\|\mathbf{R}_{i}^{(t)}\right\|_{1}} & \text { if } Y_{i, j}^{(t)}=1 \\ \alpha & \text { otherwise } \end{array}\right.$
2.3 Label Shortlisting
在每一層,不能只關注於少量真實的標籤,還需要關注於一些高置信度的非真實標籤。(因為分類不是100%準確,要給演算法一些容錯度,之後用 beam search 矯正)
在每一層,將模型預測出的 top-k relevant clusters 作為父節點。因而,在第 $t$ 層我們需要考慮 $t-1$ 層的標籤列表。
$\begin{aligned}&\mathbf{P}^{(t-1)} =\operatorname{Top}\left(\mathbf{W}^{(t-1) \top} \Phi\left(\mathbf{X}, \Theta^{(t-1)}\right), k\right)\quad\quad\quad(7)\\&\mathbf{M}^{(t)} =\operatorname{binarize}\left(\mathbf{P}^{(t-1)} \mathbf{C}^{(t) \top}\right)+\operatorname{binarize}\left(\mathbf{Y}^{(t-1)} \mathbf{C}^{(t) \top}\right)\quad\quad\quad(8)\end{aligned}$
2.4 Training with bootstrapping
我們利用遞迴學習結構,通過模型自舉來解決這個問題。
$\mathbf{W}_{i n i t}^{(t)}:=\underset{\mathbf{W}^{(t)}}{\operatorname{argmin}} \sum\limits _{i=1}^{N} \sum\limits_{\ell: \mathbf{M}_{i, \ell}^{(t)} \neq 0} \hat{R}_{i, \ell}^{(t)} \mathcal{L}\left(Y_{i, \ell}^{(t)}, \mathbf{W}_{\ell}^{(t) \top} \Phi_{d n n}\left(\mathbf{x}_{i}, \boldsymbol{\theta}^{(t-1) *}\right)\right)+\lambda\left\|\mathbf{W}^{(t)}\right\|^{2}\quad\quad\quad(11)$
3 Algorithm