PyTorch Geometric Temporal 介紹 —— 資料結構和RGCN的概念

多事鬼間人 發表於 2022-11-28
資料結構 PyTorch

Introduction

PyTorch Geometric Temporal is a temporal graph neural network extension library for PyTorch Geometric.

PyTorch Geometric Temporal 是基於PyTorch Geometric的對時間序列圖資料的擴充套件。

Data Structures: PyTorch Geometric Temporal Signal

定義:在PyTorch Geometric Temporal中,邊、邊特徵、節點被歸為圖結構Graph,節點特徵被歸為訊號Single,對於特定時間切片或特定時間點的時間序列圖資料被稱為快照Snapshot。

PyTorch Geometric Temporal定義了數個Temporal Signal Iterators用於時間序列圖資料的迭代。

Temporal Signal Iterators資料迭代器的引數是由描述圖的各個物件(edge_index,node_feature,...)的列表組成,列表的索引對應各時間節點。

按照圖結構的時間序列中的變換部分不同,圖結構包括但不限於為以下幾種:

  • Static Graph with Temporal Signal
    靜態的邊和邊特徵,靜態的節點,動態的節點特徵
  • Dynamic Graph with Temporal Signal
    動態的邊和邊特徵,動態的節點和節點特徵
  • Dynamic Graph with Static Signal
    動態的邊和邊特徵,動態的節點,靜態的節點特徵

理論上來說,任意描述圖結構的物件都可以根據問題定為靜態或動態,所有物件都為靜態則為傳統的GNN問題。

實際上,在PyTorch Geometric Temporal定義的資料迭代器中,靜態和動態的差別在於是以陣列的列表還是以單一陣列的形式輸入,以及在輸出時是按索引從列表中讀取還是重複讀取單一陣列。

如在StaticGraphTemporalSignal的原始碼中_get_edge_index _get_features分別為:

# https://pytorch-geometric-temporal.readthedocs.io/en/latest/_modules/torch_geometric_temporal/signal/static_graph_temporal_signal.html#StaticGraphTemporalSignal

def _get_edge_index(self):
	if self.edge_index is None:
		return self.edge_index
	else:
		return torch.LongTensor(self.edge_index)
		
def _get_features(self, time_index: int):
	if self.features[time_index] is None:
    	return self.features[time_index]
    else:
    	return torch.FloatTensor(self.features[time_index])

對於Heterogeneous Graph的資料迭代器,其與普通Graph的差異在於對於每個類別建立鍵值對組成字典,其中的值按靜態和動態定為列表或單一陣列。

Recurrent Graph Convolutional Layers

Define $\ast_G $ as graph convolution, \(\odot\) as Hadamard product

\[\begin{aligned} &z = \sigma(W_{xz}\ast_Gx_t+W_{hz}\ast_Gh_{t-1}),\\ &r = \sigma(W_{xr}\ast_Gx_t+W_{hr}\ast_Gh_{t-1}),\\ &\tilde h = \text{tanh}(W_{xh}\ast_Gx_t+W_{hh}\ast_G(r\odot h_{t-1})),\\ &h_t = z \odot h_{t-1} + (1-z) \odot \tilde h \end{aligned} \]

From https://arxiv.org/abs/1612.07659

具體的函式實現見 https://pytorch-geometric-temporal.readthedocs.io/en/latest/modules/root.html#

與RNN的比較

\[\begin{aligned} &z_t = \sigma(W_{xz}x_t+b_{xz}+W_{hz}h_{t-1}+b_{hz}),\\ &r_t = \sigma(W_{xr}x_t+b_{xr}+W_{hr}h_{t-1}+b_{hr}),\\ &\tilde h_t = \text{tanh}(W_{xh}x_t+b_{xh}+r_t(W_{hh}h_{t-1}+b_{hh})),\\ &h_t = z*h_{t-1} + (1-z)*\tilde h \end{aligned} \]

From https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU

對於傳統GRU的解析 https://zhuanlan.zhihu.com/p/32481747

在普通資料的Recurrent NN中,對於每一條時間序列資料會獨立的計算各時間節點會根據上一時間節點計算hidden state。但在時間序列圖資料中,每個snapshot被視為一個整體計算Hidden state matrix \(H \in \mathbb{R}^{\text{Num(Nodes)}\times \text{Out_Channels}_H}\) 和Cell state matrix(對於LSTM)\(C \in \mathbb{R}^{\text{Num(Nodes)}\times \text{Out_Channels}_C}\)

與GCN的比較

相較於傳統的Graph Convolution Layer,RGCN將圖卷積計算的擴充套件到RNN各個狀態的計算中替代原本的引數矩陣和特徵的乘法計算。