Scaled Dot-Product Attention 的公式中為什麼要除以 $sqrt{d_k}$?

赤川鹤鸣發表於2024-10-22

Scaled Dot-Product Attention 的公式中為什麼要除以 \(\sqrt{d_k}\)

在學習 Scaled Dot-Product Attention 的過程中,遇到了如下公式

\[ \mathrm{Attention} (\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathrm{softmax} \left( \dfrac{\mathbf{Q} \mathbf{K}}{\sqrt{d_k}} \right) \mathbf{V} \]

不禁產生疑問,其中的 \(\sqrt{d_k}\) 為什麼是這個數,而不是 \(d_k\) 或者其它的什麼值呢?

Attention Is All You Need 中有一段解釋

We suspect that for large values of \(d_k\), the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by \(\sqrt{d_k}\).

這說明,兩個向量的點積可能很大,導致 softmax 函式的梯度太小,因此需要除以一個因子,但是為什麼是 \(\sqrt{d_k}\) 呢?

文章中的一行註釋提及到

To illustrate why the dot products get large, assume that the components of \(\mathbf{q}\) and \(\mathbf{k}\) are independent random variables with mean \(0\) and variance \(1\). Then their dot product, $\mathbf{q} \cdot \mathbf{k} = \sum_{i=1}^{d_k} q_i k_i $ has mean \(0\) and variance \(d_k\).

本期,我們將基於上文的思路進行完整的推導,以證明 \(\sqrt{d_k}\) 的在其中的作用.

基本假設

假設獨立隨機變數 \(U_1 ,\ U_2 ,\ \dots ,\ U_{d_k}\) 和獨立隨機變數 \(V_1 ,\ V_2 ,\ \dots ,\ V_{d_k}\) 分別服從期望為 \(0\),方差為 \(1\) 的分佈,即

\[E \left(U_i \right) = 0 ,\ \mathrm{Var} \left(U_i \right) = 1 \]

\[E \left(V_i \right) = 0 ,\ \mathrm{Var} \left(V_i \right) = 1 \]

其中 \(i = 1, 2, \dots ,\ d_k\)\(d_k\) 是個常數.

計算 $U_i V_i $ 的方差

由隨機變數方差的定義可得 $ U_i V_i $ 的方差為

\[\begin{align*} \mathrm{Var} \left( U_i V_i \right) &= E \left[ \left( U_i V_i - E \left( U_i V_i \right) \right)^2\right] \\ &= E \left[ \left(U_i V_i \right)^2 - 2U_i V_i E \left( U_i V_i \right) + E^2 \left( U_i V_i \right)\right] \\ &= E \left[ \left( U_i V_i \right)^2 \right] - 2 E \left[ U_i V_i E \left( U_i V_i \right) \right] + E^2 \left(U_i V_i\right) \\ &= E \left( U_i^2 V_i^2 \right) - 2 E \left( U_i V_i \right) E \left( U_i V_i \right) + E^2 \left(U_i V_i\right) \\ &= E \left( U_i^2 V_i^2 \right) - E^2 \left( U_i V_i \right) \end{align*} \]

因為 \(U_i\)\(V_i\) 是獨立的隨機變數,所以

\[E \left( U_i V_i \right) = E \left( U_i \right) E \left( V_i \right) \]

從而

\[\begin{align*} \mathrm{Var} \left( U_i V_i \right) &= E\left(U_i^2\right) E\left(V_i^2\right) - \left(E\left(U_i\right) E\left(V_i\right) \right)^2 \\ &= E\left(U_i^2\right) E\left(V_i^2\right) - E^2\left(U_i\right) E^2\left(V_i\right) \end{align*} \]

又因為 \(E(U_i) = E(V_i) = 0\),所以

\[\mathrm{Var} \left( U_i V_i \right) = E(U_i^2) E(V_i^2) \]

計算 \(E(U_i^2)\)

因為

\[ E \left( U_i \right) = 0 \]

\[\mathrm{Var} \left( U_i \right) = 1 \]

\[\mathrm{Var} \left( U_i \right) = E \left( U_i^2 \right) - E^2 \left( U_i \right) \]

所以

\[E(U_i^2) = 1 \]

同理,

\[E(V_i^2) = 1 \]

計算 \(\mathbf{q} \mathbf{k}\) 的方差

如果 \(\mathbf{q} = \left[U_1, U_2, \cdots, U_{d_k} \right]^T\)\(\mathbf{k} = \left[V_1, V_2, \cdots, V_{d_k} \right]^T\),那麼

\[\mathbf{q} \mathbf{k} = \sum_{i=1}^{d_k} U_i V_i \]

\(\mathbf{q} \mathbf{k}\) 的方差

\[\begin{align*} \mathrm{Var}\left( \mathbf{q} \mathbf{k} \right) &= \mathrm{Var}\left( \sum_{i=1}^{d_k} U_i V_i \right) \\ &= \sum_{i=1}^{d_k} \mathrm{Var} \left( U_i V_i \right) \\ &= \sum_{i=1}^{d_k} E \left(U_i^2\right) E \left(V_i^2\right) \\ &= \sum_{i=1}^{d_k} 1 \cdot 1 \\ &= d_k \end{align*} \]

到這裡就可以解釋為什麼在最後要除以 \(\sqrt{d_k}\),因為

\[\begin{align*} \mathrm{Var}\left( \dfrac{\mathbf{q} \mathbf{k} }{\sqrt{d_k}} \right) &= \dfrac{\mathrm{Var}\left( \mathbf{q} \mathbf{k} \right)}{d_k} \\ &= \dfrac{d_k}{d_k} \\ &= 1 \end{align*} \]

可見這個因子的目的是讓 \(\mathbf{q} \mathbf{k}\) 的分佈也歸一化到期望為 \(0\),方差為 \(1\) 的分佈中,增強機器學習的穩定性.

參考文獻/資料

  • Attention Is All You Need

相關文章