論文解讀(DGI)《DEEP GRAPH INFOMAX》

CBlair發表於2021-09-19

  論文標題:DEEP GRAPH INFOMAX
  論文方向:影像領域
  論文來源:2019 ICLR
  論文連結:https://arxiv.org/abs/1809.10341
  論文程式碼:https://github.com/PetarV-/DGI


摘要

  1.  DGI,一種以無監督的方式學習圖結構資料中節點表示的一般方法。
  2.  DGI 依賴於最大限度地擴大圖增強表示和目前提取到的圖資訊之間的互資訊
  3.  與大多數以前使用 GCN 進行無監督學習的方法相比,DGI不依賴於隨機遊走目標,並且很容易適用於直推式學習和歸納式學習。

1 介紹

  神經網路推廣到圖形結構輸入的困難之處:大多數圖表資料是未標記的。

  隨機遊走的限制:隨機遊走目標以犧牲結構資訊為代價過分強調鄰近資訊,並且效能高度依賴於超引數的選擇。目前還不清楚隨機遊走目標是否真的提供了任何有用的訊號。

  本文提出了一種用於無監督圖學習的替代目標,這種目標是基於互資訊,而不是隨機遊走。在概率論和資訊理論中,兩個隨機變數的互資訊(Mutual Information,簡稱MI)是指變數間相互依賴性的量度。近年來基於互資訊的代表性工作是 Mutual Information Neural Estimation (MINE),其中提出了一種 Deep InfoMax (DMI) 方法來學習高維資料的表示 DMI 訓練一個編碼模型來最大化高階全域性表示和輸入的區域性部分的互資訊。這鼓勵編碼器攜帶出現在所有位置的資訊型別(因此是全域性相關的),例如類標籤的情況。


2 相關工作

2.1 對比方法

  對於無監督學習一類重要的方法就是對比學習,通過訓練編碼器使它在特徵表示中更具判別性來捕獲感興趣的和不感興趣的統計依賴性。例如,對比方法可以使用評分函式,訓練編碼器來增加“真實”輸入的分數,並減少“假”輸入的分數,以此判別真實資料和假資料。有很多方法可以對一個表示進行打分,但在圖形文獻中,最常見的技術是使用分類,儘管也會使用其他的打分函式。DGI在這方面也是對比性的,因為DGI目標是基於對區域性-全域性對和負抽樣配對的分類。

2.2 抽樣戰略

  對比方法的一個關鍵實現細節是如何繪製正負樣本。關於無監督圖表示學習的先前工作依賴於區域性對比損失(強制近端節點具有相似的嵌入)。從語言建模的角度來看,正樣本通常對應於在圖中短時間的隨機遊走中一起出現的節點對,有效地將節點視為單詞,將隨機遊走視為句子。最近有的方法提出使用節點錨定取樣作為替代。這些方法的負取樣主要是基於隨機對的抽樣。

2.3 預測編碼

  對比預測編碼 Contrastive predictive coding (CPC) 是另一種基於互資訊最大化的深度表示的學習方法。CPC 也是一種對比學習方法,它使用條件密度的估計(以噪聲對比估計的形式)作為評分函式。然而,與 DGI 不同的是,CPC是預測性的:對比目標有效地訓練了輸入的結構指定部分(例如,相鄰節點對之間或節點與其鄰居之間)之間的預測器。DGI 不同之處在於同時對比一個圖的全域性/區域性部分,其中全域性變數是從所有的區域性變數計算出來的。

3 DGI Methodology

  在本節中,我們將以自上而下的方式介紹DGI方法:首先是對我們特定的無監督學習設定的抽象概述,然後是對我們的方法優化的目標函式的闡述,最後是在單圖設定中列舉我們過程的所有步驟。

3.1 基於圖的無監督學習

  我們假設一個通用的基於圖的無監督機器學習設定:
  首先給出一組節點特徵, $X=\left\{\vec{x}_{1}, \overrightarrow{x_{2}}, \ldots, \overrightarrow{x_{N}}\right\}$ , 其中 $ N$ 是圖中的節點數, $ \vec{x}_{i} \in \mathbb{R}^{F}$ 代表節點 $i$ 的特徵表示。鄰接矩陣 $ A \in \mathbb{R}^{N \times N}$ , 在本文中預設所有處理的圖是無權圖, 同時鄰接矩陣儲存的值為 $0$  或 $1$。
  模型的目的是學習一個編碼器,$ \varepsilon: \mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \rightarrow \mathbb{R}^{N \times F^{\prime}} $,可以形式化的表示為 $ \mathcal{E}(\boldsymbol{X}, \boldsymbol{A})=\boldsymbol{H}=\left\{\overrightarrow{h_{1}}, \overrightarrow{h_{2}}, \ldots, \overrightarrow{h_{N}}\right\}$ ,其中 $ H$ 代表高階表示, 並且每個節點 $i$ 滿足 $ \overrightarrow{h_{i}} \in \mathbb{R}^{F^{\prime}} $ 。所得到的節點特徵的高階表示可以用於各種下游任務,例如節點分類任務。
  在這裡,我們將重點討論圖卷積編碼器,它通過不斷聚合目標節點周邊的鄰居來完成特徵學習。它所產生的 $ \vec{h}_{i}$ 總結了以節點為中心的圖的一個 patch,而不僅僅是節點本身。在接下來的內容中,我們通常將 $ \vec{h}_{i}$ 稱為 patch representations 來強調這一點。

3.2 區域性-全域性互資訊最大化

   DGI 的核心思想在於通過最大化區域性互資訊來訓練編碼器——即 DGI 尋求獲得節點(即區域性)表示,以捕獲整個圖的全域性資訊(表示為summary vector,$\vec{s}$)。

  為了得到 圖級別的 summary vector $ \vec{s} $,作者提出了一種 readout 函式,$ \mathcal{R}: \mathbb{R}^{N \times F} \rightarrow \mathbb{R}^{F}$ ,利用它將獲得的 patch representations 總結為圖級別的表示。上述過程可以總結為 $ \vec{s}=\mathcal{R}(\mathcal{E}(\boldsymbol{X}, \boldsymbol{A}))$

  作為最大化區域性互資訊的指標,我們使用了一個 discriminator,$ \mathcal{D}: \mathbb{R}^{F} \times \mathbb{R}^{F} \rightarrow \mathbb{R}$, 這樣 $\mathcal{D}\left(\vec{h}_{i}, \vec{s}\right) $ 表示分配給這個 patch-summary 對的概率分數(對於包含在 summary 中的 patch 應該更高) 。

  $\mathcal{D}$ 的負樣本由 $ (\boldsymbol{X}, \boldsymbol{A})$ 的 summary vector $ \vec{s}$ 與一個可選擇的圖 $  (\widetilde{\boldsymbol{X}}, \widetilde{\boldsymbol{A}}) $  的 patch representations $ \vec{h}_{j}$ 提供。在多圖的資料集中,$ (\widetilde{\boldsymbol{X}}, \widetilde{\boldsymbol{A}}) $  可以通過訓練集的其他元素獲得。但是,對於單個圖,需要一個顯式(隨機 ) corruption function,$ \mathcal{C}: \mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \rightarrow \mathbb{R}^{M \times F} \times \mathbb{R}^{M \times M} $  來生成負樣本的圖  $(\widetilde{\boldsymbol{X}}$,$\widetilde{\boldsymbol{A}}) $ 。 上述過程可以表述為 $ (\widetilde{\boldsymbol{X}}, \widetilde{\boldsymbol{A}})=\mathcal{C}(\boldsymbol{X}, \boldsymbol{A}) $

  負樣本抽樣程式的選擇將決定著作為這種最大化的副產品所希望捕獲的具體結構資訊的種類。

  對於目標,我們遵循 Deep InfoMax,使用帶有標準二值交叉熵 (BCE) 損失的橾聲對比型目標函式(正樣本和負樣本之間):

     $\mathcal{L}=\frac{1}{N+M}\left(\sum \limits _{i=1}^{N} \mathbb{E}_{(\mathbf{X}, \mathbf{A})}\left[\log \mathcal{D}\left(\vec{h}_{i}, \vec{s}\right)\right]+\sum \limits_{j=1}^{M} \mathbb{E}_{(\tilde{\mathbf{X}}, \tilde{\mathbf{A}})}\left[\log \left(1-\mathcal{D}\left(\overrightarrow{\widetilde{h}}_{j}, \vec{s}\right)\right)\right]\right)$

3.3 DGI概述

      

  假設單圖設定(即 $(\boldsymbol{X}, \boldsymbol{A}) $ 作為輸入, DGI 的步驟:
  1. 通過 corruption function 得到負樣本例項: $ (\widetilde{X}, \widetilde{\boldsymbol{A}}) \sim \mathcal{C}(\boldsymbol{X}, \boldsymbol{A}) $ 。
  2. 通過編碼器獲得輸入圖的 patch representations $\overrightarrow{h_{i}}: \boldsymbol{H}=\mathcal{E}(\boldsymbol{X}, \boldsymbol{A})=\left\{\overrightarrow{h_{1}}, \overrightarrow{h_{2}}, \ldots, \overrightarrow{h_{N}}\right\} $
  3. 通過編碼器獲得負樣本的 patch representations $\vec{h}_{j}: \widetilde{H}=\mathcal{E}(\widetilde{X}, \widetilde{A})=\left\{\vec{h}_{1}, \vec{h}_{2}, \ldots, \widetilde{h}_{M}\right\} $
  4. 通過 Readout 函式傳遞輸入圖的 patch representations 來得到圖級別的 summary vector: $ \vec{s}=\mathcal{R}(\boldsymbol{H})$ 。
  5. 通過梯度下降法最小化目標函式式 (1),更新引數 $\mathcal{E}, \mathcal{R}, \mathcal{D}$。


4 實驗

4.1 資料集

  我們評估了 DGI 編碼器在各種節點分類任務(直推式學習 [ transductive ] 和歸納式學習 [ inductive ])上學習的表示的好處,獲得了有競爭力的結果。在每種情況下,DGI都被用來以完全無監督的方式學習 patch representations,然後評估這些表示的節點級分類效用。這是通過直接使用這些表示來訓練和測試一個簡單的線性(邏輯迴歸)分類器來實現的。
      

  1.  在 Cora、Citeseer 和 Pubmed 引文網路上對研究論文進行主題分類。
  2.  以Reddit帖子為模型預測社交網路的社群結構。
  3.  對蛋白質-蛋白質相互作用(PPI)網路中的蛋白質作用進行分類,需要對未見網路進行歸納。

4.2 實驗設定

  對於三個實驗設定(直推式學習、大圖上的歸納式學習和多圖上的歸納式學習)中的每一個,我們使用了與該設定相適應的不同編碼器和 corruption function。

4.2.1 直推式學習

  直推式學習 Transductive learning

  編碼器是一層圖卷積網路(GCN)模型,具有以下傳播規則:

    $\mathcal{E}(\mathbf{X}, \mathbf{A})=\sigma\left(\hat{\mathbf{D}}^{-\frac{1}{2}} \hat{\mathbf{A}} \hat{\mathbf{D}}^{-\frac{1}{2}} \mathbf{X} \boldsymbol{\Theta}\right)$

  其中, $\hat{A}=A+I_{N} $ 代表加上自環的鄰接矩陣, $\hat{D}$ 代表相應的度矩陣,滿足  $\hat{D}_{i i}=\sum_{j} \hat{A}_{i j}$ 對於非線性啟用函式 $\sigma$ ,選擇 PReLU(parametric ReLU)。$\Theta \in R^{F \times F^{\prime}} $ 是應用於每個節點的可學習線性變換。

  對於 corruption function  C ,直接採用 $ \widetilde{A}=A$,但是 $ \widetilde{X}$ 是由原本的特徵矩陣 $X$ 經過隨機變換得到的。也就是說,損壞的圖(corrupted graph)由與原始圖完全相同的節點組成,但它們位於圖中的不同位置,因此將得到不同的鄰近表示。

4.2.2 大圖上的歸納式學習

  歸納式學習 Inductive learning 

  對於歸納學習,不再在編碼器中使用 GCN 更新規則(因為學習的濾波器依賴於固定的和已知的鄰接矩陣);相反,我們應用平均池( mean-pooling)傳播規則,GraphSAGE-GCN:

     $\operatorname{MP}(\mathbf{X}, \mathbf{A})=\hat{\mathbf{D}}^{-1} \hat{\mathbf{A}} \mathbf{X} \Theta$

   $\widehat{D} ^{-1}$ 實際上執行的是標準化的和(因此是 mean-pooling)。儘管上式明確指定了鄰接矩陣和度矩陣,但並不需要它們:因為 Const-GAT 模型中使用的持續關注機制可以觀察到相同的歸納行為。

   對於 Reddit 資料庫,DGI 的編碼器是一個帶有跳躍連線的三層均值池模型:

    $\widetilde{\mathrm{MP}}(\mathbf{X}, \mathbf{A})=\sigma\left(\mathbf{X} \Theta^{\prime} \| \operatorname{MP}(\mathbf{X}, \mathbf{A})\right) \quad \mathcal{E}(\mathbf{X}, \mathbf{A})=\widetilde{\mathrm{MP}}_{3}\left(\widetilde{\mathrm{MP}}_{2}\left(\widetilde{\mathrm{MP}}_{1}(\mathbf{X}, \mathbf{A}), \mathbf{A}\right), \mathbf{A}\right)$

  這裡 || 是 featurewise concatenation 。由於資料集的規模很大,它將不能完全適合 GPU記憶體。因此,採用 子抽樣(subsampling)方法,首先選擇小批量的節點,然後,通過對具有替換的節點鄰域進行抽樣,得到以每個節點為中心的子圖。具體來說,DGI 在第一層、第二層和第三層分別取樣 10、10 和 25 個鄰居,這樣每次取樣的 patch 有 1 + 10 + 100 + 2500 = 2611 個節點。只進行了推導中心節點 i 的 patch 表示 $h_I$  所必需的計算。這些表示然後被用來為 minibatch(圖2)匯出總結向量 $\overrightarrow{s} $ 。在整個訓練過程中使用了 256 個節點的 minibatch 。

       

  圖2中,摘要向量 $\vec{s} $ 是通過組合幾個子取樣的鄰近表示 $\vec{h}_{i} $ 得到的。

4.2.3 多圖上的歸納式學習

  例如 PPI 資料集,編碼器是一個帶有密集跳過連線的三層均值池模型

     $\mathbf{H}_{1}=\sigma\left(\operatorname{MP}_{1}(\mathbf{X}, \mathbf{A})\right)$

     $\mathbf{H}_{2}=\sigma\left(\mathbf{M P}_{2}\left(\mathbf{H}_{1}+\mathbf{X} \mathbf{W}_{\text {skip }}, \mathbf{A}\right)\right)$

     $\mathcal{E}(\mathbf{X}, \mathbf{A})=\sigma\left(\mathbf{M P}_{3}\left(\mathbf{H}_{2}+\mathbf{H}_{1}+\mathbf{X} \mathbf{W}_{\text {skip }}, \mathbf{A}\right)\right)$

   其中,$W_{skip}$ 是一個可學習的投影矩陣。

  在這個多圖設定中,DGI 選擇使用隨機抽樣的訓練圖作為負樣本(即,DGI 的破壞函式只是從訓練集中抽樣一個不同的圖)。作者發現該方法是最穩定的,因為該資料集中超過 40% 的節點具有全零特徵(all-zero features)。為了進一步擴大負樣本池,作者還將 dropout 應用於取樣圖的輸入特徵。作者發現,在將學習到的嵌入資訊提供給邏輯迴歸模型之前,將其標準化是有益的。

4.2.4 Readout,discriminator 的細節

  在所有三個實驗設定中,作者使用了相同的readout函式和discriminator體系結構。

  對於 Readout Function,作者使用所有節點特徵的簡單平均值:

     $\mathcal{R}(\mathbf{H})=\sigma\left(\frac{1}{N} \sum \limits _{i=1}^{N} \vec{h}_{i}\right)$

   作者通過應用一個簡單的雙線性評分函式對圖級別的 summarize-patch representation 對進行評分:

     $\mathcal{D}\left(\vec{h}_{i}, \vec{s}\right)=\sigma\left(\vec{h}_{i}^{T} \mathbf{W} \vec{s}\right)$

  其中,  $W$  是一個可學習的評分權重引數, $\sigma$ 是邏輯 Sigmoid 非線性, 用於將分數轉換為 $(\vec{h}_{i}, \vec{s})$ 為正對的概率。

4.3 結果

      

       

   根據分類準確性(在 transductive tasks)或 micro-averaged $F_1$ score(在歸納任務)的結果總結。在第一列中,我們突出顯示了訓練期間每個方法可用的資料型別(X:特徵,A:鄰接矩陣,Y:標籤)。"GCN" 對應於以監督方式訓練的兩層 DGI 編碼器。

      


5 參考

1 Deep Graph Infomax

 

 

 

 

知識點

 知識點:

  Q:互資訊(Mutual Information)

  互資訊(Mutual Information)是度量兩個事件集合之間的相關性(mutual dependence),它是資訊理論裡一種有用的資訊度量,它可以看成是一個隨機變數中包含的關於另一個隨機變數的資訊量,或者說是一個隨機變數由於已知另一個隨機變數而減少的不肯定性。互資訊最常用的單位是bit。互資訊指的是兩個隨機變數之間的關聯程度,即給定一個隨機變數後,另一個隨機變數不確定性的削弱程度,因而互資訊取值最小為0,意味著給定一個隨機變數對確定一另一個隨機變數沒有關係,最大取值為隨機變數的熵,意味著給定一個隨機變數,能完全消除另一個隨機變數的不確定性。

  直觀上,互資訊度量 X 和 Y 共享的資訊:它度量知道這兩個變數其中一個,對另一個不確定度減少的程度。例如,如果 X 和 Y 相互獨立,則知道 X 不對 Y 提供任何資訊,反之亦然,所以它們的互資訊為零。在另一個極端,如果 X 是 Y 的一個確定性函式,且 Y 也是 X 的一個確定性函式,那麼傳遞的所有資訊被 X 和 Y 共享:知道 X 決定 Y 的值,反之亦然。因此,在此情形互資訊與 Y(或 X)單獨包含的不確定度相同,稱作 Y(或 X)的熵。而且,這個互資訊與 X 的熵和 Y 的熵相同。

  Q:什麼是 patch?

  在 CNN 學習訓練過程中,不是一次來處理一整張圖片,而是先將圖片劃分為多個小的塊,核心/過濾器 kernel  每次只檢視影像的一個塊,這一個小塊就稱為 patch,然後過濾器移動到影像的另一個patch,以此類推。

  當將 CNN 過濾器應用到影像時,它會一次檢視一個 patch 。

  CNN 核心/過濾器 一次只處理一個 patch,而不是整個影像。這是因為我們希望過濾器處理影像的小塊以便檢測特徵(邊緣等)。這也有一個很好的正則化屬性,因為我們估計的引數數量較少,而且這些引數必須在每個影像的許多區域以及所有其他訓練影像的許多區域都是“好”的。

  所以 patch 就是核心 kernel 的輸入。這時核心的大小便是 patch 的大小。

    

   如圖,主動脈弓和心臟,綠色部分相同,而黃色部分不同。傳統的CNN演算法,區分效果不佳。在 Multi-Instance Multi-Stage Deep Learning for Medical Image Recognition 這篇文章中,作者針對這種場景提出瞭解決方法。

    $\begin{array}{l} L_{1}(\mathbf{W})=\sum_{\mathbf{x}_{m} \in \mathcal{T}}-\log \left(\mathbf{P}\left(l_{m} \mid \mathbf{X}_{m} ; \mathbf{W}\right)\right) \\ L_{2}(\mathbf{W})=\sum_{\mathbf{X}_{m} \in \mathcal{T}}-\log \left(\max _{\mathbf{x}_{m n} \in \mathcal{L}\left(\mathbf{X}_{m}\right)} \mathbf{P}\left(l_{m} \mid \mathbf{x}_{m n} ; \mathbf{W}\right)\right) \end{array}$

  這樣訓練出的網路,就會對有區分度的patch敏感,而對無區分度的無感。

    

   一個CNN層生成一箇中間表示。該表示被傳遞到下一層。如果下一層是CNN,則應用完全相同的“patch”概念,並以完全相同的方式進行計算,即使中間表示不是您或我可以識別為“影像”的東西。

  Q:什麼是 macro-F1,micro-F1

  macro-F1 和 micro-F1,巨集觀F1值和微觀F1值,考慮的是在多標籤(Multi-label)情況下,分類效果的評估方式。

  比如 Multi-label 性別男或女(0/1)以及是否是學生(0/1);當然 Multi-class也可以通過一定的編碼方式轉化為 Multi-label,如原始類別 1,2,3,4,獨熱編碼後可用四元向量表示 [0,1,0,0] 即表示類標 2。

  macro-F1 和 micro-F1 正是基於分類目的的多樣性,將只適用於 Binary 分類的 F1 值推廣了:

 

相關文章