【雜學】大模型推理加速 —— KV-cache 技術

KeanShi發表於2024-11-13

如果不熟悉 Transformer 的同學可以點選這裡瞭解

自從《Attention Is All You Need》問世以來,Transformer 已經成為了 LLM 中最基礎的架構,被廣泛使用。KV-cache 是大模型推理加速的關鍵技術之一,已經成為了 Transformer 標配的功能,不過其只能用於 Decoder 結構:由於 Decoder 中有 Mask 機制,推理的時候前面的詞不需要與後面的詞計算 attention score,因此KV 矩陣可以被快取起來用於多次計算。

例如,我們要生成 "I love Tianjin University" 這句話。首先是隻有開頭標記 <s>,計算過程如下圖所示(為了便於理解,我們將 softmax 和 scale 去掉):

最終第一步的注意力 \(\text{Att}_{step1}\) 計算公式為:

\[{\color{red}\text{Att}_1}(Q,K,V)=({\color{red}{Q_1}}K_1^T)\overrightarrow{V_1} \]

此時序列中的詞為 "<s> I",由於有 Mask 機制,第二步計算如下圖所示:

第二步的注意力為:

\[\begin{aligned} \text{Att}_{step2}&=\begin{bmatrix}{\color{red}{Q_1}}K_{1}^{T}&0\\{\color{green}{Q_2}}K_{1}^{T}&{\color{green}{Q_2}}K_{2}^{T}\end{bmatrix}\begin{bmatrix}\overrightarrow{V_{1}}\\\overrightarrow{V_{2}}\end{bmatrix} =\begin{bmatrix}{\color{red}{Q_1}}K_1^T\times\overrightarrow{V1}\\{\color{green}{Q_2}}K_1^T\times\overrightarrow{V1}+{\color{green}{Q_2}}K_2^T\times\overrightarrow{V2}\end{bmatrix} \end{aligned}\]

\(\text{Att}_1\) 是第一行,\(\text{Att}_2\) 是第二行,則有:

\[\begin{aligned}&{\color{red}\text{Att}_1}(Q,K,V)={\color{red}{Q_1}}K_1^T\overrightarrow{V_1}\\&{\color{green}\text{Att}_2}(Q,K,V)={\color{green}{Q_2}}K_1^T\overrightarrow{V_1}+{\color{green}{Q_2}}K_2^T\overrightarrow{V_2}\end{aligned} \]

此時我們可以大膽猜想:

  • \(\text{Att}_k\) 只與 \(Q_k\) 有關
  • 已經計算出的 \(\text{Att}\) 永遠都不會改變

帶著這個猜想,繼續生成下面的詞,容易計算第三步可以得到:

\[\begin{aligned}&{\color{red}\text{Att}_1}(Q,K,V)={\color{red}{Q_1}}K_1^T\overrightarrow{V_1}\\&{\color{green}\text{Att}_2}(Q,K,V)={\color{green}{Q_2}}K_1^T\overrightarrow{V_1}+{\color{green}{Q_2}}K_2^T\overrightarrow{V_2}\\&{\color{blue}\text{Att}_3}(Q,K,V)={\color{blue}{Q_3}}K_1^T\overrightarrow{V_1}+{\color{blue}{Q_3}}K_2^T\overrightarrow{V_2}+{\color{blue}{Q_3}}K_3^T\overrightarrow{V_3}\end{aligned} \]

同樣的,\(\text{Att}_k\) 只與 \(Q_k\) 有關。第四步也相同:

看上面的圖和公式,我們可以歸納出性質:

  1. 樸素的 Attention 計算存在大量冗餘
  2. \(\text{Att}_k\) 只與 \(Q_k\) 有關,即預測詞 \(x_k\) 僅依賴於 \(x_{k-1}\)
  3. \(K\)\(V\) 全程參與計算,可以快取起來
  4. 雖然叫做 KV-cache,但其實真正最佳化掉的是冗餘的 \(Q\)\(\text{Att}\)

當然,這有點類似於動態規劃思想,也存在利用空間換取時間的問題,因此當序列很長時,KV-cache 有可能會出現記憶體爆炸的情況。

下面附上 gpt 的 KV-cache 程式碼,非常簡單,僅僅是做了 concat 操作。不過值得注意的是,attention 的計算並沒有使用 cache。

if layer_past is not None:
        past_key, past_value = layer_past
        key = torch.cat((past_key, key), dim=-2)
        value = torch.cat((past_value, value), dim=-2)
    
    if use_cache is True:
        present = (key, value)
    else:
        present = None
    
    if self.reorder_and_upcast_attn:
        attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
    else:
        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

相關文章