einsum函式說明
pytorch文件說明:\(torch.einsum(equation, **operands)\) 使用基於愛因斯坦求和約定的符號,將輸入operands的元素沿指定的維數求和。einsum允許計算許多常見的多維線性代數陣列運算,方法是基於愛因斯坦求和約定以簡寫格式表示它們。主要是省略了求和號,總體思路是在箭頭左邊用一些下標標記輸入operands的每個維度,並在箭頭右邊定義哪些下標是輸出的一部分。通過將operands元素與下標不屬於輸出的維度的乘積求和來計算輸出。其方便之處在於可以直接通過求和公式寫出運算程式碼。
# 矩陣乘法例子引入
a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等價操作 torch.mm(a, b)
兩個基本概念,自由索引/自由標(Free indices)和求和索引/啞標(Summation indices):
- 自由索引,出現在箭頭右邊的索引
- 求和索引,只出現在箭頭左邊的索引,表示中間計算結果需要這個維度上求和之後才能得到輸出,
接著是介紹三條基本規則:
- 規則一,equation 箭頭左邊,在不同輸入之間重複出現的索引表示,把輸入張量沿著該維度做乘法操作,比如還是以上面矩陣乘法為例, "ik,kj->ij",k 在輸入中重複出現,所以就是把 a 和 b 沿著 k 這個維度作相乘操作;
- 規則二,只出現在 equation 箭頭左邊的索引,表示中間計算結果需要在這個維度上求和,也就是上面提到的求和索引;
- 規則三,equation 箭頭右邊的索引順序可以是任意的,比如上面的 "ik,kj->ij" 如果寫成 "ik,kj->ji",那麼就是返回輸出結果的轉置,使用者只需要定義好索引的順序,轉置操作會在 einsum 內部完成。
兩條特殊規則:
- equation 可以不寫包括箭頭在內的右邊部分,那麼在這種情況下,輸出張量的維度會根據預設規則推導。就是把輸入中只出現一次的索引取出來,然後按字母表順序排列,比如上面的矩陣乘法 "ik,kj->ij" 也可以簡化為 "ik,kj",根據預設規則,輸出就是 "ij" 與原來一樣;
- equation 中支援 "..." 省略號,用於表示使用者並不關心的索引,詳見下方轉置例子
單運算元
獲取對角線元素diagonal
einsum 可以不做求和。舉個例子,獲取二維方陣的對角線元素,結果放入一維向量。
上面,A 是一維向量,B 是二維方陣。使用 einsum 記法,可以寫作 ii->i
torch.einsum('ii->i', torch.randn(4, 4))
# 以下操作互相等價
a = torch.randn(4,4)
c = torch.einsum('ii->i', a)
c = torch.diagonal(a, 0)
跡trace
求解矩陣的跡(trace),即對角線元素的和。
t 是常量,A 是二維方陣。按照前面的做法,省略 ΣΣ,左右兩邊對調,省去矩陣和 t,剩下的就是ii->
或省略箭頭ii
torch.einsum('ii', torch.randn(4, 4))
矩陣轉置
A 和 B 都是二維方陣。einsum 可以表達為 ij->ji
。
torch.einsum('ij -> ji',a)
pytorch 中,還支援省略前面的維度。比如,只轉置最後兩個維度,可以表達為 ...ij->...ji
。下面展示了一個含有四個二維矩陣的三維矩陣,轉置三維矩陣中的每個二維矩陣。
A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4])
# 等價操作
A.permute(0,1,3,2)
A.transpose(2,3)
求和
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)
列求和:
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3., 5., 7.])
# 等價操作
torch.sum(a, 0) # (dim引數0) means the dimension or dimensions to reduce.
雙運算元
矩陣乘法
第一個學習的 einsum 表示式是,ik,kj->ij
。前面提到過,愛因斯坦求和記法可以理解為懶人求和記法。將上述公式中的 ΣΣ 去掉,並且將左右兩邊對調一下,省去矩陣之後,剩下的就是 ik,kj->ij
了。
torch.einsum('ik,kj->ij', a, b)
# 可用兩個矩陣測試以下矩陣乘法操作互相等價
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b)
c = torch.mm(a, b)
c = a @ b
矩陣-向量相乘
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])
tensor([ 5., 14.])
批量矩陣乘 batch matrix multiplication
>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]],
[[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]],
[[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
# 等價操作
torch.bmm(As, Bs)
向量內積 dot
a = torch.arange(3)
b = torch.arange(3,6) # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.)
# 等價操作
torch.dot(a, b)
矩陣內積 dot
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)
哈達瑪積
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[ 0., 7., 16.],
[ 27., 40., 55.]])
外積 outer
a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b])
tensor([[ 0., 0., 0., 0.],
[ 3., 4., 5., 6.],
[ 6., 8., 10., 12.]])
einsum其他規則和例子判斷:
- 輸入中多次出現的字元,將被用作求和。例子,
kj,ji
完整的表示式是kj,ji->ik
,矩陣乘法再相乘。 - 輸出可以指定,但是輸出中的每個字元必須在輸入中出現至少一次,輸出的每個字元在輸出中只能出現最多一次。例子,
ab->aa
是非法的,ab->c
是非法的,ab->a
是合法的。 - 省略符
...
是用來跳過部分維度。例子,...ij,...jk
表示 batch 矩陣乘法。 - 在輸出沒有指定的情況下,省略符優先順序高於普通字元。例子,
b...a
完整的表示式是b...a->...ab
,可以將一個形狀為(a,b,c)
的矩陣變為形狀為(b,c,a)
的矩陣。 - 允許多個矩陣輸入,表示式中使用逗號分開不同矩陣輸入的下標。例子,
i,i,i
表示將三個一維向量按位相乘,並相加。 - 除了箭頭,其他任何地方都可以加空格。例子,
i j , j k -> ik
是合法的,ij,jk - > ik
是非法的。 - 輸入的表示式,維度需要和輸入的矩陣對上,不能多也不能少。比如一個 shape 為
(4,3,3)
的矩陣,表示式ab->a
是非法的,abc->
是合法的。
實際使用
實現multi headed attention
https://nn.labml.ai/transformers/mha.html
如何優雅地實現多頭自注意力
計算注意力score:
# q k v均為 [seq_len, batch_size, heads, d_k]
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解為ibhd,jbhd->ibhj->ijbh
計算attention輸出:
# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k]
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]
參考文獻:
https://zhuanlan.zhihu.com/p/361209187
如何優雅地實現多頭自注意力
https://rockt.github.io/2018/04/30/einsum **