BurstAttention:可對非常長的序列進行高效的分散式注意力計算

deephub發表於2024-03-23

提高llm中注意力機制效率的努力主要集中在兩種方法上:最佳化單裝置計算和儲存能力,如FlashAttention,以及利用多裝置的分散式系統,如RingAttention。

FlashAttention透過使用靜態隨機儲存器(SRAM)來儲存中間狀態,而不是依賴於高頻寬儲存器(HBM)來提高注意力計算速度。

而RingAttention透過將長序列劃分為子序列並將其分佈在多個裝置上進行並行處理來處理長序列。

雖然它們都提高了處理速度和效率,如果將它們組合起來使用是否可以有更大的提高呢?理論上是這樣,但是在分散式環境中直接組合這兩種方法無法充分利用它們的優勢,並且存在相容性問題。

而最新的研究BurstAttention可以將2者結合,作為RingAttention和FlashAttention之間的橋樑。

BurstAttention是一個創新的框架,它最佳化了跨裝置的計算和通訊,增強了記憶體使用,最小化了通訊開銷,提高了快取效率。

https://avoid.overfit.cn/post/5aacdef85b104ff0a9faea9ad84f2a95

相關文章