einsum函式介紹-張量常用操作

MapleTx發表於2022-05-08

einsum函式說明

pytorch文件說明:\(torch.einsum(equation, **operands)\) 使用基於愛因斯坦求和約定的符號,將輸入operands的元素沿指定的維數求和。einsum允許計算許多常見的多維線性代數陣列運算,方法是基於愛因斯坦求和約定以簡寫格式表示它們。主要是省略了求和號,總體思路是在箭頭左邊用一些下標標記輸入operands的每個維度,並在箭頭右邊定義哪些下標是輸出的一部分。通過將operands元素與下標不屬於輸出的維度的乘積求和來計算輸出。其方便之處在於可以直接通過求和公式寫出運算程式碼。

# 矩陣乘法例子引入
a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等價操作 torch.mm(a, b)

兩個基本概念,自由索引/自由標(Free indices)和求和索引/啞標(Summation indices):

  • 自由索引,出現在箭頭右邊的索引
  • 求和索引,只出現在箭頭左邊的索引,表示中間計算結果需要這個維度上求和之後才能得到輸出,

接著是介紹三條基本規則:

  • 規則一,equation 箭頭左邊,在不同輸入之間重複出現的索引表示,把輸入張量沿著該維度做乘法操作,比如還是以上面矩陣乘法為例, "ik,kj->ij",k 在輸入中重複出現,所以就是把 a 和 b 沿著 k 這個維度作相乘操作;
  • 規則二,只出現在 equation 箭頭左邊的索引,表示中間計算結果需要在這個維度上求和,也就是上面提到的求和索引;
  • 規則三,equation 箭頭右邊的索引順序可以是任意的,比如上面的 "ik,kj->ij" 如果寫成 "ik,kj->ji",那麼就是返回輸出結果的轉置,使用者只需要定義好索引的順序,轉置操作會在 einsum 內部完成。

兩條特殊規則:

  • equation 可以不寫包括箭頭在內的右邊部分,那麼在這種情況下,輸出張量的維度會根據預設規則推導。就是把輸入中只出現一次的索引取出來,然後按字母表順序排列,比如上面的矩陣乘法 "ik,kj->ij" 也可以簡化為 "ik,kj",根據預設規則,輸出就是 "ij" 與原來一樣;
  • equation 中支援 "..." 省略號,用於表示使用者並不關心的索引,詳見下方轉置例子

單運算元

獲取對角線元素diagonal

einsum 可以不做求和。舉個例子,獲取二維方陣的對角線元素,結果放入一維向量。

\[A_i = B_{ii} \]

上面,A 是一維向量,B 是二維方陣。使用 einsum 記法,可以寫作 ii->i

torch.einsum('ii->i', torch.randn(4, 4))

# 以下操作互相等價
a = torch.randn(4,4)
c = torch.einsum('ii->i', a)
c = torch.diagonal(a, 0)


跡trace

求解矩陣的跡(trace),即對角線元素的和。

\[t = \Sigma_{i=1}^{n} A_{ii} \]

t 是常量,A 是二維方陣。按照前面的做法,省略 ΣΣ,左右兩邊對調,省去矩陣和 t,剩下的就是ii->或省略箭頭ii

torch.einsum('ii', torch.randn(4, 4))

矩陣轉置

\[A_{ij} = B_{ji} \]

A 和 B 都是二維方陣。einsum 可以表達為 ij->ji

torch.einsum('ij -> ji',a)

pytorch 中,還支援省略前面的維度。比如,只轉置最後兩個維度,可以表達為 ...ij->...ji。下面展示了一個含有四個二維矩陣的三維矩陣,轉置三維矩陣中的每個二維矩陣。

A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4])

# 等價操作
A.permute(0,1,3,2)
A.transpose(2,3)

求和

\[b=\sum_{i} \sum_{j} A_{i j}=A_{i j} \]

a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)

列求和:

\[b_{j}=\sum_{i} A_{i j}=A_{i j} \]

a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3.,  5.,  7.])

# 等價操作
torch.sum(a, 0) # (dim引數0) means the dimension or dimensions to reduce.

雙運算元

矩陣乘法

\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj} \]

第一個學習的 einsum 表示式是,ik,kj->ij。前面提到過,愛因斯坦求和記法可以理解為懶人求和記法。將上述公式中的 ΣΣ 去掉,並且將左右兩邊對調一下,省去矩陣之後,剩下的就是 ik,kj->ij 了。

torch.einsum('ik,kj->ij', a, b) 

# 可用兩個矩陣測試以下矩陣乘法操作互相等價
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b) 
c = torch.mm(a, b) 
c = a @ b

矩陣-向量相乘

\[c_{i}=\sum_{k} A_{i k} b_{k}=A_{i k} b_{k} \]

a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])

tensor([  5.,  14.])

批量矩陣乘 batch matrix multiplication

\[C_{bik}=\sum_{k} A_{bij} B_{bjk}=A_{bij} B_{bjk} \]

>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

# 等價操作
torch.bmm(As, Bs)

向量內積 dot

\[c=\sum_{i} a_{i} b_{i}=a_{i} b_{i} \]

a = torch.arange(3)
b = torch.arange(3,6)  # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.)

# 等價操作
torch.dot(a, b)

矩陣內積 dot

\[c=\sum_{i} \sum_{j} A_{i j} B_{i j}=A_{i j} B_{i j} \]

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)

哈達瑪積

\[C_{i j}=A_{i j} B_{i j} \]

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[  0.,   7.,  16.],
        [ 27.,  40.,  55.]])

外積 outer

\[C_{i j}=a_{i} b_{j} \]

a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b])

tensor([[  0.,   0.,   0.,   0.],
        [  3.,   4.,   5.,   6.],
        [  6.,   8.,  10.,  12.]])

einsum其他規則和例子判斷:

  • 輸入中多次出現的字元,將被用作求和。例子,kj,ji 完整的表示式是 kj,ji->ik,矩陣乘法再相乘。
  • 輸出可以指定,但是輸出中的每個字元必須在輸入中出現至少一次,輸出的每個字元在輸出中只能出現最多一次。例子,ab->aa 是非法的,ab->c 是非法的,ab->a 是合法的。
  • 省略符 ... 是用來跳過部分維度。例子,...ij,...jk 表示 batch 矩陣乘法。
  • 在輸出沒有指定的情況下,省略符優先順序高於普通字元。例子,b...a 完整的表示式是 b...a->...ab,可以將一個形狀為 (a,b,c) 的矩陣變為形狀為 (b,c,a) 的矩陣。
  • 允許多個矩陣輸入,表示式中使用逗號分開不同矩陣輸入的下標。例子,i,i,i 表示將三個一維向量按位相乘,並相加。
  • 除了箭頭,其他任何地方都可以加空格。例子,i j , j k -> ik 是合法的,ij,jk - > ik 是非法的。
  • 輸入的表示式,維度需要和輸入的矩陣對上,不能多也不能少。比如一個 shape 為 (4,3,3) 的矩陣,表示式 ab->a 是非法的,abc-> 是合法的。

實際使用

實現multi headed attention

https://nn.labml.ai/transformers/mha.html
如何優雅地實現多頭自注意力

計算注意力score:

\[Q K^{\top} or S_{i j b h}=\sum_{d} Q_{i b h d} K_{j b h d} \]

# q k v均為 [seq_len, batch_size, heads, d_k] 
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解為ibhd,jbhd->ibhj->ijbh

計算attention輸出:

\[\underset{\text { seq }}{\operatorname{softmax}}\left(\frac{Q K^{\top}}{\sqrt{d_{k}}}\right) V \]

# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k] 

x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]


參考文獻:

https://zhuanlan.zhihu.com/p/361209187
如何優雅地實現多頭自注意力
https://rockt.github.io/2018/04/30/einsum **

相關文章