torch 處理 tensor 張量的廣播,使用 einsum 函式,摘錄一段使用程式碼,並分析用法
# In[6]:
img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)
batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)
batch_gray_weighted_fancy.shape
# Out[6]:
torch.Size([2, 5, 5])
這段程式碼利用 einsum
函式來進行張量運算,並將每個通道的影像加權轉換為灰度影像。einsum
的使用使得程式碼簡潔且易於理解。:
程式碼解讀
-
img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)
img_t
是一個影像張量,其形狀假設為(3, 5, 5)
,表示3個通道的5x5影像(如RGB影像)。weights
是一個權重張量,其形狀為(3,)
,對應於每個通道的權重。'...chw'
表示任意形狀的張量(...
),其中c
是通道維度,h
是高度維度,w
是寬度維度。因為沒有前導維度,img_t
的形狀具體為(3, 5, 5)
。'c->...hw'
表示將通道維度c
與weights
的權重值相乘,並將其求和後,留下高度和寬度維度(h, w
)。最終結果是一個沒有通道維度的二維張量,其形狀為(5, 5)
,即灰度影像。
總結:
img_gray_weighted_fancy
是一個5x5的灰度影像,原來的RGB影像被weights
加權後生成。 -
batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)
batch_t
是一個批次影像張量,假設形狀為(2, 3, 5, 5)
,表示2個樣本、每個樣本有3個通道(如RGB),每個通道是5x5的影像。- 依然使用
weights
作為權重張量,形狀為(3,)
,對應每個通道的權重。 '...chw'
這裡的...
表示前導的維度,即批次維度,這個索引表示式允許處理不同大小的張量。'c->...hw'
指示要將通道維度c
加權求和,結果保留批次、高度和寬度維度。對於批次影像,每個樣本都會生成一個灰度影像。
總結:
batch_gray_weighted_fancy
的形狀為(2, 5, 5)
,表示2個樣本對應的5x5灰度影像。 -
batch_gray_weighted_fancy.shape
- 列印
batch_gray_weighted_fancy
的形狀,得到torch.Size([2, 5, 5])
。
總結:最終輸出的形狀表明,處理後的批次影像變為2個5x5的灰度影像,每個樣本的RGB影像都透過通道加權轉換為了灰度影像。
- 列印
總覽
einsum
表示式'...chw,c->...hw'
的含義是:對張量中指定的通道維度c
進行加權求和,去除通道維度,保留批次、行和列的維度。- 這段程式碼透過
einsum
函式簡化了將彩色影像(多通道)轉換為灰度影像的過程,避免了顯式的廣播和求和操作,極大地提升了程式碼的簡潔性和可讀性。
提問
上述程式碼"...chw,c->...hw"中,為什麼兩次使用 ... ,若變為"chw,c->...hw"可以嗎?
回答:
在 einsum
表示式中,...
是一個萬用字元,表示任意數量的前導維度。這裡的 ...
可以匹配0個或多個維度。表示式中的 ...chw,c->...hw
使用 ...
是為了處理任意數量的前導維度,而不僅僅是一個具體的維度。
為什麼使用 ...
兩次?
-
'...chw,c->...hw'
中的...chw
:- 這裡的
...
可以匹配batch_t
的前導維度。例如,如果batch_t
的形狀是(2, 3, 5, 5)
,那麼...
匹配的就是2
。 chw
表示通道、高度、寬度,其中c
是通道維度,h
是高度維度,w
是寬度維度。
- 這裡的
-
'...chw,c->...hw'
中的...hw
:- 表示結果張量中保留前導維度
...
和高度、寬度hw
,但去除了通道維度c
。
- 表示結果張量中保留前導維度
這意味著 ...
萬用字元可以讓這個表示式更通用,無論張量前面有多少個維度,它都可以適用。比如:
- 對於一個形狀為
(3, 5, 5)
的影像張量,...chw
會匹配(3, 5, 5)
,...hw
會匹配(5, 5)
。 - 對於一個形狀為
(2, 3, 5, 5)
的批次影像張量,...chw
會匹配(2, 3, 5, 5)
,...hw
會匹配(2, 5, 5)
。
如果改為 'chw,c->...hw'
可以嗎?
不可以。原因如下:
-
沒有萬用字元
...
來匹配前導維度:- 如果你使用
'chw,c->...hw'
,這裡chw
固定表示只有3個維度:通道、高度、寬度。 - 這個表示式不能匹配有前導維度的情況(如批次維度),即它只能處理形狀精確為
(3, h, w)
的張量。
- 如果你使用
-
無法處理批次張量:
- 如果
batch_t
的形狀是(2, 3, 5, 5)
,chw
不能正確處理這個形狀,因為它缺少了批次維度(2)
,結果會導致錯誤。 ...
萬用字元的使用允許einsum
表示式能夠自動適應任意數量的前導維度(例如批次維度),使得程式碼更加通用。
- 如果
總結
使用 ...
萬用字元兩次的原因是為了讓這個 einsum
表示式適用於不同的張量形狀,無論是否有前導維度(如批次維度)。這使得程式碼更加靈活和通用。如果你去掉其中的一個 ...
,表示式的適用性將變得更窄,無法處理具有額外維度的張量。
問:...chw,c->...hw 變成 ...chw,c->hw,會有什麼後果
答:
如果將 einsum
表示式從 ...chw,c->...hw
改為 ...chw,c->hw
,會導致運算中前導維度(如批次維度)被丟棄或移除,從而改變了結果張量的形狀和語義。下面詳細解釋其後果。
原始表示式:...chw,c->...hw
...chw
:匹配任意數量的前導維度(...
),以及通道、高度、寬度。c->...hw
:消除了通道維度c
,保留了前導維度...
和高度、寬度hw
。
對於形狀為 (2, 3, 5, 5)
的 batch_t
張量(其中 2
是批次維度),運算後得到形狀為 (2, 5, 5)
的張量。前導維度(批次維度)被保留。
修改後的表示式:...chw,c->hw
...chw
:同樣匹配任意數量的前導維度以及通道、高度、寬度。c->hw
:這裡的hw
只保留了高度和寬度維度,完全忽略了前導維度...
。
後果:
-
前導維度丟失:
- 前導維度
...
(例如批次維度)被丟棄。 - 如果輸入張量有前導維度(如批次維度
2
),這些維度在結果中將不再存在。
- 前導維度
-
結果形狀:
- 假設
batch_t
的形狀是(2, 3, 5, 5)
:...chw
中的...
匹配2
(批次維度)。- 由於結果只保留
hw
(高度和寬度),最終輸出的形狀將是(5, 5)
,即沒有了批次維度。
- 這意味著,無論輸入張量的前導維度是什麼,最終結果都是將每個批次的影像摺疊成一個單獨的高度-寬度影像,導致批次資訊丟失。
- 假設
-
語義混亂:
- 在深度學習處理中,保留批次維度通常是非常重要的。批次維度的丟失意味著無法將結果與原始輸入資料一一對應。
- 如果你期望處理一組影像並保留每張影像的結果,但由於誤操作丟失了批次維度,那麼在後續步驟中處理這些結果會非常混亂。
結論
將 einsum
表示式從 ...chw,c->...hw
改為 ...chw,c->hw
會導致丟失前導維度(如批次維度)。如果前導維度被丟棄,輸出將不再保留批次的結構資訊,這在許多情況下可能是不可取的。通常你需要保留前導維度,除非你的具體應用場景明確不需要它們。