學習筆記:SSTBAN 用於長期交通預測的自監督時空瓶頸注意力

white514發表於2024-03-23

Self-Supervised Spatial-Temporal Bottleneck Attentive Network for Efficient Long-term Traffic Forecasting
用於高效長期交通預測的自監督時空瓶頸注意力網路
期刊會議:ICDE2023
論文地址:https://ieeexplore.ieee.org/document/10184658
程式碼地址:https://github.com/guoshnBJTU/SSTBAN

長期交通預測存在的問題:

  • 難以平衡準確率和效率。隨著時間跨度增大,要麼無法捕捉長期動態性,要麼以二次計算複雜度為代價獲取全域性接受域。
  • 高質量的訓練資料需求與模型的泛化能力的矛盾。如何提升資料的利用效率值得思考。
    SSTBAN採用多工框架,結合自監督學習器對歷史交通資料產生魯棒的潛在表示,從而提高其泛化效能和預測的魯棒性。此外,作者還設計了一個時空瓶頸注意機制,在編碼全域性時空動態的同時降低了計算複雜度。

長期預測需求分析:
與有助於及時決策的短期預測相比,長期預測為旅行者和管理員提供了必要的支援資訊,以最佳化旅行計劃和運輸資源管理。特別是未來幾個小時的流量預測資訊,有助於使用者提前制定路由計劃。

選擇注意力機制的原因:
目前STGNNs分為RNN-based,CNN-based和attention-based方法。RNN-based存在梯度消失問題,不利於長期預測,且序列順序的預測方式使得模型訓練時間隨著預測時間線性增加;CNN-based的kernel大小限制了長期動態性的捕捉能力;Attention-based更靈活,不會受到空間和時間距離的影響,但是存在二次計算複雜度的問題,這是要解決的問題。

資料的利用效率:高質量資料需求與泛化性矛盾
現有方法普遍有較強的高質量資料需求,當訓練資料存在噪聲時,就會導致過擬合或是學習到虛假的關係,泛化能力不佳。於是引入了常用於NLP和CV中的自監督學習。這要正確認知NLP/CV和時空交通預測的區別:

  • NLP/CV中的基礎模式,如形狀和語義,在廣泛資料集中是通用的;而交通資料集中則鮮有這樣的共同特徵,比如可能資料的特徵都不一樣。
  • NLP只需捕捉序列特徵——即時間,CV只需要捕捉空間特徵,但是STGNN要同時捕捉時空特徵。

貢獻:

  • 第一次提出了一種採用自監督學習器的時空交通預測模型,滿足了泛化和魯棒需求。
  • 設計了一種時空瓶頸注意力機制,能夠高效捕捉長期時空動態,將時間複雜度由二次方降低至線性。(RNN:?)
  • 在九個資料集上進行了實驗,證明了在精度和效率上的優勢。

模型

圖: SSTBAN架構

模型包含兩個分支:第一個是時空預測分支,第二個是時空自監督學習分支,因此是個多工框架。

在訓練階段,兩個分支一起工作。在分支一中,原始的資料依次經過ST Encoder、Transformer Attention,最後由ST Forecasting Decoder預測;在分支二中,首先隨機mask掉一些資料,將破損的資料經過ST Encoder來用剩餘的資料提取特徵,經過ST Reconstruction Decoder來補全丟掉的資料,並將補全的資料和分支一的完整資料進行對齊比較(為了避免噪聲的影響,這裡的比較是放在了潛在空間中的)。訓練損失也包含了兩個,一個是預測誤差的MAE,另一個是對齊的MSE。

在兩個分支中,encoder和兩個decoder由一樣的時空瓶頸注意力模組(STBA)和時空嵌入模組(STE)構成。

STBA目的是捕捉長期的時空動態性,且維持低的計算複雜度。
STE目的是提取不同時間切片和節點的獨特性,來彌補基於注意力機制的STBA對順序的不敏感缺點。我們透過端到端的方式訓練空間嵌入\(E_{SP}\in R^{N\times d}\),它在所有時間中共享;透過time-of-day和day-of-week,用one-hot和MLPs得到輸入時間嵌入\(E_{TP}\in R^{P\times d}\)和輸出時間嵌入\(E_{TP}'\in R^{Q\times d}\),它在所有節點中共享;將它們相加得到輸入序列嵌入\(\mathcal{E}\in R^{P\times N\times d}\)和輸入序列嵌入\(\mathcal{E}'\in R^{Q\times N\times d}\)

時空瓶頸注意力 STBA

圖: 時空瓶頸注意力STBA,時間瓶頸注意力TBA,空間SBA

圖中\(\mathcal{Z}^{(l-1)}=(\mathcal{H}^{(l-1)}||\mathcal{E})\in\mathbb{R}^{P\times N\times2d}\)

STBA包含了空間注意力(SBA)和時間注意力(TBA)。它們並沒有直接和其他點相連,而是和參考點相連,而參考點的數量遠小於時間點和空間點。我們還希望參考點能夠編碼通用的全域性資訊。由於整體形狀像瓶頸而得名。

TBA平行地處理每個點的輸入。(這裡的過程還不是很懂。)

STBA具有以下特點:

  • 由於參考點的設定,運算複雜度從\(O(N^2)\)降低到了\(O(NN')\),因為\(N'\)是個小的超引數。相比於GCN,STBA不需要預定義的圖結構,同時能動態調整節點間關係強度。
  • 參考點起到了編碼全域性模式的作用,可以理解輸入,如用來聚類。

時空預測分支

圖:分支一 時空預測分支結構

組成部分:
(1)時空編碼器:由時空瓶頸注意力組成,對映到潛在表徵空間
(2)Transformer attention:將潛在空間下的歷史訊號適配於預測訊號尺寸。為了緩解長期預測存在的比較嚴重的誤差傳播問題,我們透過自適應地融合歷史中的不同特徵,用注意力機制直接把每步的歷史訊號和預測訊號連線起來。即

\[\mathcal{H^{\prime}}_{:,v}^{(0)}=\mathrm{MHSA}(\mathcal{E}_{:,v}^{\prime},\mathcal{E}_{:,v},\mathcal{H}_{:,v}^{(L)})\in\mathbb{R}^{Q\times d} \]

(3)時空預測解碼器:由若干層時空瓶頸注意力,最後加上全連線層組成。

時空自監督學習分支

這一分支從mask掉部分訊號的不完整資料中,理解時空關係,並在潛在空間中重構缺失的訊號。目的是訓練潛在空間表徵能力。包括如下部分:
(1)Masking:考慮到mask掉單獨一個時間點的資料,很容易透過前後資料算出來,因此在時間或空間維度上mask掉連續的段,以此來學習趨勢模式。Mask策略是,將輸入資料分成若干patch,並將一定比例的patch全部清0。

圖:Masking演算法

(2)時空編碼器:和分支一中的一樣。只是,被mask掉的資料不參與時空瓶頸注意力的計算。
(3)時空重構編碼器:輸入(2)提供的殘缺潛在表徵,以及指示mask位置的token向量。由若干時空瓶頸組成,並將重構後的表徵與分支一的完整表徵匹配。

實驗

資料集
分別是Seattle Loop,PEMS04, PEMS08
這個Seattle Loop我還是第一次見。

Loop Seattle 資料集由部署在西雅圖地區高速公路(I-5、I-405、I-90和SR-520)上的感應環路探測器收集,包含來自323個感測器站的交通狀態資料。

圖:資料集資訊

超參數列
\(L\):ST Encoder中STBA的數量
\(L'\):STF Decoder中STBA的數量
\(d\):多頭注意力機制的維數
\(h\):多頭注意力機制的頭數
\(l_m\):Masking過程的patch length
\(\alpha_m\):Masking過程的mask率
\(\lambda\):預測損失和對齊損失的權重。越大代表對齊損失佔比越大。
時間和空間參考點的數量都是3。

圖:超引數設定

對比試驗

實際上作者選的這些基線模型都比較老,說服力比較差。我在本文的最後放上了自己做的一點對比實驗,可以當作參考。

圖:PEMS對比試驗
圖:SeattleLoop對比實驗

隨著時間跨度增加,SSTBAN的優勢也在增加。

圖:預測表現與預測長度的關係

魯棒測試

作者還進行了以下兩個實驗。可以看到,模型在這兩個方面還是有優點的。遺憾就是,對比的模型只有兩個,且比較古老,依然是缺乏說服力。
注:GMAN AAAI2020,DMSTGCN KDD2021

圖:減少訓練資料
圖:隨機新增噪聲

消融實驗

將STBA與普通注意力網路進行對比。也覺是說,STBA在減小時間複雜度的同時,還能增加準確率。

圖:消融實驗

算力消耗

可以看出,時間消耗和空間小號還是比較小的。不過在實驗中,模型的時間和空間佔用與batch size等引數設定有關,所以這個只能做參考吧。在我用PEMS08做復現,預測48步時,設定batch size=8,視訊記憶體佔用20G左右。

圖:算力實驗

復現

以下是我的復現結果。以下實驗每個僅做了一次,並沒有重複實驗,所以僅作參考。

batch size對時間的影響較大。越大,時間越短,但相應的佔用視訊記憶體越多。在SSTBAN用PEMS08預測48步的實驗中,設定batch size=16時,40G視訊記憶體的A100就已經跑不動了,可以看出SSTBAN的空間佔用還是比較大的。SSTBAN的特點是,訓練一個epoch耗時比較長,但是epoch數量少,就觸發早停了。

注:TrendGCN屬於CIKM2023

模型 資料集 步數 epoch MAE MAPE RMSE 時間
TrendGCN 08 12,12 120(batch64) 15.11 9.68 24.25 0h41
TrendGCN 08 24,24 120(batch128) 16.84 10.77 27.14 0h40
TrendGCN 08 36,36 120(batch64) 17.70 11.95 28.63 1h30
TrendGCN 08 48,48 120(batch128) 18.86 12.91 29.97 1h20
模型 資料集 步數 epoch MAE MAPE RMSE 時間
SSTBAN 08 12,12 71(batch32) 15.36 10.79 24.26 0h50
SSTBAN 08 24,24 16(batch32) 15.40 10.68 26.20 1h30
SSTBAN 08 36,36 18(batch4) 16.56 11.65 29.33 3h30
SSTBAN 08 48,48 15(batch8) 17.29 15.10 29.04 2h40
SSTBAN 04 36,36 15(batch8) 21.09 15.29 37.42 2h

相關文章