手把手教你用DGL框架進行批量圖分類

DGL專欄發表於2019-01-29

圖分類(預測圖的標籤)是圖結構資料裡一類重要的問題。它的應用廣泛,可見於生物資訊學、化學資訊學、社交網路分析、城市計算以及網路安全。隨著近來學界對於神經網路的熱情持續高漲,出現了一批用神經網路做圖分類的工作。比如訓練神經網路來預測蛋白質結構的性質,根據社交網路結構來預測使用者的所屬社群等(Ying et al., 2018, Cangea et al., 2018, Knyazev et al., 2018, Bianchi et al., 2019, Liao et al., 2019, Gao et al., 2019)。

在這個教程裡,我們將一起學習:

  • 如何使用 DGL 批量化處理大小各異的圖資料

  • 訓練神經網路完成一個簡易的圖分類任務

簡易圖分類任務

這裡我們設計了一個簡單的圖分類任務。在 DGL 裡我們實現了一個迷你圖分類資料集(MiniGCDataset)。它由以下 8 類圖結構資料組成。每一類圖包含同樣數量的隨機樣本。任務目標是訓練神經網路模型對這些樣本進行分類。

手把手教你用DGL框架進行批量圖分類

以下是使用 MiniGCDataset 的示例程式碼。我們先建立了一個擁有 80 個樣本的資料集。資料集中每張圖隨機有 10 到 20 個節點。DGL 中所有的資料集類都符合 Sequence 的抽象結構——既可以使用 dataset[i] 來訪問第 i 個樣本。這裡每個樣本包含圖結構以及它對應的標籤。

from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
# 資料集包含了80張圖。每張圖有10-20個節點
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[0]
fig, ax = plt.subplots()
nx.draw(graph.to_networkx(), ax=ax)
ax.set_title('Class: {:d}'.format(label))
plt.show()

執行以上程式碼後可以畫出資料集中第一個樣本的圖結構以及它對應的標籤:

手把手教你用DGL框架進行批量圖分類

打包一個圖的小批量

為了更高效地訓練神經網路,一個常見的做法是將多個樣本打包成小批量(mini-batch)。打包尺寸相同的張量樣本非常簡單。比如說打包兩個尺寸為 2828 的圖片會得到一個 22828 的張量。相較之下,打包圖面臨兩個挑戰:

  • 圖的邊比較稀疏

  • 圖的大小、形狀各不相同

DGL 提供了名為 dgl.batch 的介面來實現打包一個圖批量的功能。其核心思路非常簡單。將 n 張小圖打包在一起的操作可以看成是生成一張含 n 個不相連小圖的大圖。下圖的視覺化從直覺上解釋了 dgl.batch 的功能。

手把手教你用DGL框架進行批量圖分類

可以看到通過 dgl.batch 操作,我們生成了一張大圖,其中包含了一個環狀和一個星狀的連通分量。其鄰接矩陣表示則對應為在對角線上把兩張小圖的鄰接矩陣拼接在一起(其餘部分都為 0)。

以下是使用 dgl.batch 的一個實際例子。我們定義了一個 collate 函式來將 MiniGCDataset 裡多個樣本打包成一個小批量。

import dgl

def collate(samples):
    # 輸入`samples` 是一個列表
    # 每個元素都是一個二元組 (圖, 標籤)
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

正如打包 N 個張量得到的還是張量,dgl.batch 返回的也是一張圖。這樣的設計有兩點好處。首先,任何用於操作一張小圖的程式碼可以被直接使用在一個圖批量上。其次,由於 DGL 能夠並行處理圖中節點和邊上的計算,因此同一批量內的圖樣本都可以被平行計算。

圖分類器

這裡使用的圖分類器和應用在影象或者語音上的分類器類似——先通過多層神經網路計算每個樣本的表示(representation),再通過表示計算出每個類別的概率,最後通過向後傳播計算梯度。一個常見的圖分類器由以下幾個步驟構成:

  1. 通過圖卷積(Graph Convolution)層獲得圖中每個節點的表示。

  2. 使用「讀出」操作(Readout)獲得每張圖的表示。

  3. 使用 Softmax 計算每個類別的概率,使用向後傳播更新引數

下圖展示了整個流程:

手把手教你用DGL框架進行批量圖分類

之後我們將分步講解每一個步驟。

圖卷積

我們的圖卷積操作基本類似圖卷積網路 GCN(具體可以參見我們的關於 GCN 的教程)。圖卷積模型可以用以下公式表示:

手把手教你用DGL框架進行批量圖分類

在這個例子中,我們對這個公式進行了微調:

手把手教你用DGL框架進行批量圖分類

我們將求和替換成求平均可用來平衡度數不同的節點,在實驗中這也帶來了模型表現的提升。

此外,在構建資料集時,我們給每個圖裡所有的節點都加上了和自己的邊(自環)。這保證節點在收集鄰居節點表示進行更新時也能考慮到自己原有的表示。以下是定義圖卷積模型的程式碼。這裡我們使用 PyTorch 作為 DGL 的後端引擎(DGL 也支援 MXNet 作為後端)。

首先,我們使用 DGL 的內建函式定義訊息傳遞:

import dgl.function as fn
import torch
import torch.nn as nn

# 將節點表示h作為資訊發出
msg = fn.copy_src(src='h', out='m')

其次,我們定義訊息累和函式。這裡我們對收到的訊息進行平均。

def reduce(nodes):
    """對所有鄰節點節點特徵求平均並覆蓋原本的節點特徵。"""
    accum = torch.mean(nodes.mailbox['m'], 1)
    return {'h': accum}

之後,我們對收到的訊息應用線性變換和啟用函式

class NodeApplyModule(nn.Module):
    """將節點特徵 hv 更新為 ReLU(Whv+b)."""
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h' : h}

最後,我們把所有的小模組串聯起來成為 GCNLayer。

class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCNLayer, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        # 使用 h 初始化節點特徵。
        g.ndata['h'] = feature
        # 使用 update_all介面和自定義的訊息傳遞及累和函式更新節點表示。
        g.update_all(msg, reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

讀出和分類

讀出(Readout)操作的輸入是圖中所有節點的表示,輸出則是整張圖的表示。在 Google 的 Neural Message Passing for Quantum Chemistry(Gilmer et al. 2017) 論文中總結過許多不同種類的讀出函式。在這個示例裡,我們對圖中所有節點表示取平均以作為圖的表示:

手把手教你用DGL框架進行批量圖分類

DGL 提供了許多讀出函式介面,以上公式可以很方便地用 dgl.mean(g) 完成。最後我們將圖的表示輸入分類器。分類器對圖表示先做了一個線性變換然後得到每一類在 softmax 之前的 logits。具體程式碼如下:

import torch.nn.functional as F

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        # 兩層圖卷積層。
        self.layers = nn.ModuleList([
            GCNLayer(in_dim, hidden_dim, F.relu),
            GCNLayer(hidden_dim, hidden_dim, F.relu)])
        # 分類層。
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # 使用節點度數作為初始節點表示。
        h = g.in_degrees().view(-1, 1).float()
        # 圖卷積層。
        for conv in self.layers:
            h = conv(g, h)
        g.ndata['h'] = h
        # 讀出函式。
        graph_repr = dgl.mean_nodes(g, 'h')
        # 分類層。
        return self.classify(graph_repr)

準備和訓練

閱讀到這邊的讀者可以長舒一口氣了。因為之後的訓練過程和其他經典的影象,語音分類問題基本一致。首先我們建立了一個包含 400 張節點數量為 10~20 的合成資料集。其中 320 張圖作為訓練資料集,80 張圖作為測試集。

import torch.optim as optim
from torch.utils.data import DataLoader

# 建立一個訓練資料集和測試資料集
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)
# 使用 PyTorch 的 DataLoader 和之前定義的 collate 函式。
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
                         collate_fn=collate)

其次我們建立一個剛剛定義的神經網路模型物件。

# 建立模型
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

訓練過程則是經典的反向傳播和梯度下降

# 建立模型

epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

下圖是以上模型訓練的學習曲線

手把手教你用DGL框架進行批量圖分類

在訓練完成後,我們在測試集上驗證模型的表現。出於部署教程的考量,我們限制了模型訓練的時間。如果你花更多時間訓練模型,應該能得到更好的表現(80%-90%)。

我們還製作了一個動畫來展示訓練好的模型預測每張圖真實標籤的概率。可以看到我們剛剛定義的神經網路能夠較為準確地預測出圖樣本的對應標籤:

手把手教你用DGL框架進行批量圖分類

為了更好地理解模型學到的節點和圖的表示,我們使用了 t-SNE 來進行降維和視覺化。

手把手教你用DGL框架進行批量圖分類

手把手教你用DGL框架進行批量圖分類

頂部的兩張小圖分別視覺化了做完 1 層和 2 層圖卷積後的節點表示。不同顏色代表屬於不同類別的圖的節點。可以看到,經過訓練後,屬於同一類別的節點表示更加接近。並且,經過兩層圖卷積後這一聚類效果更明顯。其原因是因為兩層卷積後每個節點能接收到 2 度範圍內的鄰居資訊。

底部的大圖視覺化了每張圖在做 softmax 前的 logits,也就是圖表示。可以看到通過讀出函式後,圖表示能非常好地各自區分開來。這一區分度比節點表示更加明顯。

擴充閱讀

使用神經網路作圖分類還是一個嶄新的領域。這個任務並不簡單,需要模型能將不同的圖結構資料對映到不同的表示,同時要求圖表示可以儲存它們結構和內容上的異同。欲知更多內容,ICLR 2019 上新鮮出爐的 oral paper How Powerful Are Graph Neural Networks? 可能是一個不錯的出發點。

以上程式碼都可以在我們的文件網站下載:https://docs.dgl.ai/tutorials/basics/4_batch.html

對於想要更進一步學習用 DGL 進行批量化圖處理,以下內容可能會有所幫助:

  • Tree LSTM 教程。該模型需要批處理多個句子的語法書結構。

  • Deep Generative Models of Graphs 的 DGL 教程。

  • Junction Tree VAE 的 DGL 示例程式碼。該模型根據分子結構圖來預測分子性質。

關於 DGL 專欄: DGL 是一款全新的面向神經網路的開源框架。通過該專欄,我們 DGL 團隊希望和大家一起學習神經網路的最新進展。同時展示 DGL 的靈活性和高效性。通過系統學習演算法,通過演算法理解系統。

相關文章