概論
a = torch.randn(3, 2, 2)
b = torch.randn(3)
c = torch.einsum('...chw,c->...hw', a, b)
上面的 einsum 如何計算的?
簡單說,把 b 廣播為 a 的形狀,然後做矩陣乘法,即逐位相乘運算,注意,不是點積,是逐位的相乘運算。
然後,再把結果逐位相加後,得到結果,同時也去掉了維度c。
運算過程
具體運算細節如下:
為了詳細解釋 c = torch.einsum('...chw,c->...hw', a, b)
的計算過程,我們可以逐步分析每個部分的運算,並透過一個具體的例子說明結果的產生過程。
1. 張量 a
和 b
的形狀與內容
a
是一個形狀為(3, 2, 2)
的張量,假設其值為:a = torch.tensor([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]], [[0.9, 1.0], [1.1, 1.2]]])
b
是一個形狀為(3,)
的張量,假設其值為:b = torch.tensor([2.0, 3.0, 4.0])
2. einsum
表示式 '...chw,c->...hw'
解析
-
...chw
:...
匹配任意數量的前導維度,在本例中沒有前導維度。c
對應的是第一個維度(形狀為3)。h
對應第二個維度(形狀為2)。w
對應第三個維度(形狀為2)。
-
c->...hw
:c
對應b
中的元素,作為縮並維度,它會在計算中被消除(透過乘法與求和操作)。...hw
表示最終保留高度(h)和寬度(w)維度。
3. 具體計算過程
步驟 1:廣播 b
以匹配 a
的形狀
張量 b
(形狀 (3,)
)會被廣播到形狀 (3, 2, 2)
,與張量 a
對應的維度對齊:
b_broadcasted = torch.tensor([[[2.0, 2.0],
[2.0, 2.0]],
[[3.0, 3.0],
[3.0, 3.0]],
[[4.0, 4.0],
[4.0, 4.0]]])
步驟 2:逐元素相乘
a
和廣播後的 b
逐元素相乘,結果為:
result = a * b_broadcasted
result = torch.tensor([[[0.2, 0.4],
[0.6, 0.8]],
[[1.5, 1.8],
[2.1, 2.4]],
[[3.6, 4.0],
[4.4, 4.8]]])
步驟 3:沿著 c
維度求和
對上面的結果,沿著 c
維度(也就是第一個維度)求和:
- 對
(0.2, 1.5, 3.6)
這三個數在c
維度上相加,得到5.3
。 - 對
(0.4, 1.8, 4.0)
這三個數在c
維度上相加,得到6.2
。 - 以此類推,求得最終結果
c
:
c = torch.tensor([[5.3, 6.2],
[7.1, 8.0]])
4. 總結結果
最終的結果 c
是一個形狀為 (2, 2)
的張量。每個元素是透過 a
中與 b
對應的元素逐元素相乘後,再沿 c
維度(通道維度)求和得到的。
因此,使用 torch.einsum('...chw,c->...hw', a, b)
後,得到的張量 c
對應的是去掉 c
維度後的加權和,僅保留高度 h
和寬度 w
維度的結果。
廣播運算
問:b 廣播為 b_broadcasted ,程式碼如何寫
答
在 PyTorch 中,可以使用 unsqueeze
和 expand
方法來將張量 b
廣播為與 a
形狀匹配的 b_broadcasted
。具體程式碼如下:
import torch
# 原始張量 a 和 b
a = torch.randn(3, 2, 2)
b = torch.randn(3)
# 擴充套件 b 的維度,使其能夠廣播到與 a 形狀匹配
b_broadcasted = b.unsqueeze(-1).unsqueeze(-1).expand_as(a)
print(b_broadcasted)
步驟解釋:
unsqueeze(-1)
:將b
的形狀從(3,)
變為(3, 1)
,即在最後一維新增一個大小為1
的新維度。unsqueeze(-1)
:再一次將形狀從(3, 1)
變為(3, 1, 1)
,在最後再新增一個大小為1
的維度。expand_as(a)
:將b
廣播到與a
相同的形狀(3, 2, 2)
。
結果示例:
假設 b
的值為 [2.0, 3.0, 4.0]
,那麼 b_broadcasted
將變為:
tensor([[[2.0, 2.0],
[2.0, 2.0]],
[[3.0, 3.0],
[3.0, 3.0]],
[[4.0, 4.0],
[4.0, 4.0]]])
這就是張量 b
在計算過程中如何被廣播為 b_broadcasted
的程式碼實現。