Self-Supervised Spatial-Temporal Bottleneck Attentive Network for Efficient Long-term Traffic Forecasting
用於高效長期交通預測的自監督時空瓶頸注意力網路
期刊會議:ICDE2023
論文地址:https://ieeexplore.ieee.org/document/10184658
程式碼地址:https://github.com/guoshnBJTU/SSTBAN
長期交通預測存在的問題:
- 難以平衡準確率和效率。隨著時間跨度增大,要麼無法捕捉長期動態性,要麼以二次計算複雜度為代價獲取全域性接受域。
- 高質量的訓練資料需求與模型的泛化能力的矛盾。如何提升資料的利用效率值得思考。
SSTBAN採用多工框架,結合自監督學習器對歷史交通資料產生魯棒的潛在表示,從而提高其泛化效能和預測的魯棒性。此外,作者還設計了一個時空瓶頸注意機制,在編碼全域性時空動態的同時降低了計算複雜度。
長期預測需求分析:
與有助於及時決策的短期預測相比,長期預測為旅行者和管理員提供了必要的支援資訊,以最佳化旅行計劃和運輸資源管理。特別是未來幾個小時的流量預測資訊,有助於使用者提前制定路由計劃。
選擇注意力機制的原因:
目前STGNNs分為RNN-based,CNN-based和attention-based方法。RNN-based存在梯度消失問題,不利於長期預測,且序列順序的預測方式使得模型訓練時間隨著預測時間線性增加;CNN-based的kernel大小限制了長期動態性的捕捉能力;Attention-based更靈活,不會受到空間和時間距離的影響,但是存在二次計算複雜度的問題,這是要解決的問題。
資料的利用效率:高質量資料需求與泛化性矛盾
現有方法普遍有較強的高質量資料需求,當訓練資料存在噪聲時,就會導致過擬合或是學習到虛假的關係,泛化能力不佳。於是引入了常用於NLP和CV中的自監督學習。這要正確認知NLP/CV和時空交通預測的區別:
- NLP/CV中的基礎模式,如形狀和語義,在廣泛資料集中是通用的;而交通資料集中則鮮有這樣的共同特徵,比如可能資料的特徵都不一樣。
- NLP只需捕捉序列特徵——即時間,CV只需要捕捉空間特徵,但是STGNN要同時捕捉時空特徵。
貢獻:
- 第一次提出了一種採用自監督學習器的時空交通預測模型,滿足了泛化和魯棒需求。
- 設計了一種時空瓶頸注意力機制,能夠高效捕捉長期時空動態,將時間複雜度由二次方降低至線性。(RNN:?)
- 在九個資料集上進行了實驗,證明了在精度和效率上的優勢。
模型
圖: SSTBAN架構 |
模型包含兩個分支:第一個是時空預測分支,第二個是時空自監督學習分支,因此是個多工框架。
在訓練階段,兩個分支一起工作。在分支一中,原始的資料依次經過ST Encoder、Transformer Attention,最後由ST Forecasting Decoder預測;在分支二中,首先隨機mask掉一些資料,將破損的資料經過ST Encoder來用剩餘的資料提取特徵,經過ST Reconstruction Decoder來補全丟掉的資料,並將補全的資料和分支一的完整資料進行對齊比較(為了避免噪聲的影響,這裡的比較是放在了潛在空間中的)。訓練損失也包含了兩個,一個是預測誤差的MAE,另一個是對齊的MSE。
在兩個分支中,encoder和兩個decoder由一樣的時空瓶頸注意力模組(STBA)和時空嵌入模組(STE)構成。
STBA目的是捕捉長期的時空動態性,且維持低的計算複雜度。
STE目的是提取不同時間切片和節點的獨特性,來彌補基於注意力機制的STBA對順序的不敏感缺點。我們透過端到端的方式訓練空間嵌入\(E_{SP}\in R^{N\times d}\),它在所有時間中共享;透過time-of-day和day-of-week,用one-hot和MLPs得到輸入時間嵌入\(E_{TP}\in R^{P\times d}\)和輸出時間嵌入\(E_{TP}'\in R^{Q\times d}\),它在所有節點中共享;將它們相加得到輸入序列嵌入\(\mathcal{E}\in R^{P\times N\times d}\)和輸入序列嵌入\(\mathcal{E}'\in R^{Q\times N\times d}\)。
時空瓶頸注意力 STBA
圖: 時空瓶頸注意力STBA,時間瓶頸注意力TBA,空間SBA |
圖中\(\mathcal{Z}^{(l-1)}=(\mathcal{H}^{(l-1)}||\mathcal{E})\in\mathbb{R}^{P\times N\times2d}\)。
STBA包含了空間注意力(SBA)和時間注意力(TBA)。它們並沒有直接和其他點相連,而是和參考點相連,而參考點的數量遠小於時間點和空間點。我們還希望參考點能夠編碼通用的全域性資訊。由於整體形狀像瓶頸而得名。
TBA平行地處理每個點的輸入。(這裡的過程還不是很懂。)
STBA具有以下特點:
- 由於參考點的設定,運算複雜度從\(O(N^2)\)降低到了\(O(NN')\),因為\(N'\)是個小的超引數。相比於GCN,STBA不需要預定義的圖結構,同時能動態調整節點間關係強度。
- 參考點起到了編碼全域性模式的作用,可以理解輸入,如用來聚類。
時空預測分支
圖:分支一 時空預測分支結構 |
組成部分:
(1)時空編碼器:由時空瓶頸注意力組成,對映到潛在表徵空間
(2)Transformer attention:將潛在空間下的歷史訊號適配於預測訊號尺寸。為了緩解長期預測存在的比較嚴重的誤差傳播問題,我們透過自適應地融合歷史中的不同特徵,用注意力機制直接把每步的歷史訊號和預測訊號連線起來。即
(3)時空預測解碼器:由若干層時空瓶頸注意力,最後加上全連線層組成。
時空自監督學習分支
這一分支從mask掉部分訊號的不完整資料中,理解時空關係,並在潛在空間中重構缺失的訊號。目的是訓練潛在空間表徵能力。包括如下部分:
(1)Masking:考慮到mask掉單獨一個時間點的資料,很容易透過前後資料算出來,因此在時間或空間維度上mask掉連續的段,以此來學習趨勢模式。Mask策略是,將輸入資料分成若干patch,並將一定比例的patch全部清0。
圖:Masking演算法 |
(2)時空編碼器:和分支一中的一樣。只是,被mask掉的資料不參與時空瓶頸注意力的計算。
(3)時空重構編碼器:輸入(2)提供的殘缺潛在表徵,以及指示mask位置的token向量。由若干時空瓶頸組成,並將重構後的表徵與分支一的完整表徵匹配。
實驗
資料集
分別是Seattle Loop,PEMS04, PEMS08
這個Seattle Loop我還是第一次見。
Loop Seattle 資料集由部署在西雅圖地區高速公路(I-5、I-405、I-90和SR-520)上的感應環路探測器收集,包含來自323個感測器站的交通狀態資料。
圖:資料集資訊 |
超參數列
\(L\):ST Encoder中STBA的數量
\(L'\):STF Decoder中STBA的數量
\(d\):多頭注意力機制的維數
\(h\):多頭注意力機制的頭數
\(l_m\):Masking過程的patch length
\(\alpha_m\):Masking過程的mask率
\(\lambda\):預測損失和對齊損失的權重。越大代表對齊損失佔比越大。
時間和空間參考點的數量都是3。
圖:超引數設定 |
對比試驗
實際上作者選的這些基線模型都比較老,說服力比較差。我在本文的最後放上了自己做的一點對比實驗,可以當作參考。
圖:PEMS對比試驗 |
圖:SeattleLoop對比實驗 |
隨著時間跨度增加,SSTBAN的優勢也在增加。
圖:預測表現與預測長度的關係 |
魯棒測試
作者還進行了以下兩個實驗。可以看到,模型在這兩個方面還是有優點的。遺憾就是,對比的模型只有兩個,且比較古老,依然是缺乏說服力。
注:GMAN AAAI2020,DMSTGCN KDD2021
圖:減少訓練資料 |
圖:隨機新增噪聲 |
消融實驗
將STBA與普通注意力網路進行對比。也覺是說,STBA在減小時間複雜度的同時,還能增加準確率。
圖:消融實驗 |
算力消耗
可以看出,時間消耗和空間小號還是比較小的。不過在實驗中,模型的時間和空間佔用與batch size等引數設定有關,所以這個只能做參考吧。在我用PEMS08做復現,預測48步時,設定batch size=8,視訊記憶體佔用20G左右。
圖:算力實驗 |
復現
以下是我的復現結果。以下實驗每個僅做了一次,並沒有重複實驗,所以僅作參考。
batch size對時間的影響較大。越大,時間越短,但相應的佔用視訊記憶體越多。在SSTBAN用PEMS08預測48步的實驗中,設定batch size=16時,40G視訊記憶體的A100就已經跑不動了,可以看出SSTBAN的空間佔用還是比較大的。SSTBAN的特點是,訓練一個epoch耗時比較長,但是epoch數量少,就觸發早停了。
注:TrendGCN屬於CIKM2023
模型 | 資料集 | 步數 | epoch | MAE | MAPE | RMSE | 時間 |
---|---|---|---|---|---|---|---|
TrendGCN | 08 | 12,12 | 120(batch64) | 15.11 | 9.68 | 24.25 | 0h41 |
TrendGCN | 08 | 24,24 | 120(batch128) | 16.84 | 10.77 | 27.14 | 0h40 |
TrendGCN | 08 | 36,36 | 120(batch64) | 17.70 | 11.95 | 28.63 | 1h30 |
TrendGCN | 08 | 48,48 | 120(batch128) | 18.86 | 12.91 | 29.97 | 1h20 |
模型 | 資料集 | 步數 | epoch | MAE | MAPE | RMSE | 時間 |
---|---|---|---|---|---|---|---|
SSTBAN | 08 | 12,12 | 71(batch32) | 15.36 | 10.79 | 24.26 | 0h50 |
SSTBAN | 08 | 24,24 | 16(batch32) | 15.40 | 10.68 | 26.20 | 1h30 |
SSTBAN | 08 | 36,36 | 18(batch4) | 16.56 | 11.65 | 29.33 | 3h30 |
SSTBAN | 08 | 48,48 | 15(batch8) | 17.29 | 15.10 | 29.04 | 2h40 |
SSTBAN | 04 | 36,36 | 15(batch8) | 21.09 | 15.29 | 37.42 | 2h |