torch.einsum 的計算過程

立体风發表於2024-08-09

概論

a = torch.randn(3, 2, 2)
b = torch.randn(3)
c = torch.einsum('...chw,c->...hw', a, b)

上面的 einsum 如何計算的?
簡單說,把 b 廣播為 a 的形狀,然後做矩陣乘法,即逐位相乘運算,注意,不是點積,是逐位的相乘運算。
然後,再把結果逐位相加後,得到結果,同時也去掉了維度c。

運算過程

具體運算細節如下:

為了詳細解釋 c = torch.einsum('...chw,c->...hw', a, b) 的計算過程,我們可以逐步分析每個部分的運算,並透過一個具體的例子說明結果的產生過程。

1. 張量 ab 的形狀與內容

  • a 是一個形狀為 (3, 2, 2) 的張量,假設其值為:
    a = torch.tensor([[[0.1, 0.2],
                       [0.3, 0.4]],
    
                      [[0.5, 0.6],
                       [0.7, 0.8]],
    
                      [[0.9, 1.0],
                       [1.1, 1.2]]])
    
  • b 是一個形狀為 (3,) 的張量,假設其值為:
    b = torch.tensor([2.0, 3.0, 4.0])
    

2. einsum 表示式 '...chw,c->...hw' 解析

  • ...chw:

    • ... 匹配任意數量的前導維度,在本例中沒有前導維度。
    • c 對應的是第一個維度(形狀為3)。
    • h 對應第二個維度(形狀為2)。
    • w 對應第三個維度(形狀為2)。
  • c->...hw:

    • c 對應 b 中的元素,作為縮並維度,它會在計算中被消除(透過乘法與求和操作)。
    • ...hw 表示最終保留高度(h)和寬度(w)維度。

3. 具體計算過程

步驟 1:廣播 b 以匹配 a 的形狀

張量 b(形狀 (3,))會被廣播到形狀 (3, 2, 2),與張量 a 對應的維度對齊:

b_broadcasted = torch.tensor([[[2.0, 2.0],
                               [2.0, 2.0]],

                              [[3.0, 3.0],
                               [3.0, 3.0]],

                              [[4.0, 4.0],
                               [4.0, 4.0]]])

步驟 2:逐元素相乘

a 和廣播後的 b 逐元素相乘,結果為:

result = a * b_broadcasted
result = torch.tensor([[[0.2, 0.4],
                        [0.6, 0.8]],

                       [[1.5, 1.8],
                        [2.1, 2.4]],

                       [[3.6, 4.0],
                        [4.4, 4.8]]])

步驟 3:沿著 c 維度求和

對上面的結果,沿著 c 維度(也就是第一個維度)求和:

  • (0.2, 1.5, 3.6) 這三個數在 c 維度上相加,得到 5.3
  • (0.4, 1.8, 4.0) 這三個數在 c 維度上相加,得到 6.2
  • 以此類推,求得最終結果 c
c = torch.tensor([[5.3, 6.2],
                  [7.1, 8.0]])

4. 總結結果

最終的結果 c 是一個形狀為 (2, 2) 的張量。每個元素是透過 a 中與 b 對應的元素逐元素相乘後,再沿 c 維度(通道維度)求和得到的。

因此,使用 torch.einsum('...chw,c->...hw', a, b) 後,得到的張量 c 對應的是去掉 c 維度後的加權和,僅保留高度 h 和寬度 w 維度的結果。

廣播運算

問:b 廣播為 b_broadcasted ,程式碼如何寫

在 PyTorch 中,可以使用 unsqueezeexpand 方法來將張量 b 廣播為與 a 形狀匹配的 b_broadcasted。具體程式碼如下:

import torch

# 原始張量 a 和 b
a = torch.randn(3, 2, 2)
b = torch.randn(3)

# 擴充套件 b 的維度,使其能夠廣播到與 a 形狀匹配
b_broadcasted = b.unsqueeze(-1).unsqueeze(-1).expand_as(a)

print(b_broadcasted)

步驟解釋:

  1. unsqueeze(-1):將 b 的形狀從 (3,) 變為 (3, 1),即在最後一維新增一個大小為 1 的新維度。
  2. unsqueeze(-1):再一次將形狀從 (3, 1) 變為 (3, 1, 1),在最後再新增一個大小為 1 的維度。
  3. expand_as(a):將 b 廣播到與 a 相同的形狀 (3, 2, 2)

結果示例:

假設 b 的值為 [2.0, 3.0, 4.0],那麼 b_broadcasted 將變為:

tensor([[[2.0, 2.0],
         [2.0, 2.0]],

        [[3.0, 3.0],
         [3.0, 3.0]],

        [[4.0, 4.0],
         [4.0, 4.0]]])

這就是張量 b 在計算過程中如何被廣播為 b_broadcasted 的程式碼實現。

相關文章