深入理解圖注意力機制

DGL专栏發表於2019-02-19

圖卷積網路 Graph Convolutional Network (GCN) 告訴我們將區域性的圖結構和節點特徵結合可以在節點分類任務中獲得不錯的表現。美中不足的是 GCN 結合鄰近節點特徵的方式和圖的結構依依相關,這侷限了訓練所得模型在其他圖結構上的泛化能力。

Graph Attention Network (GAT) 提出了用注意力機制對鄰近節點特徵加權求和。鄰近節點特徵的權重完全取決於節點特徵,獨立於圖結構。

在這個教程裡我們將:

  • 解釋什麼是 Graph Attention Network

  • 演示用 DGL 實現這一模型

  • 深入理解學習所得的注意力權重

  • 初探歸納學習 (inductive learning)

難度:★★★★✩(需要對圖神經網路訓練和 Pytorch 有基本瞭解)

在 GCN 裡引入注意力機制

GAT 和 GCN 的核心區別在於如何收集並累和距離為 1 的鄰居節點的特徵表示。

在 GCN 裡,一次圖卷積操作包含對鄰節點特徵的標準化求和:

深入理解圖注意力機制

其中 N(i) 是對節點 i 距離為 1 鄰節點的集合。我們通常會加一條連線節點 i 和它自身的邊使得 i 本身也被包括在 N(i) 裡。深入理解圖注意力機制 是一個基於圖結構的標準化常數;σ是一個啟用函式(GCN 使用了 ReLU);W^((l)) 是節點特徵轉換的權重矩陣,被所有節點共享。由於 c_ij 和圖的機構相關,使得在一張圖上學習到的 GCN 模型比較難直接應用到另一張圖上。解決這一問題的方法有很多,比如 GraphSAGE 提出了一種採用相同節點特徵更新規則的模型,唯一的區別是他們將 c_ij 設為了|N(i)|。

圖注意力模型 GAT 用注意力機制替代了圖卷積中固定的標準化操作。以下圖和公式定義瞭如何對第 l 層節點特徵做更新得到第 l+1 層節點特徵:

深入理解圖注意力機制

圖 1:圖注意力網路示意圖和更新公式。

對於上述公式的一些解釋:

  • 公式(1)對 l 層節點嵌入深入理解圖注意力機制做了線性變換,W^((l)) 是該變換可訓練的引數

  • 公式(2)計算了成對節點間的原始注意力分數。它首先拼接了兩個節點的 z 嵌入,注意 || 在這裡表示拼接;隨後對拼接好的嵌入以及一個可學習的權重向量 做點積;最後應用了一個 LeakyReLU 啟用函式。這一形式的注意力機制通常被稱為加性注意力,區別於 Transformer 裡的點積注意力。

  • 公式(3)對於一個節點所有入邊得到的原始注意力分數應用了一個 softmax 操作,得到了注意力權重

  • 公式(4)形似 GCN 的節點特徵更新規則,對所有鄰節點的特徵做了基於注意力的加權求和。

出於簡潔的考量,在本教程中,我們選擇省略了一些論文中的細節,如 dropout, skip connection 等等。感興趣的讀者們歡迎參閱文末連結的模型完整實現。

本質上,GAT 只是將原本的標準化常數替換為使用注意力權重的鄰居節點特徵聚合函式。

GAT 的 DGL 實現

以下程式碼給讀者提供了在 DGL 裡實現一個 GAT 層的總體印象。別擔心,我們會將以下程式碼拆分成三塊,並逐塊講解每塊程式碼是如何實現上面的一條公式。

import torch
import torch.nn as nn
import torch.nn.functional as F

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # 公式 (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # 公式 (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)

    def edge_attention(self, edges):
        # 公式 (2) 所需,邊上的使用者定義函式
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e' : F.leaky_relu(a)}

    def message_func(self, edges):
        # 公式 (3), (4)所需,傳遞訊息用的使用者定義函式
        return {'z' : edges.src['z'], 'e' : edges.data['e']}

    def reduce_func(self, nodes):
        # 公式 (3), (4)所需, 歸約用的使用者定義函式
        # 公式 (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # 公式 (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h' : h}

    def forward(self, h):
        # 公式 (1)
        z = self.fc(h)
        self.g.ndata['z'] = z
        # 公式 (2)
        self.g.apply_edges(self.edge_attention)
        # 公式 (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

實現公式 (1) 

深入理解圖注意力機制

第一個公式相對比較簡單。線性變換非常常見。在 PyTorch 裡,我們可以通過 torch.nn.Linear 很方便地實現。

實現公式 (2) 

深入理解圖注意力機制

原始注意力權重 e_ij 是基於一對鄰近節點 i 和 j 的表示計算得到。我們可以把注意力權重 e_ij 看成在 i->j 這條邊的資料。因此,在 DGL 裡,我們可以使用 g.apply_edges 這一 API 來呼叫邊上的操作,用一個邊上的使用者定義函式來指定具體操作的內容。我們在使用者定義函式裡實現了公式(2)的操作:

 def edge_attention(self, edges):
        # 公式 (2) 所需,邊上的使用者定義函式
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e' : F.leaky_relu(a)}

公式中的點積同樣藉由 PyTorch 的一個線性變換 attn_fc 實現。注意 apply_edges 會把所有邊上的資料打包為一個張量,這使得拼接和點積可以並行完成。

實現公式 (3) 和 (4)

深入理解圖注意力機制

類似 GCN,在 DGL 裡我們使用 update_all API 來觸發所有節點上的訊息傳遞函式。update_all 接收兩個使用者自定義函式作為引數。message_function 傳送了兩種張量作為訊息:訊息原節點的 z 表示以及每條邊上的原始注意力權重。reduce_function 隨後進行了兩項操作:

  1. 使用 softmax 歸一化注意力權重(公式(3))。

  2. 使用注意力權重聚合鄰節點特徵(公式(4))。

這兩項操作都先從節點的 mailbox 獲取了資料,隨後在資料的第二維(dim = 1 ) 上進行了運算。注意資料的第一維代表了節點的數量,第二維代表了每個節點收到訊息的數量。

 def reduce_func(self, nodes):
        # 公式 (3), (4)所需, 歸約用的使用者定義函式
        # 公式 (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # 公式 (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h' : h}

多頭注意力 (Multi-head attention)

神似卷積神經網路裡的多通道,GAT 引入了多頭注意力來豐富模型的能力和穩定訓練的過程。每一個注意力的頭都有它自己的引數。如何整合多個注意力機制的輸出結果一般有兩種方式:

深入理解圖注意力機制

以上式子中 K 是注意力頭的數量。作者們建議對中間層使用拼接對最後一層使用求平均。

我們之前有定義單頭注意力的 GAT 層,它可作為多頭注意力 GAT 層的組建單元:

class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # 對輸出特徵維度(第1維)做拼接
            return torch.cat(head_outs, dim=1)
        else:
            # 用求平均整合多頭結果
            return torch.mean(torch.stack(head_outs))

在 Cora 資料集上訓練一個 GAT 模型

Cora 是經典的文章引用網路資料集。Cora 圖上的每個節點是一篇文章,邊代表文章和文章間的引用關係。每個節點的初始特徵是文章的詞袋(Bag of words)表示。其目標是根據引用關係預測文章的類別(比如機器學習還是遺傳演算法)。在這裡,我們定義一個兩層的 GAT 模型:

class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # 注意輸入的維度是 hidden_dim * num_heads 因為多頭的結果都被拼接在了
        # 一起。 此外輸出層只有一個頭。
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

我們使用 DGL 自帶的資料模組載入 Cora 資料集。

from dgl import DGLGraph
from dgl.data import citation_graph as citegrh

def load_cora_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.ByteTensor(data.train_mask)
    g = DGLGraph(data.graph)
    return g, features, labels, mask

模型訓練的流程和 GCN 教程裡的一樣。

import time
import numpy as np
g, features, labels, mask = load_cora_data()

# 建立模型
net = GAT(g, 
          in_dim=features.size()[1], 
          hidden_dim=8, 
          out_dim=7, 
          num_heads=8)
print(net)

# 建立優化器
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

# 主流程
dur = []
for epoch in range(30):
    if epoch >=3:
        t0 = time.time()

    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >=3:
        dur.append(time.time() - t0)

    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
            epoch, loss.item(), np.mean(dur)))

視覺化並理解學到的注意力

Cora 資料集

以下表格總結了 GAT 論文以及 dgl 實現的模型在 Cora 資料集上的表現:

深入理解圖注意力機制

可以看到 DGL 能完全復現原論文中的實驗結果。對比圖卷積網路 GCN,GAT 在 Cora 上有 2~3 個百分點的提升。

不過,我們的模型究竟學到了怎樣的注意力機制呢?

由於注意力權重深入理解圖注意力機制與圖上的邊密切相關,我們可以通過給邊著色來視覺化注意力權重。以下圖片中我們選取了 Cora 的一個子圖並且在圖上畫出了 GAT 模型最後一層的注意力權重。我們根據圖上節點的標籤對節點進行了著色,根據注意力權重的大小對邊進行了著色(可參考圖右側的色條)。 

深入理解圖注意力機制

圖 2:Cora 資料集上學習到的注意力權重

乍看之下模型似乎學到了不同的注意力權重。為了對注意力機制有一個全域性觀念,我們衡量了注意力分佈的熵。對於節點 i,{α_ij }_(j∈N(i)) 構成了一個在 i 鄰節點上的離散概率分佈。它的熵被定義為:

深入理解圖注意力機制

直觀地說,熵低代表了概率高度集中,反之亦然。熵為 0 則所有的注意力都被放在一個點上。均勻分佈具有最高的熵(log N(i))。在理想情況下,我們想要模型習得一個熵較低的分佈(即某一、兩個節點比其它節點重要的多)。注意由於節點的入度不同,它們注意力權重的分佈所能達到的最大熵也會不同。

基於圖中所有節點的熵,我們畫了所有頭注意力的直方圖。 

深入理解圖注意力機制

圖 3:Cora 資料集上學到的注意力權重直方圖。

作為參考,下圖是在所有節點的注意力權重都是均勻分佈的情況下得到的直方圖。 

深入理解圖注意力機制

出人意料的,模型學到的節點注意力權重非常接近均勻分佈(換言之,所有的鄰節點都獲得了同等重視)。這在一定程度上解釋了為什麼在 Cora 上 GAT 的表現和 GCN 非常接近(在上面表格裡我們可以看到兩者的差距平均下來不到 2%)。由於沒有顯著區分節點,注意力並沒有那麼重要。

這是否說明了注意力機制沒什麼用?不!在接下來的資料集上我們觀察到了完全不同的現象。

蛋白質互動網路 (PPI)

PPI(蛋白質間相互作用)資料集包含了 24 張圖,對應了不同的人體組織。節點最多可以有 121 種標籤(比如蛋白質的一些性質、所處位置等)。因此節點標籤被表示為有 121 個元素的二元張量。資料集的任務是預測節點標籤。

我們使用了 20 張圖進行訓練,2 張圖進行驗證,2 張圖進行測試。平均下來每張圖有 2372 個節點。每個節點有 50 個特徵,包含定位基因集合、特徵基因集合以及免疫特徵。至關重要的是,測試用圖在訓練過程中對模型完全不可見。這一設定被稱為歸納學習

我們比較了 dgl 實現的 GAT 和 GCN 在 10 次隨機訓練中的表現。模型的超引數驗證集上進行了優化。在實驗中我們使用了 micro f1 score 來衡量模型的表現。

深入理解圖注意力機制

在訓練過程中,我們使用了 BCEWithLogitsLoss 作為損失函式。下圖繪製了 GAT 和 GCN 的學習曲線;顯然 GAT 的表現遠優於 GCN。 

深入理解圖注意力機制

圖 4:PPI 資料集上 GCN 和 GAT 學習曲線比較。

像之前一樣,我們可以通過繪製節點注意力分佈之熵的直方圖來有一個統計意義上的直觀瞭解。以下我們基於一個 3 層 GAT 模型中不同模型層不同注意力頭繪製了直方圖。

第一層學到的注意力

深入理解圖注意力機制 第二層學到的注意力

深入理解圖注意力機制

最後一層學到的注意力 

深入理解圖注意力機制

作為參考,下圖是在所有節點的注意力權重都是均勻分佈的情況下得到的直方圖。

深入理解圖注意力機制

可以很明顯地看到,GAT 在 PPI 上確實學到了一個尖銳的注意力權重分佈。與此同時,GAT 層與層之間的注意力也呈現出一個清晰的模式:在中間層隨著層數的增加註意力權重變得愈發集中;最後的輸出層由於我們對不同頭結果做了平均,注意力分佈再次趨近均勻分佈。

不同於在 Cora 資料集上非常有限的收益,GAT 在 PPI 資料集上較 GCN 和其它圖模型的變種取得了明顯的優勢(根據原論文的結果在測試集上的表現提升了至少 20%)。我們的實驗揭示了 GAT 學到的注意力顯著區別於均勻分佈。雖然這值得進一步的深入研究,一個由此而生的假設是 GAT 的優勢在於處理更復雜領域結構的能力。

擴充閱讀

到目前為止我們演示瞭如何用 DGL 實現 GAT。簡介起見,我們忽略了 dropout, skip connection 等一些細節。這些細節很常見且獨立於 DGL 相關的概念。有興趣的讀者歡迎參閱完整的程式碼實現。

關於 DGL 專欄: DGL 是一款全新的面向圖神經網路的開源框架。通過該專欄,我們 DGL 團隊希望和大家一起學習圖神經網路的最新進展。同時展示 DGL 的靈活性和高效性。通過系統學習演算法,通過演算法理解系統。

相關文章