論文解讀(S^3-CL)《Structural and Semantic Contrastive Learning for Self-supervised Node Representation Learning》

發表於2022-04-09

論文資訊

論文標題:Structural and Semantic Contrastive Learning for Self-supervised Node Representation Learning
論文作者: Kaize Ding 、Yancheng Wang 、Yingzhen Yang、Huan Liu
論文來源:2022, arXiv
論文地址:download
論文程式碼:download

1 摘要

  Graph Contrastive Learning (GCL) 研究方向:
    • encoding architecture
    • augmentation
    • contrastive objective

2 介紹

  現存兩階段對比學習框架存在的問題:

    • Shallow Encoding Architecture  
    • Arbitrary Augmentation Design  
    • Semanticless Contrastive Objective  

  定義1:

    • 自監督節點表示學習: 給定一個屬性圖$\mathcal{G}=   (\mathbf{X}, \mathbf{A})$ ,目標是學習一個圖編碼器  $f_{\boldsymbol{\theta}}: \mathbb{R}^{N \times D} \times   \mathbb{R}^{N \times N} \rightarrow \mathbb{R}^{N \times D^{\prime}}$ ,不使用標籤資訊,這樣生成的節點表示 $\mathbf{H} \in \mathbb{R}^{N \times D^{\prime}}=f_{\boldsymbol{\theta}}(\mathbf{X}, \mathbf{A})$ 可以用於不同的下游任務。

  本文提出的方法: Simple Neural Networks with Structural and Semantic Contrastive Leanring($S^3-CL$),包括三個部分:

    • An encoder network
    • A structural contrastive learning module
    • A semantic contrastive learning module

  框架如下所示:

  

3 方法

3.1 Structural Contrastive Learning

  在無監督表示學習中,對比學習方法將每個樣本視為一個不同的類,旨在實現例項區分。以類似的方式,現有的GCL方法通過最大化不同增強檢視中相同圖元素的表示之間的一致性來實現節點級識別。

3.1.1 Structure Augmentation via Graph Diffusion

  Step1 :Graph Diffusion

    Graph Diffusion :

    $\mathbf{S}=\sum\limits _{l=0}^{\infty} \theta_{l} \mathbf{T}^{l} \in \mathbb{R}^{N \times N}\quad\quad\quad(1)$

  其中:

    • $\mathbf{T} \in \mathbb{R}^{N \times N}$ 是轉移概率矩陣,$\mathbf{T} = \tilde{\mathbf{A}}_{s y m}=\tilde{\mathbf{D}}^{-1 / 2} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-1 / 2}$;  
    • $\theta_{l}$ 是 $ l^{t h} -hop$ 鄰居的權重引數,此處設定 $\theta_{l}=1 $,其他的權重引數設定為 $0$;  

  通過上述引數設定,那麼:$\mathbf{S}^{(l)}=\tilde{\mathbf{A}}_{s y m}^{l}$ 。

  Step2:Feature propagation

  然後,使用增廣圖結構進行特徵傳播,並使用一層編碼器網路 $f_{\boldsymbol{\theta}}(\cdot )$ 進一步計算節點表示,如下:

    $\mathbf{H}^{(l)}=f_{\boldsymbol{\theta}}\left(\mathbf{S}^{(l)} \mathbf{X}\right)=\operatorname{ReLU}\left(\mathbf{S}^{(l)} \mathbf{X} \Theta\right)\quad\quad\quad(2)$

  其中,$ \Theta \in \mathbb{R}^{N \times D^{\prime}}$ 代表著權重引數,計算出的節點表示 $\mathbf{H}^{(l)}$ 可以編碼圖中來自 $l-hops$ 鄰域的特徵資訊。

  為更好的利用區域性-全域性資訊,本文進一步設定不同的 $\text{l}$ 來執行多個資料增強。具體來說,$\mathbf{H}^{(1)}$ 是從區域性檢視中學習的,因為只用了在直接鄰居之間傳遞的資訊,而 $\left\{\mathbf{H}^{(l)}\right\}_{l=2}^{L}$ 是從一組高階檢視中學習的,這些高階檢視編碼了不同級別的全域性結構資訊。  

  本文的對比學習架構中,目標是通過最大化每個節點的互資訊,最大化每個節點的區域性檢視和高階檢視之間的一致性。

  這裡並沒有直接使用 $f_{\boldsymbol{\theta}}(\cdot )$ 輸出的表示,而是進一步對生成的表示 $\mathbf{H}^{(l)}$ 使用一個投影頭 $g_{\psi}(\cdot)$ 【一個兩層MLP】,即生成 $\left\{\mathbf{U}^{(l)}\right\}_{l=1}^{L} $,其中 $ \mathbf{U}^{(l)}=g_{\psi}\left(\mathbf{H}^{(l)}\right)$。

3.1.2 Structural Contrastive Objective

  最大化區域性檢視 $\mathbf{U}^{1}$ ,和高階檢視 $\mathrm{U}^{(l)}$ 之間的一致性(節點級):

    $\mathcal{L}_{s t r}^{(l)}=-\frac{1}{N} \sum\limits_{i=1}^{N} \log \frac{\exp \left(\mathbf{u}_{i}^{(1)} \cdot \mathbf{u}_{i}^{(l)} / \tau_{1}\right)}{\sum\limits_{j=0}^{M} \exp \left(\mathbf{u}_{i}^{(1)} \cdot \mathbf{u}_{j}^{(l)} / \tau_{1}\right)}\quad\quad\quad(3)$

  其中:

    • $\mathbf{u}_{i}^{(1)}$  和  $\mathbf{u}_{i}^{(l)}$  分別代表著 $\mathbf{U}^{(1)}$ 和 $\mathbf{U}^{(l)}$ 的第 $i$ 行表示向量;
    • 定義$\left\{\mathbf{u}_{j}^{(l)}\right\}_{j=0}^{M}$ 擁有一個正樣本,$M$ 個負樣本;

  所以總損失為:

    $\mathcal{L}_{s t r}=\sum\limits _{l=2}^{L} \mathcal{L}_{s t r}^{(l)}\quad\quad\quad(4)$

  最終的節點表示 $H$ 可以通過將 $\tilde{\mathbf{X}}=\frac{1}{L} \sum_{l=1}^{L} \mathbf{S}^{(l)} \mathbf{X}$ 輸入編碼器網路來計算,以保留區域性和全域性結構資訊。

3.2 Semantic Contrastive Learning

  為探討輸入圖的語義資訊,進一步提出了一個語義對比學習模組,該模組通過鼓勵簇內緊湊性和簇間可分離性來明確地捕獲資料語義結構。

  它通過迭代的推斷節點和其對應原型之間的聚類,並進行語義對比學習,促進節點在潛在空間中對應的聚類原型進行語義相似的聚類。Cluster 原型的表示用一個矩陣表示 $\mathbf{C} \in \mathbb{R}^{K \times D^{\prime}} $ ,這裡的 $K$ 代表著原型數目。$\mathbf{c}_{k}$ 代表著 $\mathbf{C} $ 的第 $k$ 行,即第 $k$ 個原型的表示向量。節點 $v_i$ 的原型分配定義為 $\mathcal{Z}=\left\{z_{i}\right\}_{i=1}^{n}$,其中$z_{i} \in\{1, \ldots, K\} $ 。

3.2.1 Bayesian Non-parametric Prototype Inference.

  我們的語義對比學習模組的一個關鍵組成部分是推斷出具有高度代表性的 cluster prototypes。然而,在自監督節點表示學習的設定下,最優的聚類數量是未知的,因此很難直接採用 K-means 等聚類方法對節點進行聚類。為了解決這個問題,我們提出了一種貝葉斯非引數原型推理(  Bayesian non-parametric prototype inference)演算法來近似最優的聚類數量並計算聚類原型。

  我們建立了一個狄利克雷過程混合模型(Dirichlet Process Mixture Model (DPMM)),並假設節點表示的分佈是一個高斯混合模型(GMM),其分量具有相同的固定協方差矩陣$\sigma \mathbf{I}$。每個元件都用於建模一個 cluster 的原型。DPMM 模型可以定義為:

    $\begin{array}{llrl}G & \sim \operatorname{DP}\left(G_{0}, \alpha\right) & & \\\phi_{i} & \sim G & & \text { for } i=1, \ldots, N \\\mathbf{h}_{i} & \sim \mathcal{N}\left(\phi_{i}, \sigma \mathbf{I}\right) & & \text { for } i=1, \ldots, N\end{array}\quad\quad\quad(5)$

  其中:

    • $G$ 是由狄利克雷過程 $\operatorname{DP}\left(G_{0}, \alpha\right) $ 得到的高斯分佈;  
    • $\alpha$ 是 $\operatorname{DP}\left(G_{0}, \alpha\right)$ 的濃度引數;  
    • $\phi_{i}$ 是節點表示的高斯取樣的均值;  
    • $ G_{0} $ 是高斯分佈的的先驗均值,本文取 $G_{0}$ 為一個零均值高斯 $\mathcal{N}(\mathbf{0}, \rho \mathbf{I})$;
    • $\rho \mathbf{I}$ 是協方差矩陣;

  接下來,我們使用一個摺疊的 $\text{collapsed Gibbs sampler }$ 用 DPMM來推斷 GMM 的元件。$\text{Gibbs sampler }$對給定高斯分量均值的節點的偽標籤進行迭代取樣,並對給定節點的偽標籤的高斯分量的均值進行取樣。當高斯分量 $\sigma \rightarrow 0$ 的方差變化時,對偽標籤進行取樣的過程變得確定性的。設 $\tilde{K}$ 表示當前迭代步驟中推斷出的原型數量,原型分配更新可以表述為:

    $z_{i}=\underset{k}{\arg \min }\left\{d_{i k}\right\}, \quad \text { for } i=1, \ldots, N$

    $d_{i k}=\left\{\begin{array}{ll} \left\|\mathbf{h}_{i}-\mathbf{c}_{k}\right\|^{2} & \text { for } k=1, \ldots, \tilde{K} \\ \xi & \text { for } k=\tilde{K}+1 \end{array}\right.\quad\quad\quad(6)$

  其中,$d_{i k}$ 是確定節點表示 $\mathbf{h}_{i}$ 的偽標籤的度量。$ \xi $ 是初始化一個新原型的邊際。在實踐中,我們通過對每個資料集進行交叉驗證來選擇 $ \xi $ 的值。根據 $\text{Eq.6}$ 中的公式,將一個節點分配給由最近高斯均值對應的分量建模的原型,除非到最近均值的平方歐氏距離大於 $ \xi $。在獲得偽標籤後,可以通過以下方法計算叢集原型表示:

    ${\large \mathbf{c}_{k}=\frac{\sum\limits _{z_{i}=k} \mathbf{h}_{i}}{\sum\limits_{z_{i}=k} 1}} , \quad \text { for } k=1, \ldots, \tilde{K}\quad\quad\quad(7)$

  注意,我們迭代地更新原型分配和原型表示直到收斂,然後我們將原型的數量 $K$ 設定為推斷出的原型的數量 $\tilde{K}$。原型推理的演算法在附錄中總結為 Algorithm 2。

3.2.2 Prototype Refinement via Label Propagation

  考慮到貝葉斯非引數演算法推斷出的偽標籤可能不準確,我們進一步基於標籤傳播對 Gibbs sampler 生成的偽標籤進行了重新細化。通過這種方法,我們可以平滑噪聲偽標籤,並利用結構知識改進聚類原型表示。

  首先,我們將原型賦值 $\mathcal{Z}$ 轉換為一個單熱的偽標籤矩陣 $\mathbf{Z} \in \mathbb{R}^{N \times K}$ ,其中 $\mathbf{Z}_{i j}=1$ 當且僅當 $z_{i}=k$。根據個性化PageRank(PPR)的想法,$T $ 聚合步後 $\mathbf{Z}^{(T)}$ 的偽標籤更新為:

    $\mathbf{Z}^{(t+1)}=(1-\beta) \tilde{\mathbf{A}}_{s y m} \mathbf{Z}^{(t)}+\beta \mathbf{Z}^{(0)}\quad\quad\quad(8)$

  其中,$ \mathbf{Z}^{(0)}=\mathbf{Z}$ 和 $\beta$ 可以視為 PPR 中的轉移概率。接下來,我們通過設定 $z_{i}=   \arg \max _{k} \mathbf{Z}_{i k}^{(T)}$  $i \in\{1, \ldots, N\}$ ,將傳播的結果 $i \in\{1, \ldots, N\}$  轉換為硬偽標籤。

  在使用標籤傳播對偽標籤 $\mathcal{Z}$ 進行細化後,我們使用每個叢集中節點表示的平均值作為叢集原型表示,由$\mathbf{c}_{k}=\sum_{z_{i}=k} \mathbf{h}_{i} / \sum_{z_{i}=k} 1$ 計算出。

3.2.3 Semantic Contrastive Objective

  給定原型分配 $\mathcal{Z}$ 和原型表示 $\mathbf{C}$,我們的語義對比學習旨在找到網路引數 $\theta$,最大化對數似然定義為:

    $Q(\boldsymbol{\theta})=\sum\limits _{i=1}^{N} \log p\left(\mathbf{x}_{i} \mid \boldsymbol{\theta}, \mathbf{C}\right)\quad\quad\quad(9)$

  其中 $p$ 是概率密度函式。作為 $p\left(\mathbf{x}_{i} \mid \boldsymbol{\theta}, \mathbf{C}\right)= \sum_{k=1}^{K} \log p\left(\mathbf{x}_{i}, z_{i}=k \mid \boldsymbol{\theta}, \mathbf{C}\right)$,我們得到

    $Q(\boldsymbol{\theta})=\sum\limits _{i=1}^{N} \sum\limits _{k=1}^{K} \log p\left(\mathbf{x}_{i}, z_{i}=k \mid \boldsymbol{\theta}, \mathbf{C}\right)\quad\quad\quad(10)$

  $Q(\boldsymbol{\theta})$ 的變分下界由

    $\begin{aligned}Q(\boldsymbol{\theta}) & \geq \sum\limits_{i=1}^{N} \sum\limits_{k=1}^{K} q\left(k \mid \mathbf{x}_{i}\right) \log \frac{p\left(\mathbf{x}_{i}, z_{i}=k \mid \boldsymbol{\theta}, \mathbf{C}\right)}{q\left(k \mid \mathbf{x}_{i}\right)} \\&=\sum\limits_{i=1}^{N} \sum\limits_{k=1}^{K} q\left(k \mid \mathbf{x}_{i}\right) \log p\left(\mathbf{x}_{i}, z_{i}=k \mid \boldsymbol{\theta}, \mathbf{C}\right) -\sum\limits_{i=1}^{N} \sum\limits_{k=1}^{K} q\left(k \mid \mathbf{x}_{i}\right) \log q\left(k \mid \mathbf{x}_{i}\right)\end{aligned}\quad\quad\quad(11)$

  其中,$q\left(k \mid \mathbf{x}_{i}\right)=p\left(z_{i}=k \mid \mathbf{x}_{i}, \boldsymbol{\theta}, \mathbf{C}\right)$ 表示 $z_{i}$ 的後部。由於上面的第二項是一個常數,我們可以通過最小化函 數 $E(\boldsymbol{\theta})$ 來最大化對數似然 $Q(\boldsymbol{\theta})$,如下:

    $E(\boldsymbol{\theta})=-\sum\limits_{i=1}^{N} \sum\limits_{k=1}^{K} q\left(k \mid \mathbf{x}_{i}\right) \log p\left(\mathbf{x}_{i}, z_{i}=k \mid \boldsymbol{\theta}, \mathbf{C}\right)\quad\quad\quad(12)$

  通過讓 $q\left(k \mid \mathbf{x}_{i}\right)=\mathbb{1}_{\left\{z_{i}=k\right\}}$,$E(\boldsymbol{\theta})$ 可以通過 $-\sum_{i=1}^{N} \log p\left(\mathbf{x}_{i}, z_{i} \mid \boldsymbol{\theta}, \mathbf{C}\right)$ 來計算。在 $\mathbf{x}_{i}$ 在不同原型上的先驗分佈均勻的假設下,我們有$p\left(\mathbf{x}_{i}, z_{i} \mid \boldsymbol{\theta}, \mathbf{C}\right) \propto p\left(\mathbf{x}_{i} \mid z_{i}, \boldsymbol{\theta}, \mathbf{C}\right)$ 。由我們的DPMM模型生成的每個原型周圍的分佈是一個高斯分佈。如果我們在節點和原型的表示上應用 $\ell_{2}$ 歸一化,我們可以估計$p\left(\mathbf{x}_{i} \mid z_{i}, \boldsymbol{\theta}, \mathbf{C}\right)$ 通過:

    $p\left(\mathbf{x}_{i} \mid z_{i}, \boldsymbol{\theta}, \mathbf{C}\right)=\frac{\exp \left(\mathbf{h}_{i} \cdot \mathbf{c}_{z_{i}} / \tau_{2}\right)}{\sum\limits _{k=1}^{K} \exp \left(\mathbf{h}_{i} \cdot \mathbf{c}_{k} / \tau_{2}\right)}\quad\quad\quad(13)$

  其中,$ \tau_{2} \propto \sigma^{2}$ 和 $ \sigma$ 為 Eq.5 定義的DPMM模型中高斯分佈的方差。$\mathbf{h}_{i}$ 和 $\mathbf{c}_{z_{i}}$ 是 $\mathbf{x}_{i}$ 和 $z_{i} -th$ 原型的代表。因此,$ E(\boldsymbol{\theta})$ 可以通過最小化損失函式來最小化,如下:

    $\mathcal{L}_{s e m}=-\frac{1}{N} \sum\limits _{i=1}^{N} \log \frac{\exp \left(\mathbf{h}_{i} \cdot \mathbf{c}_{z_{i}} / \tau_{2}\right)}{\sum\limits _{k=1}^{K} \exp \left(\mathbf{h}_{i} \cdot \mathbf{c}_{k} / \tau_{2}\right)}\quad\quad\quad(14)$

  Algorithm 1

  

4.3 Model Learning

4.3.1 Overall Loss

  為了以端到端方式訓練我們的模型,並學習編碼器$f_{\theta}(\cdot)$,我們共同優化了結構和語義對比學習損失。總體目標函式的定義為:

    $\mathcal{L}=\gamma \mathcal{L}_{s t r}+(1-\gamma) \mathcal{L}_{s e m}\quad\quad\quad(15)$

  我們的目標是在訓練中最小化 $\mathcal{L}$,$\gamma$ 是一個平衡因素來控制每個損失的貢獻。

  值得注意的是,在語義對比學習中,計算出的偽標籤 $ \mathcal{Z}$ 可以用於負示例抽樣過程,以避免結構對比學習中的抽樣偏差問題。我們從分配給不同原型的節點中,在 Eq.3 中為每個節點選擇負樣本。對負示例抽樣的詳細分析見附錄C.2。

4.3.2 Model Optimization via EM

  採用EM演算法交替估計後驗分佈 $p\left(z_{i} \mid \mathbf{x}_{i}, \boldsymbol{\theta}, \mathbf{C}\right)$,並優化網路引數$\boldsymbol{\theta}$。我們描述了在我們的方法中應用的 $E-step$ 和 $M-step$ 的細節如下:

E-step

  在這一步中,我們的目標是估計後驗分佈$p\left(z_{i} \mid \mathbf{x}_{i}, \boldsymbol{\theta}, \mathbf{C}\right) $。為了實現這一點,我們修復網路引數$\boldsymbol{\theta}$,估計原型 $\mathbf{C} $ 和原型分配 $\mathcal{Z}$為了訓練編碼器網路的穩定性,我們應用貝葉斯非引數原型推理演算法的節點表示計算動量編碼器$\mathbf{H}^{\prime}=f_{\theta^{\prime}}(\tilde{\mathbf{X}})$,動量編碼器的$\boldsymbol{\theta}^{\prime}$引數$\boldsymbol{\theta}$的移動平均更新:

    $\boldsymbol{\theta}^{\prime}=(1-m) \cdot \boldsymbol{\theta}+m \cdot \boldsymbol{\theta}^{\prime}\quad\quad\quad(16)$

  其中:

    $m \in[0,1)$ 是動量係數

M-step

  給定由 E-step 計算的後驗分佈,我們的目標是通過直接優化語義對比損失函式Lsem來最大化對數似然 $Q(\boldsymbol{\theta})$ 的期望。為了同時執行結構對比學習和語義對比學習,我們優化了一個如 Eq.15 所示的聯合整體損失函式。

  在對未標記的輸入圖進行自我監督的預訓練後,預訓練的編碼器可以直接用於生成各種下游任務的節點表示。

5 Experiments

5.1 資料集

  

5.2 實驗結果

  

  

 

相關文章