論文解讀(MPNN)Neural Message Passing for Quantum Chemistry

希望每天漲粉發表於2021-10-17

  論文標題:DEEP GRAPH INFOMAX
  論文來源:ICML 2017
  論文連結:https://arxiv.org/abs/1704.01212

1 介紹

  本文的目標是證明:「能夠應用於化學預測任務的模型可以直接從分子圖中學習到分子的特徵,並且不受到圖同構的影響。」

  本文提出的 MPNN 是一種用於圖上監督學習的框架。為此,作者將應用於圖上的監督學習框架稱之為訊息傳遞神經網路(MPNN),這種框架是從目前比較流行的支援圖資料的神經網路模型中抽象出來的一些共性,抽象出來的目的在於理解它們之間的關係。

  本文以 QM9 作為 benchmark 資料集,該資料集由 $130k$ 個分子組成,每個分子有 $13$個特徵,這些特徵是通過一種計算昂貴的量子力學模擬方法(DFT)近似生成的,相當於 $13$ 個迴歸任務。這些任務似乎代表了許多重要的化學預測問題,並且目前對許多現有方法來說是困難的。

  本文給出的一個例子是利用 MPNN 框架代替計算代價昂貴的 DFT 來預測有機分子的量子特性:

        

 

  本文提出的模型的效能度量採用兩種形式:

  • DFT近似的平均估計誤差;
  • 化學界已經確立的目標誤差,稱為“化學精度”。

2 訊息傳遞神經網路(MPNN)

  本文首先通過八篇文獻來舉例驗證 MPNN 框架的通配性。
  為簡便起見,以處理無向圖 $G$ 為例,節點 $v$ 的特徵為 $x_{v}$ 和邊的特徵 $e_{v w}$ 。
  前向傳遞有兩個階段,一個訊息傳遞階段(message passing phase)一個讀出階段(readout phase)。
  訊息傳遞階段執行 $T$ 個時間步並且依賴訊息函式 $M_t$ 以及節點更新函式 $U_t$。在訊息傳遞階段,每個節點的隱藏狀態 $h_{v}^{t}$ 都會根據訊息 $m_{v}^{t+1}$ 調整。
    $\begin{aligned}m_{v}^{t+1} &=\sum \limits _{w \in N(v)} M_{t}\left(h_{v}^{t}, h_{w}^{t}, e_{v w}\right) \\h_{v}^{t+1} &=U_{t}\left(h_{v}^{t}, m_{v}^{t+1}\right)\end{aligned}$
  其中,$N(v)$ 表示圖 $G$ 中 $v$ 的鄰居。
  讀出階段根據一些  讀出函式 R 計算整個圖的特徵向量:
    $\hat{y}=R\left(\left\{h_{v}^{T} \mid v \in G\right\}\right)$
  訊息函式 $M_t$、向量更新函式  $U_t$ 和 讀出函式 $R$ 都是可學習的可微函式。$R$作用於節點的狀態集合,同時對節點的排列不敏感,這樣才能保證 MPNN 對圖同構保持不變。
  此外,也可通過引入的隱藏層狀態來學習圖中每一條邊的特徵,並且同樣可以用上面的等式進行學習和更新。

3 論文文獻總結

  MPNN 通過定義訊息函式、更新函式和讀出函式來適配不同種模型。

  Paper 1 : Convolutional Networks for Learning Molecular Fingerprints, Duvenaud et al. (2015)

  訊息傳遞函式為:

    $M\left(h_{v}, h_{w}, e_{v w}\right)=\left(h_{w}, e_{v w}\right)$

  其中 $(., .) $ 表示拼接 (concat) ;
  節點更新函式為:

    $U_{t}\left(h_{v}^{t}, m_{v}^{t+1}\right)=\sigma\left(H_{t}^{d e g(v)} m_{v}^{t+1}\right)$

  其中 $ \sigma$ 為 sigmoid 函式, $ \operatorname{deg}(v) $  表示節點 $ v$ 的度, $ H_{t}^{N}$ 是一個可學習的矩陣,$ \mathrm{t}$ 為時間步, $ \mathrm{N}$ 為節點度;讀出函式 $ \mathrm{R} $ 將先前所有隱藏層的狀態 $ h_{v}^{t}$ 進行連線:

    $R=f\left(\sum \limits _{v, t} \operatorname{softmax}\left(W_{t} h_{v}^{t}\right)\right)$

  其中 $f$ 是一個神經網路,$ W_{t}$ 是一個可學習的讀出矩陣。
  在訊息傳遞階段可能會存在一些問題,如最終的訊息向量分別對連通的節點和連通的邊求和 $m_{v}^{t+1}=\left(\sum h_{w}^{t}, \sum e_{v w}\right) $。可見,該模型實現的訊息傳遞無法識別節點和邊之間的相關性。


  Paper 2 : Gated Graph Neural Networks (GG-NN), Li et al. (2016)

  訊息傳遞函式為:

    $M_{t}\left(h_{v}^{t}, h_{w}^{t}, e_{v w}\right)=A_{e_{v w}} h_{w}^{t}$

  其中 $A_{e_{v w}}$  是 $e_{v w}$  的一個可學習矩陣,每條邊都會對應那麼一個矩陣。
  更新函式為:

    $U_{t}\left(h_{v}^{t}, m_{v}^{t+1}\right)=G R U\left(h_{v}^{t}, m_{v}^{t+1}\right)$

  其中 $GRU$  為門控制單元 (Gate Recurrent Unit) 。使用了權值捆綁(weight tying),所以在每一個時間步 $\mathrm{t}$  下都會使用相同的更新函式。

  讀出函式 $\mathrm{R}$  為:

    $R=\sum \limits_{v \in V} \sigma\left(i\left(h_{v}^{(T)}\right), h_{v}^{0}\right) \odot\left(j\left(h_{v}^{(T)}\right)\right)$

  其中 $i$  和 $j$  為神經網路, $\odot$ 即哈達瑪積,表示元素相乘。


  Paper 3 : Interaction Networks, Battaglia et al. (2016)

  該論文考慮圖中節點的結構和圖的結構,也考慮每個時間步下的節點級的影響。所以這裡的更新函式的輸入是 $\left(h_{v}, x_{v}, m_{v}\right)$ ,其中 $x_v $ 是一個外部向量,表示對頂點 $v$ 的一些外部影響。
  訊息傳遞函式:

    $M\left(h_{v}, h_{w}, e_{v w}\right)$ 是一個以 $\left(h_{v}, h_{w}, e_{v w}\right)$ 為輸入的神經網路。

  節點更新函式:

    $U\left(h_{v}, x_{v}, m_{v}\right)$ 是一個以 $\left(h_{v}, x_{v}, m_{v}\right)$ 為輸入的神經網路。

  讀出函式 $\mathrm{R}$(圖級別的輸出):

    $R=f\left(\sum_{v \in G} h_{v}^{T}\right)$ ,其中 $\mathrm{f}$ 是一個神經網路,輸入是最終的隱藏層狀態的和。原論文中 $T=1$ 。


  Paper 4 : Molecular Graph Convolutions, Kearnes et al. (2016)
  該論文與其他 MPNN 稍有不同,主要區別在於考慮了邊表示 $e_{v, w}^{t}$ ,並且在訊息傳遞階段會進行更新。

  訊息傳遞函式用的是節點的訊息:

    $M_{t}\left(h_{v}^{t}, h_{w}^{t}, e_{v w}^{t}\right)=e_{v w}^{t}$

  節點的更新函式:

    $U_{t}\left(h_{v}^{t}, m_{v}^{t+1}\right)=\alpha\left(W_{1}\left(\alpha\left(W_{0} h_{v}^{t}\right), m_{v}^{t+1}\right)\right)$

  其中 $ (., .) $ 表示拼接 (concat), $ \alpha$ 為 $ \operatorname{ReLU}$ 啟用函式, $ W_{0}$,$W_{1}$ 為可學習權重矩陣。
  邊狀態的更新定義為:

    $e_{v w}^{t+1} =U_{t}^{\prime}\left(e_{v w}^{t}, h_{v}^{t}, h_{w}^{t}\right) =\alpha\left(W_{4}\left(\alpha\left(W_{2}, e_{v w}^{t}\right), \alpha\left(W_{3}\left(h_{v}^{t}, h_{w}^{t}\right)\right)\right)\right)$

  其中,$W_{i}$ 為可學習權重矩陣。


  Paper 5 : Deep Tensor Neural Networks, Schutt et al. (2017)

  訊息函式:

    $M_{t}=\tanh \left(W^{f c}\left(\left(W^{c f} h_{w}^{t}+b_{1}\right) \odot\left(W^{d f} e_{v w}+b_{2}\right)\right)\right)$

  其中 $ W^{f c}, W^{c f}, W^{d f}$ 為矩陣, $ b_{1}, b_{2}$ 為偏置向量;
  更新函式:

    $U_{t}\left(h_{v}^{t}, m_{v}^{t+1}\right)=h_{v}^{t}+m_{v}^{t+1}$

  讀出函式(通過單層隱藏層接受每個節點並且求和後輸出):

    $R=\sum_{v} N N\left(h_{v}^{T}\right)$


  Paper 6 : Laplacian Based Methods, Bruna et al. (2013); Defferrard et al. (2016); Kipf \& Welling (2016)
  基於拉普拉斯矩陣的方法將影像中的卷積運算擴充套件到網路圖 $G$ 的鄰接矩陣 $A$ 中。

  在 Bruna et al. (2013); Defferrard et al. (2016)的工作中:
  訊息函式:

    $M_{t}\left(h_{v}^{t}, h_{w}^{t}\right)=C_{v w}^{t} h_{w}^{t}$

  其中,矩陣 $C_{v w}^{t}$  為拉普拉斯矩陣 $L$  的特徵向量組成的矩陣;

  更新函式:

    $U_{t}\left(h_{v}^{t}, m_{v}^{t+1}\right)=\sigma\left(m_{v}^{t+1}\right)$

   在 Kipf & Welling (2016) 的工作中:
  訊息函式:

    $M_{t}\left(h_{v}^{t}, h_{w}^{t}\right)=C_{v w} h_{w}^{t}$

  其中, $C_{v w}=(\operatorname{deg}(v) \operatorname{deg}(w))^{-1 / 2} A_{v w} $;
  更新函式:

    $U_{v}^{t}\left(h_{v}^{t}, m_{v}^{t+1}\right)=\operatorname{Re} L U\left(W^{t} m_{v}^{t+1}\right)$

  上述模型都是 MPNN 框架的不同例項,作者呼籲大家應致力於將這一框架應用於某個實際應用,並根據不同情況對關鍵部分進行修改,從而引導模型的改進,這樣才能最大限度的發揮模型的能力。

4 MPNN 變種

4.1 Message Functions

  作者將 MPNN 框架應用於分子預測領域,提出了 MPNN 的變種,並以 QM9 資料集為例進行實驗。任務是根據分子結構預測分子所屬類別。

  作者主要是基於 GG-NN 來探索 MPNN 的多種改進方式(不同的訊息函式、輸出函式等)。

  下文中以 $d$ 代表節點特徵的維度,以 $n$ 代表圖的節點數量。同樣適用於有向圖,入邊和出邊有分別的資訊通道,那麼節點 $v$ 的資訊 $m_{v}$ 由 $m_{v}^{i n}$ 和 $m_{v}^{out }$ 拼接而成。在無向圖中,可以將無向圖的邊看做兩條邊,一條入邊和一條出邊,有相同的標籤,那麼資訊通道的大小是 $2 d$ 而不是 $d$ 。
  模型的輸入是每個節點的特徵向量 $x_{v}$  以及鄰接矩陣 $A$  ,鄰接矩陣 $A$  具有向量分量,表示分子中的不同化學鍵以及兩個原子之間的成對空間距離。初始狀態 $h_{v}^{0}$  是原子輸入特徵集合  $x_{v}$  ,並且需要 padding 到維度  $d$。在實驗中的每個時間步 $t$ 都要進行權重共 享 , 並且更新函式 GRU。

  訊息函式:
  GG-NN 採用的訊息函式,採用矩陣相乘的方式(GG-NN 的邊有離散的標籤):

    $M\left(h_{v}, h_{w}, e_{v w}\right)=A_{e_{v w}} h_{w}$

   現假設邊有一個特徵向量 $e_{v w} $(為相容邊的特性)。  

    $M\left(h_{v}, h_{w}, e_{v w}\right)=A\left(e_{v w}\right) h_{w}$

  其中, $A\left(e_{v w}\right)$ 是一個神經網路,將邊的向量 $e_{v w}$  對映到 $\mathrm{d} \times \mathrm{d}$  維矩陣。
  上述兩種訊息函式的特點是,從節點 $v$  到節點 $w$  的函式僅與隱藏層狀態 $h_{v}$  和邊向量 $e_{v w}$  有關,而和隱藏狀態 $h_{v}^{t}$  無關。實際上,節點訊息同時依賴於源節點 $v$  和目標節點 $w$  的話,網路的訊息通道將會得到更有效的利用。所以也可以嘗試去使用一種訊息函式的變種:

    $m_{v w}=f\left(h_{w}^{t}, h_{v}^{t}, e_{v w}\right)$

  其中, $f$  為神經網路。

  對於有向圖, 一共有兩個訊息函式  $M^{i n}$  和  $M^{out }$ , 對於邊  $e_{v w}$  應用哪個訊息函式取決於邊的方向。

4.2 Virtual Graph Elements

  本文作者探索了兩種不同的訊息傳遞方式。

  • 為沒有連線的節點新增一個虛擬的邊,這樣訊息便具有更長的傳播距離;
  • 使用潛在的“主”節點(master node),這個節點可以通過特殊的邊來連線到圖中任意一個節點。主節點充當了一個全域性的暫存空間,每個節點都會在訊息傳遞過程中通過主節點進行讀取和寫入。同時允許主節點具有自己的節點維度,以及內部更新函式(GRU)的單獨權重。目的同樣是為了在傳播階段傳播很長的距離。

4.3 Readout Functions

  作者嘗試了兩種讀出函式:

  考慮 GG-NN 中的讀出函式:

    $R=\sum \limits_{v \in V} \sigma\left(i\left(h_{v}^{(T)}\right), h_{v}^{0}\right) \odot\left(j\left(h_{v}^{(T)}\right)\right)$

  考慮 set2set 模型。set2set 模型是專門為在集合運算而設計的,並且相比簡單累加節點的狀態來說具有更強的表達能力。模型首先通過線性對映將資料對映到元組 $ \left(h_{v}^{t}, x_{v}\right)$ ,並將投影元組作為輸入 $ T=\left\{\left(h_{v}^{T}, x_{v}\right)\right\}$,然後經過 $\mathrm{M}$  步計算後, set2set 模型會生成一 個與節點順序無關的 Graph-level 的 embeedding 向量,從而得到我們的輸出向量。

4.4 Multiple Towers

  考慮 MPNN 的伸縮性。
  對一個稠密圖來說,訊息傳遞階段的每一個時間步的時間複雜度為  $O\left(n^{2} d^{2}\right)$,其中 $\mathrm{n}$  為節點數,$  \mathrm{d}$  為向量維度,顯然計算複雜度還是較高的。
  處理的方法是將向量維度為 $d$ 的 $h_{v}^{t}$ 拆分成 $k$  份,就變成了 $k$  個 $\mathrm{d} / \mathrm{k}$  維向量 $h_{v}^{t, k} $,並在每個 $h_{v}^{t, k}$ 傳播過程中分別進行傳播和更新,最後再進行合併。

    $\left(h_{v}^{t, 1}, h_{v}^{t, 2}, \cdots, h_{v}^{t, k}\right)=g\left(\tilde{h}_{v}^{t, 1}, \tilde{h}_{v}^{t, 2}, \cdots, \tilde{h}_{v}^{t, k}\right)$

  $g$ 代表神經網路, $(x, y, \cdots) $ 代表拼接,$g$ 在所有節點上共享。這樣就保持了節點排列不變性,同時允許圖的不同副本在傳播階段相互通訊

  此時子向量時間複雜度為 $O\left(n^{2}(d / k)^{2}\right)$,考慮 $\mathrm{k}$  個子向量的時間複雜度為 $O\left(n^{2} d^{2} / k\right)$  。

5 輸入表示

  介紹 GNN 的輸入。一個分子有很多特徵,如下圖所示:
       

  對於鄰接矩陣,作者模型嘗試了三種邊表示形式:

  • 化學圖 (Chemical Graph) :在不考慮距離的情況下,鄰接矩陣的值是離散的鍵型別:單鍵,雙鍵,三鍵或芳香鍵;
  • 距離分桶(Distance bins):基於矩陣乘法的訊息函式的前提假設是邊資訊是離散的,因此作者將鍵的距離分為 10 個 bin, 比如說 $[2,6]$ 中均勻劃分 8 個 bin,$[0,2]$  為 1 個 bin, $[6,+\infty]$  為 1 個 bin;
  • 原始距離特徵(Raw distance feature):也可以同時考慮距離和化學鍵的特徵,這時每條邊都有自己的特徵向量,此時鄰接矩陣的每個 例項都是一個 5 維向量,第一維是距離,其餘 4 維是四種不同的化學鍵。

6 實驗

  實驗以 QM-9 資料集為例,包含 130462 個分子,以 MAE 為評估指標。

  下圖為現有演算法和作者改進的演算法之間的對比:

       

   下圖為不考慮空間資訊的結果:

      
  下圖為考慮多塔模型和結果:
       

7 總結

  作者從眾多模型中總結出 MPNN 框架,並且通過實驗表明,具有訊息函式、更新函式和讀出函式的 MPNN 具有良好的歸納能力,可以用於預測分析特性,優於目前的 Baseline,並且無需進行復雜的特徵工程。此外,實驗結果也揭示了全域性主節點和利用 set2set 模型的重要性,多塔模型也使得 MPNN 更具伸縮性,方便應用於大型圖中。

  看完點個關注唄!!(總結不易)

相關文章