論文資訊
論文標題:GraphSMOTE: Imbalanced Node Classification on Graphs with Graph Neural Networks
論文作者:Tianxiang Zhao, Xiang Zhang, Suhang Wang
論文來源:2021, WSDM
論文地址:download
論文程式碼:download
1 Introduction
節點分類受限與不同類的節點數量不平衡,本文提出過取樣方法解決這個問題。
圖中類不平衡的例子:
圖中:每個藍色節點表示一個真實使用者,每個紅色節點表示一個假使用者,邊表示關係。任務是預測未標記的使用者(虛線圈)是真實的還是假的。這些類在本質上是不平衡的,因為假使用者通常還不到所有使用者的1%。半監督設定進一步放大了類不平衡問題,因為我們只給出了有限的標記資料,這使得標記的少數樣本的數量非常小。
在不平衡的節點分類中,多數類主導著損失函式,使得訓練後的 GNNs 對這些類過度分類,無法準確預測樣本。
目前解決類不平衡問題的方法可以分為:
-
- data-level approaches(資料級方法)
- algorithm-level approaches(演算法級方法)
- hybrid approaches(混合方法)
資料級方法尋求使類分佈更加平衡,使用過取樣(over-sampling)或降取樣(down-sampling)技術 [8,26];演算法級方法通常對不同的類[11,22,44]引入不同的錯誤分類懲罰或先驗概率;混合方法[9,23]將這兩者結合起來。
以前的演算法並不容易適用於圖。
-
- 首先,對產生的新樣本,很難生成邊關係。主流過取樣技術[26]利用目標示例與其最近鄰之間的插值來生成新的訓練示例。然而,插值不適合於邊,因為它們通常是離散的和稀疏的,插值可以破壞拓撲結構。
- 第二,產生的新樣本可能質量較低。節點屬性是高維的,直接對其進行插值很容易生成域外的例子,對訓練分類器不利。
2 Related work:Class Imbalance Problem
2.1 data-level method
通過過取樣和下采樣的方法調整類結構。
過取樣的一般形式是直接複製現有樣本,帶來的問題是沒有引入額外資訊,容易導致過擬合問題。
SMOTE[8] 通過生成新樣本來解決這個問題,在少數類和最近鄰的樣本之間執行插值,在此基礎上的方法:
-
- Borderline-SMOTE [15] 將過取樣限制在類邊界附近的樣本;
- Safe-Level-SMOTE [7]使用多數類鄰居計算每個插值的安全方向,以使生成的新樣本更安全;
- Cluster-based Over-sampling [17] 首先將樣本聚為不同的組,而不是單獨的過樣本,考慮到輸入空間中經常存在小區域;
下采樣丟棄多數類中的一些樣本,也可使類保持平衡,但代價是丟失一些資訊。為此,提出只刪除冗餘的樣本,如 [3,20]。
2.2 hybrid method
Cost sensitive learning [22,44] 通常構造一個成本矩陣,為不同的類分配不同的錯誤分類懲罰。效果類似於普通的的過取樣。[28] 提出了一種近似於 $F$ 測量的方法,它可以通過梯度傳播直接進行優化。
2.3 algorithm-level method
它結合了來自上述一個或兩個類別的多個演算法。[23] 使用一組分類器,每個分類器都在多數類和少數類的一個子集上進行訓練。[9] 結合了 boosting 與SMOTE,[16] 結合了過取樣與成本敏感學習。[33] 引入了三種成本敏感的增強方法,它們迭代地更新每個類的影響以及 AdaBoost引數。
3 Problem definition
在本文中,我們使用 $\mathcal{G}=\{\mathcal{V}, \mathrm{A}, \mathrm{F}\}$ 來表示一個屬性網路,其中 $\mathcal{V}=\left\{v_{1}, \ldots, v_{n}\right\}$ 是一組 $n$ 節點。$\mathrm{A} \in \mathbb{R}^{n \times n}$ 為 $\mathcal{G}$ 的鄰接矩陣, $\mathrm{F} \in \mathbb{R}^{n \times d}$ 表示節點屬性矩陣,其中 $\mathrm{F}[j,:] \in \mathbb{R}^{1 \times d}$ 為節點 $j$ 的節點屬性,$?$ 為節點屬性的維數。$\mathrm{Y} \in \mathbb{R}^{n}$ 是 $\mathcal{G}$ 中節點的類資訊。
在訓練過程中,只有 $Y$ 的一個子集 $\mathcal{V}_{L}$ 可用,其中包含節點子集 $\mathcal{V}_{L}$ 的標籤。總共有 $m$ 類,$\left\{C_{1}, \ldots, C_{m}\right\} $。$\left|C_{i}\right|$ 是第 $i$ 類的大小,指的是屬於該類的樣本數量。我們使用不平衡率 $\frac{\min _{i}\left(\left|C_{i}\right|\right)}{\max _{i}\left(\left|C_{i}\right|\right)}$ 來衡量類不平衡的程度。在不平衡設定下,$\mathrm{Y}_{L}$ 的不平衡比較小。
給定節點類不平衡的 $\mathcal{G}$,以及節點 $\mathcal{V}_{L}$子集的標籤,目標是學習一個節點分類器 $f$,可適用於多數類和少數類,即:
$f(\mathcal{V}, \mathbf{A}, \mathbf{F}) \rightarrow \mathbf{Y}\quad\quad\quad(1)$
4 Method
框架如下:
GraphSMOTE 的組成部分:
-
- a GNNbased feature extractor;
- Synthetic Node Generation;
- Edge Generator;
- GNN Classifier;
4.1 Feature Extractor
SMOTE 用於原始節點特徵空間,帶來的問題是:
-
- 原始特徵空間可能是稀疏和高維的,且特徵空間不好;
- 未考慮圖的結構,可能會導致次優的合成節點;
本文提出的過取樣方法同時考慮了節點表示和拓撲結構,並且遵循了同質性假設。本文研究中使用 GraphSage 作為主幹模型結構來提取特徵:
$\mathbf{h}_{v}^{1}=\sigma\left(\mathbf{W}^{1} \cdot \operatorname{CONCAT}(\mathbf{F}[v,:], \mathbf{F} \cdot \mathbf{A}[:, v])\right)\quad\quad\quad(2)$
其中,
-
- $F$ 表示輸入特徵矩陣,$\mathbf{F}[v,:]$ 表示節點 $v$ 的特徵;
- $\mathrm{A}[:, v]$ 為鄰接矩陣中的第 $v$ 列;
- $\mathbf{h}_{v}^{1}$ 為節點 $v$ 的嵌入;
- $\mathbf{W}^{1}$ 為權值引數;
- $\sigma$ 為 ReLU等啟用函式;
4.2 Synthetic Node Generation
對少數類節點採用 SMOTE 演算法思想:對目標少數類的樣本與嵌入空間中屬於同一類的最近鄰樣本進行插值。
設 $\mathbf{h}_{v}^{1}$ 為一個帶標記的少數類節點,標記為 $Y_{v}$。第一步是找到與 $\mathbf{h}_{v}^{1}$ 在同一個類中的最近的標記節點,即:
$n n(v)=\underset{u}{\operatorname{argmin}}\left\|\mathbf{h}_{u}^{1}-\mathbf{h}_{v}^{1}\right\|, \quad \text { s.t. } \quad \mathbf{Y}_{u}=\mathbf{Y}_{v}\quad\quad\quad(3)$
其中,$n n(v)$ 是指同一類中 $v$ 的最近鄰,可以生成合成節點為:
$\mathbf{h}_{v^{\prime}}^{1}=(1-\delta) \cdot \mathbf{h}_{v}^{1}+\delta \cdot \mathbf{h}_{n n(v)}^{1}\quad\quad\quad(4)$
其中,$\delta$ 為一個隨機變數,在 $[0,1]$ 範圍內呈均勻分佈。由於 $\mathbf{h}_{v}^{1}$ 和 $\mathbf{h}_{n n(v)}^{1}$ 屬於同一個類,且非常接近,因此生成的合成節點 $\mathbf{h}_{v^{\prime}}^{1}$ 也應屬同一個類。
4.3 Edge Generator
邊生成器是一個加權內積:
$\mathbf{E}_{v, u}=\operatorname{softmax}\left(\sigma\left(\mathbf{h}_{v}^{1} \cdot \mathbf{S} \cdot \mathbf{h}_{u}^{1}\right)\right)\quad\quad\quad(5)$
其中,$\mathbf{E}_{v, u}$ 為節點 $v$ 和 $u$ 之間的預測關係資訊,$\mathrm{S}$ 為捕獲節點間相互作用的引數矩陣。
邊生成器的損失函式為:
$\mathcal{L}_{e d g e}=\|\mathbf{E}-\mathbf{A}\|_{F}^{2}\quad\quad\quad(6)$
此時,並沒有合成節點,而是學習一個好的引數矩陣 $S$ 。利用邊生成器,本文嘗試了兩種策略:
第一種,該生成器只使用邊重建來進行優化,而合成節點 $v^{\prime}$ 的邊是通過設定一個閾值 $\eta$ 生成:
$\tilde{\mathrm{A}}\left[v^{\prime}, u\right]=\left\{\begin{array}{ll}1, & \text { if } \mathbf{E}_{v^{\prime}, u}>\eta \\0, & \text { otherwise }\end{array}\right.$
其中,$\tilde{\mathrm{A}}$ 是過取樣後的鄰接矩陣,通過在 $A$ 中插入新的節點和邊,並將其傳送給分類器。
第二種,對於合成節點 $v^{\prime}$,使用軟邊而不是二進位制邊:
$\tilde{\mathbf{A}}\left[v^{\prime}, u\right]=\mathbf{E}_{v^{\prime}, u}$
在這種情況下,$\tilde{A}$ 上的梯度可以從分類器中傳播,因此可以同時使用邊緣預測損失和節點分類損失對生成器進行優化。
4.4 GNN Classifier
設 $\tilde{\mathbf{H}}^{1}$ 為將 $\mathbf{H}^{1}$ 與合成節點的嵌入連線起來的增廣節點表示集,$\tilde{V}_{L}$ 為將合成節點合併到 $\tilde{V}_{L}$ 中的增廣標記集。
對於當前的增強圖 $\tilde{\mathcal{G}}= \{\tilde{\mathrm{A}}, \tilde{\mathbf{H}}\}$ 與標記節點集 $\tilde{V}_{L}$。在 $\tilde{G}$ 中,不同類的資料大小變得平衡,並且一個無偏的GNN分類器將可以在這上面進行訓練。
本文采用另一個 GraphSage 塊,在 $\tilde{G}$ 上附加一個線性層進行節點分類,如下:
$\mathbf{h}_{v}^{2}=\sigma\left(\mathbf{W}^{2} \cdot \operatorname{CONCAT}\left(\mathbf{h}_{v}^{1}, \tilde{\mathbf{H}}^{1} \cdot \tilde{\mathbf{A}}[:, v]\right)\right)\quad\quad\quad(9)$
$\mathbf{P}_{v}=\operatorname{softmax}\left(\sigma\left(\mathbf{W}^{c} \cdot \operatorname{CONCAT}\left(\mathbf{h}_{v}^{2}, \mathbf{H}^{2} \cdot \tilde{\mathbf{A}}[:, v]\right)\right)\right)\quad\quad\quad(10)$
式中,$\mathbf{P}_{v}$ 是節點 $v$ 在類標籤上的概率分佈,利用交叉熵損失進行優化,如下:
$\mathcal{L}_{n o d e}=\sum\limits _{u \in \tilde{V}_{L}} \sum\limits_{c}\left(1\left(Y_{u}==c\right) \cdot \log \left(\mathrm{P}_{v}[c]\right)\right.\quad\quad\quad(11)$
在測試過程中,將節點 $v$ 的預測類設定為概率最高的類 $\mathrm{Y}_{v}^{\prime}$。
$\mathbf{Y}_{v}^{\prime}=\underset{c}{\operatorname{argmax}} \mathbf{P}_{v, c}\quad\quad\quad(12)$
4.5 Optimization Objective
GraphSMOTE 的最終目標函式可以寫成:
$\underset{\theta, \phi, \varphi}{\text{min }} \mathcal{L}_{\text {node }}+\lambda \cdot \mathcal{L}_{e d g e}\quad\quad\quad(13)$
其中,$ \theta,$、$\phi$、$\varphi$ 分別為特徵提取器、邊緣生成器和節點分類器的引數。由於模型的效能依賴於嵌入空間和生成的邊的質量,為了使訓練短語更加穩定,我們還嘗試了使用 $\mathcal{L}_{e d g e}$ 進行訓練前的特徵提取器和邊生成器。
4.6 Training Algorithm
完整演算法如 Algorithm 1 :
5 Experiment
資料集
baseline
- Over-sampling:重複少數節點,$n_s$ 代表重複的少數節點數量,在每次訓練迭代中,$\mathcal{V}$ 被過取樣以包含 $n+n_{s}$ 節點,和 $\mathrm{A} \in \mathbb{R}^{\left(n+n_{s}\right) \times\left(n+n_{s}\right)}$。
- Re-weight [41]:一種成本敏感的方法,給少數樣本分配較高的損失權重,以緩解多數類主導損失函式的問題。
- SMOTE [8]: 合成少數過取樣技術通過插值一個少數樣本及其同類的最近鄰來生成合成少數樣本。對於新生成的節點,將其邊被設定為與目標節點相同。
- Embed-SMOTE [1]:SMOTE 的擴充套件,在中間嵌入層而不是輸入域執行過取樣。我們將其設定為最後一個GNN層的輸出,因此不需要生成邊。
- $\text { GraphSMOTE }_{T}$:邊生成器僅使用邊預測任務中的損失進行訓練。
- $\text { GraphSMOTE }_{O}$: 預測邊緣被設定為連續,以便從基於gnn的分類器計算和傳播梯度。利用邊緣生成任務和節點分類任務的訓練訊號,將邊緣生成器與其他元件一起進行訓練;
- $\text { GraphSMOTE }_{preT}$:是 $\text { GraphSMOTE }_{T}$ 的擴充套件,對特徵提取器和邊緣生成器進行預訓練,然後對 $\text{Eq.13}$ 進行微調。 在微調過程中,邊緣生成器的優化僅使用 $\mathcal{L}_{\text {edges }}$ ;
- $\text { GraphSMOTE }_{preO}$:是 $\text { GraphSMOTE }_{O}$ 的擴充套件,一個訓練前的過程也會在微調之前進行,比如 $\text { GraphSMOTE }_{preT}$。
不平衡的節點分類
過取樣量的影響
設定不平平衡率為 $0.5$,過取樣的尺度為 $\{0.2,0.4,0.6,0.8,1.0,1.2\}$。
不平衡比的影響
設定過取樣的尺度為 $1$ ,不平衡率為 $\{0.1,0.2,0.4,0.6\}$。
基礎模型的影響
基礎模型一個採用 GCN,一個採用 GraphSAGE。
引數敏感性分析
6 Conclusion
圖中節點的類不平衡問題廣泛存在於現實世界的任務中,如假使用者檢測、網頁分類、惡意機器檢測等。這個問題會顯著影響分類器在這些少數類上的效能,但在以前的工作中沒有被考慮。因此,在本工作中,我們研究了這個不平衡的節點分類任務。具體來說,我們提出了一個新的框架GraphSMOTE,它將以前的i.i.d資料的過取樣演算法擴充套件到這個圖設定。具體地說,GraphSMOTE構造了一個具有特徵提取器的中間嵌入空間,並在此基礎上同時訓練一個邊緣生成器和一個基於gnn的節點分類器。在一個人工資料集和兩個真實資料集上進行的實驗證明了它的有效性,即大幅度地優於所有其他基線。進行消融研究是為了瞭解GraphSMOTE在各種場景下的效能。並進行了引數敏感性分析,以瞭解GraphSMOTE對超引數的敏感性。