深入淺出圖神經網路 GCN程式碼實戰

DazTricky發表於2021-07-14

GCN程式碼實戰

書中5.6節的GCN程式碼實戰做的是最經典Cora資料集上的分類,恰當又不恰當的類比Cora之於GNN就相當於MNIST之於機器學習。

有關Cora的介紹網上一搜一大把我就不贅述了,這裡說一下Cora這個資料集對應的是怎麼樣的。

Cora有2708篇論文,之間有引用關係共5429個,每篇論文作為一個節點,引用關係就是節點之間的邊。每篇論文有一個1433維的特徵來表示某個詞是否在文中出現過,也就是每個節點有1433維的特徵。最後這些論文被分為7類。

所以在Cora上訓練的目的就是學習節點的特徵及其與鄰居的關係,根據已知的節點分類對未知分類的節點的類別進行預測。

知道這些應該就OK了,下面來看程式碼。

資料處理

註釋裡自己都寫了程式碼引用自PyG我覺得就掃幾眼就行了,因為現在常用的資料集兩個GNN輪子(DGL和PyG)裡都有,現在基本都是直接用,很少自己下原始資料再處理了,所以略過。

GCN層定義

回顧第5章中GCN層的定義:

\[X'=\sigma(\tilde L_{sym}XW) \]

所以對於一層GCN,就是對輸入\(X\),乘一個引數矩陣\(W\),再乘一個算好歸一化後的“拉普拉斯矩陣”即可。

來看程式碼:

class GraphConvolution(nn.Module):
    def __init__(self, input_dim, output_dim, use_bias=True):
        super(GraphConvolution, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.use_bias = use_bias
        self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
        if self.use_bias:
            self.bias = nn.Parameter(torch.Tensor(output_dim))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight)
        if self.use_bias:
            init.zeros_(self.bias)

    def forward(self, adjacency, input_feature):
        support = torch.mm(input_feature, self.weight)
        output = torch.sparse.mm(adjacency, support)
        if self.use_bias:
            output += self.bias
        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.input_dim) + ' -> ' \
            + str(self.output_dim) + ')'

定義了一層GCN的輸入輸出維度和偏置,對於GCN層來說,每一層有自己的\(W\)\(X\)是輸入給的,\(\tilde L_{sym}\)是資料集算的,所以只需要定義一個weight矩陣,注意一下維度就行。

傳播的時候只要按照公式\(X'=\sigma(\tilde L_{sym}XW)\)進行一下矩陣乘法就好,注意一個trick:\(\tilde L_{sym}\)是稀疏矩陣,所以先矩陣乘法得到\(XW\),再用稀疏矩陣乘法計算\(\tilde L_{sym}XW\)運算效率上更好。

GCN模型定義

知道了GCN層的定義之後堆疊GCN層就可以得到GCN模型了,兩層的GCN就可以取得很好的效果(過深的GCN因為過度平滑的問題會導致準確率下降):

class GcnNet(nn.Module):
    def __init__(self, input_dim=1433):
        super(GcnNet, self).__init__()
        self.gcn1 = GraphConvolution(input_dim, 16)
        self.gcn2 = GraphConvolution(16, 7)
    
    def forward(self, adjacency, feature):
        h = F.relu(self.gcn1(adjacency, feature))
        logits = self.gcn2(adjacency, h)
        return logits

這裡設定隱藏層維度為16,調到32,64,...都是可以的,我自己試的結果來說沒有太大的區別。從隱藏層到輸出層直接將輸出維度設定為分類的維度就可以得到預測分類。

傳播的時候相比於每一層的傳播只需要加上啟用函式,這裡選用ReLU

訓練

定義模型、損失函式(交叉熵)、優化器:

model = GcnNet(input_dim).to(DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), 
                       lr=LEARNING_RATE, 
                       weight_decay=WEIGHT_DACAY)

具體的訓練函式註釋已經解釋的很清楚:

def train():
    loss_history = []
    val_acc_history = []
    model.train()
    train_y = tensor_y[tensor_train_mask]
    for epoch in range(EPOCHS):
        logits = model(tensor_adjacency, tensor_x)  # 前向傳播
        train_mask_logits = logits[tensor_train_mask]   # 只選擇訓練節點進行監督
        loss = criterion(train_mask_logits, train_y)    # 計算損失值
        optimizer.zero_grad()
        loss.backward()     # 反向傳播計算引數的梯度
        optimizer.step()    # 使用優化方法進行梯度更新
        train_acc, _, _ = test(tensor_train_mask)     # 計算當前模型訓練集上的準確率
        val_acc, _, _ = test(tensor_val_mask)     # 計算當前模型在驗證集上的準確率
        # 記錄訓練過程中損失值和準確率的變化,用於畫圖
        loss_history.append(loss.item())
        val_acc_history.append(val_acc.item())
        print("Epoch {:03d}: Loss {:.4f}, TrainAcc {:.4}, ValAcc {:.4f}".format(
            epoch, loss.item(), train_acc.item(), val_acc.item()))
    
    return loss_history, val_acc_history

對應的測試函式:

def test(mask):
    model.eval()
    with torch.no_grad():
        logits = model(tensor_adjacency, tensor_x)
        test_mask_logits = logits[mask]
        predict_y = test_mask_logits.max(1)[1]
        accuarcy = torch.eq(predict_y, tensor_y[mask]).float().mean()
    return accuarcy, test_mask_logits.cpu().numpy(), tensor_y[mask].cpu().numpy()

注意模型得到的分類不是one-hot的,而是對應不同種類的預測概率,所以要test_mask_logits.max(1)[1]取概率最高的一個作為模型預測的類別。

這些都寫好之後直接執行訓練函式即可。有需要還可以對train_lossvalidation_accuracy進行畫圖,書上也給出了相應的程式碼,比較簡單不再贅述。

相關文章