torch.einsum 的用法例項

立体风發表於2024-08-09

torch 處理 tensor 張量的廣播,使用 einsum 函式,摘錄一段使用程式碼,並分析用法

# In[6]:
img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)
batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)
batch_gray_weighted_fancy.shape

# Out[6]:
torch.Size([2, 5, 5])

這段程式碼利用 einsum 函式來進行張量運算,並將每個通道的影像加權轉換為灰度影像。einsum 的使用使得程式碼簡潔且易於理解。:

程式碼解讀

  1. img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)

    • img_t 是一個影像張量,其形狀假設為 (3, 5, 5),表示3個通道的5x5影像(如RGB影像)。
    • weights 是一個權重張量,其形狀為 (3,),對應於每個通道的權重。
    • '...chw' 表示任意形狀的張量(...),其中 c 是通道維度,h 是高度維度,w 是寬度維度。因為沒有前導維度,img_t 的形狀具體為 (3, 5, 5)
    • 'c->...hw' 表示將通道維度 cweights 的權重值相乘,並將其求和後,留下高度和寬度維度(h, w)。最終結果是一個沒有通道維度的二維張量,其形狀為 (5, 5),即灰度影像。

    總結img_gray_weighted_fancy 是一個5x5的灰度影像,原來的RGB影像被 weights 加權後生成。

  2. batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)

    • batch_t 是一個批次影像張量,假設形狀為 (2, 3, 5, 5),表示2個樣本、每個樣本有3個通道(如RGB),每個通道是5x5的影像。
    • 依然使用 weights 作為權重張量,形狀為 (3,),對應每個通道的權重。
    • '...chw' 這裡的 ... 表示前導的維度,即批次維度,這個索引表示式允許處理不同大小的張量。
    • 'c->...hw' 指示要將通道維度 c 加權求和,結果保留批次、高度和寬度維度。對於批次影像,每個樣本都會生成一個灰度影像。

    總結batch_gray_weighted_fancy 的形狀為 (2, 5, 5),表示2個樣本對應的5x5灰度影像。

  3. batch_gray_weighted_fancy.shape

    • 列印 batch_gray_weighted_fancy 的形狀,得到 torch.Size([2, 5, 5])

    總結:最終輸出的形狀表明,處理後的批次影像變為2個5x5的灰度影像,每個樣本的RGB影像都透過通道加權轉換為了灰度影像。

總覽

  • einsum 表示式 '...chw,c->...hw' 的含義是:對張量中指定的通道維度 c 進行加權求和,去除通道維度,保留批次、行和列的維度。
  • 這段程式碼透過 einsum 函式簡化了將彩色影像(多通道)轉換為灰度影像的過程,避免了顯式的廣播和求和操作,極大地提升了程式碼的簡潔性和可讀性。

提問

上述程式碼"...chw,c->...hw"中,為什麼兩次使用 ... ,若變為"chw,c->...hw"可以嗎?

回答:

einsum 表示式中,... 是一個萬用字元,表示任意數量的前導維度。這裡的 ... 可以匹配0個或多個維度。表示式中的 ...chw,c->...hw 使用 ... 是為了處理任意數量的前導維度,而不僅僅是一個具體的維度。

為什麼使用 ... 兩次?

  1. '...chw,c->...hw' 中的 ...chw

    • 這裡的 ... 可以匹配 batch_t 的前導維度。例如,如果 batch_t 的形狀是 (2, 3, 5, 5),那麼 ... 匹配的就是 2
    • chw 表示通道、高度、寬度,其中 c 是通道維度,h 是高度維度,w 是寬度維度。
  2. '...chw,c->...hw' 中的 ...hw

    • 表示結果張量中保留前導維度 ... 和高度、寬度 hw,但去除了通道維度 c

這意味著 ... 萬用字元可以讓這個表示式更通用,無論張量前面有多少個維度,它都可以適用。比如:

  • 對於一個形狀為 (3, 5, 5) 的影像張量,...chw 會匹配 (3, 5, 5)...hw 會匹配 (5, 5)
  • 對於一個形狀為 (2, 3, 5, 5) 的批次影像張量,...chw 會匹配 (2, 3, 5, 5)...hw 會匹配 (2, 5, 5)

如果改為 'chw,c->...hw' 可以嗎?

不可以。原因如下:

  1. 沒有萬用字元 ... 來匹配前導維度

    • 如果你使用 'chw,c->...hw',這裡 chw 固定表示只有3個維度:通道、高度、寬度。
    • 這個表示式不能匹配有前導維度的情況(如批次維度),即它只能處理形狀精確為 (3, h, w) 的張量。
  2. 無法處理批次張量

    • 如果 batch_t 的形狀是 (2, 3, 5, 5)chw 不能正確處理這個形狀,因為它缺少了批次維度 (2),結果會導致錯誤。
    • ... 萬用字元的使用允許 einsum 表示式能夠自動適應任意數量的前導維度(例如批次維度),使得程式碼更加通用。

總結

使用 ... 萬用字元兩次的原因是為了讓這個 einsum 表示式適用於不同的張量形狀,無論是否有前導維度(如批次維度)。這使得程式碼更加靈活和通用。如果你去掉其中的一個 ...,表示式的適用性將變得更窄,無法處理具有額外維度的張量。

問:...chw,c->...hw 變成 ...chw,c->hw,會有什麼後果

答:
如果將 einsum 表示式從 ...chw,c->...hw 改為 ...chw,c->hw,會導致運算中前導維度(如批次維度)被丟棄或移除,從而改變了結果張量的形狀和語義。下面詳細解釋其後果。

原始表示式:...chw,c->...hw

  • ...chw:匹配任意數量的前導維度(...),以及通道、高度、寬度。
  • c->...hw:消除了通道維度 c,保留了前導維度 ... 和高度、寬度 hw

對於形狀為 (2, 3, 5, 5)batch_t 張量(其中 2 是批次維度),運算後得到形狀為 (2, 5, 5) 的張量。前導維度(批次維度)被保留。

修改後的表示式:...chw,c->hw

  • ...chw:同樣匹配任意數量的前導維度以及通道、高度、寬度。
  • c->hw:這裡的 hw 只保留了高度和寬度維度,完全忽略了前導維度 ...

後果:

  1. 前導維度丟失

    • 前導維度 ...(例如批次維度)被丟棄。
    • 如果輸入張量有前導維度(如批次維度 2),這些維度在結果中將不再存在。
  2. 結果形狀

    • 假設 batch_t 的形狀是 (2, 3, 5, 5)
      • ...chw 中的 ... 匹配 2(批次維度)。
      • 由於結果只保留 hw(高度和寬度),最終輸出的形狀將是 (5, 5),即沒有了批次維度。
    • 這意味著,無論輸入張量的前導維度是什麼,最終結果都是將每個批次的影像摺疊成一個單獨的高度-寬度影像,導致批次資訊丟失。
  3. 語義混亂

    • 在深度學習處理中,保留批次維度通常是非常重要的。批次維度的丟失意味著無法將結果與原始輸入資料一一對應。
    • 如果你期望處理一組影像並保留每張影像的結果,但由於誤操作丟失了批次維度,那麼在後續步驟中處理這些結果會非常混亂。

結論

einsum 表示式從 ...chw,c->...hw 改為 ...chw,c->hw 會導致丟失前導維度(如批次維度)。如果前導維度被丟棄,輸出將不再保留批次的結構資訊,這在許多情況下可能是不可取的。通常你需要保留前導維度,除非你的具體應用場景明確不需要它們。

相關文章