論文解讀(Graph-MLP)《Graph-MLP: Node Classification without Message Passing in Graph》

發表於2022-04-02

論文資訊

論文標題:Graph-MLP: Node Classification without Message Passing in Graph
論文作者:Yang Hu, Haoxuan You, Zhecan Wang, Zhicheng Wang,Erjin Zhou, Yue Gao
論文來源:2021, ArXiv
論文地址:download 
論文程式碼:download

1 介紹

  本文工作:

    不使用基於訊息傳遞模組的GNNs,取而代之的是使用Graph-MLP:一個僅在計算損失時考慮結構資訊的MLP。

  任務:節點分類。在這個任務中,將由標記和未標記節點組成的圖輸入到一個模型中,輸出是未標記節點的預測。

2 方法

2.1 GNN 框架

  普通的 GNN 框架:

    $\mathbf{X}^{(l+1)}=\sigma\left(\widehat{A} \mathbf{X}^{(l)} W^{(l)}\right)\quad\quad\quad(1)$  

    $\widehat{A}=\mathbf{D}^{-\frac{1}{2}}(A+I) \mathbf{D}^{-\frac{1}{2}}\quad\quad\quad(2)$

2.2 Graph-MLP

  整體框架如下:

  

2.2.1 MLP-based Structure

  結構: linear-activation-layer normalization-dropout-linear-linear

  即:

    $\begin{array}{c} \mathbf{X}^{(1)}=\text { Dropout }\left(L N\left(\sigma\left(\mathbf{X} W^{0}\right)\right)\right) \quad\quad\quad(3)\\ \mathbf{Z}=\mathbf{X}^{(1)} W^{1} \quad\quad\quad(4)\\ \mathbf{Y}=\mathbf{Z} W^{2}\quad\quad\quad(5) \end{array}$

  其中:$Z$ 用於 NConterast 損失,$ Y$ 用於分類損失。

2.2.2 Neighbouring Contrastive Loss

  在 NContast 損失中,認為每個節點的 $\text{r-hop}$ 鄰居為正樣本,其他節點為負樣本。這種損失鼓勵正樣本更接近目標節點,並根據特徵距離推動負樣本遠離目標節點。取樣 $B$ 個鄰居,第 $i$ 個節點的 NContrast loss 可以表述為:

    ${\large \ell_{i}=-\log \frac{\sum\limits _{j=1}^{B} \mathbf{1}_{[j \neq i]} \gamma_{i j} \exp \left(\operatorname{sim}\left(\boldsymbol{z}_{i}, \boldsymbol{z}_{j}\right) / \tau\right)}{\sum\limits _{k=1}^{B} \mathbf{1}_{[k \neq i]} \exp \left(\operatorname{sim}\left(\boldsymbol{z}_{i}, \boldsymbol{z}_{k}\right) / \tau\right)}} \quad\quad\quad(6)$

  其中:$\gamma_{i j} $ 表示節點 $i$ 和節點 $j$ 之間的連線強度,這裡定義為 $\gamma_{i j}=\widehat{A}_{i j}^{r}$。

  $\gamma_{i j}$ 為非 $0$ 值當且僅當結點 $j$ 是結點 $i$ 的 $r$  跳鄰居,即: 

    $\gamma_{i j}\left\{\begin{array}{ll}=0, & \text { node } j \text { is the } r \text {-hop neighbor of node } i \\\neq 0, & \text { node } j \text { is not the } r \text {-hop neighbor of node } i \end{array}\right.$

  總 NContrast loss 為 $loss_{NC}$,而分類損失採用的是傳統的交叉熵(用 $loss_{CE}$ 表示 ),因此上述 Graph-MLP 的總損失函式如下:

    $\begin{aligned}\operatorname{loss}_{NC} &=\alpha \frac{1}{B} \sum\limits _{i=1}^{B} \ell_{i}\quad\quad\quad(7)\\\text { loss }_{\text {final }} &=\operatorname{loss}_{C E}+\operatorname{loss}_{N C}\quad\quad\quad(8) \end{aligned}$

2.2.3 Training

  整個模型以端到端的方式進行訓練。【端到端的學習正規化:整個學習的流程並不進行人為的子問題劃分,而是完全交給深度學習模型直接學習從原始資料到期望輸出的對映 】

  $\text{Graph-MLP}$ 模型不需要使用鄰接矩陣,在計算訓練期間的損失時只參考圖結構資訊。

  在每個 $batch$ 中,我們隨機抽取 $B$ 個節點並取相應的鄰接資訊 $\widehat{A} \in \mathbb{R}^{B \times B}$ 和節點特徵 $\mathbf{X} \in R^{\mathbb{R} \times d}$。對於某些節點 $i$,由於 $batch$ 抽樣的隨機性,可能會發生 $batch$ 中沒有 $\text{positive samples}$。在這種情況下,將刪除節點 $i$ 的損失。本文模型對 $\text{positive samples}$ 和  $\text{negative samples}$ 的比例是穩健的,而沒有特別調整的比例。

  演算法如  Algorithm 1 所示:

  

2.2.4 Inference

  在推斷過程中,傳統的圖模型如 GNN 同時需要鄰接矩陣和節點特徵作為輸入。不同的是,我們基於MLP的方法只需要節點特徵作為輸入。因此,當鄰接資訊被損壞或丟失時,Graph-MLP仍然可以提供一致可靠的結果。在傳統的圖建模中,圖資訊被嵌入到輸入的鄰接矩陣中。對於這些模型,圖節點轉換的學習嚴重依賴於內部訊息傳遞,而內部訊息傳遞對每個鄰接矩陣輸入中的連線都很敏感。然而,我們對圖形結構的監督是應用於損失水平的。因此,我們的框架能夠在節點特徵轉換過程中學習一個圖結構的分佈,而不需要進行前饋訊息傳遞。這使得我們的模型在推理過程中對特定連線的敏感性較低。

3 實驗

3.1 資料集

  

3.2 對引文網路節點分類資料集的效能

  

3.3 Graph-MLP 與 GNN 的效率

  

3.4 關於超引數的消融術研究

  

3.5 嵌入的視覺化 

  

3.6 魯棒性

  為了證明Graph-MLP在缺失連線下進行推斷仍具有良好的魯棒性,作者在測試過程中的鄰接矩陣中新增了噪聲,缺失連線的鄰接矩陣的計算公式如下:

    $A_{\text {corr }}=A \otimes  mask  +(1-  mask  ) \otimes \mathbb{N} \quad\quad\quad(9)$

    $\operatorname{mask}\left\{\begin{array}{ll} =1, & p=1-\delta \\ =0, & p=\delta \end{array}\right.\quad\quad\quad(10)$

  其中  $\delta$  表示缺失率,$mask  \in n \times n$  決定鄰接矩陣中缺失的位置,$mask$ 中的元素取  $1 / 0$  的概率為  $1-\delta / \delta$ 。 $\mathbb{N} \in n \times n$  中的元素取  $1 / 0$  的 概率都為  $0.5$  。

  

  結論:從上圖可以看出隨著缺失率的增加,GCN的推斷效能急劇下降,而Graph-MLP卻基本不受影響。

 

相關文章