命名張量的解釋

立体风發表於2024-08-10
import torch

把3個顏色grb通道合併為一個灰度通道

定義變數,用隨機數模擬

img_t = torch.randn(3, 5, 5)
batch_t = torch.randn(2, 3, 5, 5)
weights = torch.randn(3)

樸素的求法

以 channels 通道的平均數為灰度值

img_gray_naive = img_t.mean(-3)
batch_gray_naive = batch_t.mean(-3)
img_gray_naive.shape, batch_gray_naive.shape
(torch.Size([5, 5]), torch.Size([2, 5, 5]))

帶權重的求法

樸素的求法相當於平分了權重值,即每個通道的權重都是 1/3,
weights中包含了每個通道真正的權重值,img_t 中每個值乘以權重才是真正的值

首先讓 weights 擴充套件成 img_t 一樣的維度結構

unsqueeze_weights = weights.unsqueeze(-1).unsqueeze(-1)

其次 img_t 和 unsqueeze_weights 進行矩陣乘法,把權重作用到 img_t 上面

img_weights = img_t * unsqueeze_weights
batch_weights = batch_t * unsqueeze_weights

最後把各個維度的權重相加得到最終的灰度值

img_gray_weighted = img_weights.sum(-3)
batch_gray_weighted = batch_weights.sum(-3)

檢視形狀可知,沒有了顏色通道維度

img_gray_weighted.shape, batch_gray_weighted.shape
(torch.Size([5, 5]), torch.Size([2, 5, 5]))

高階做法-使用愛因斯坦求和約定

img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)
batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)
img_gray_weighted_fancy.shape, batch_gray_weighted_fancy.shape
(torch.Size([5, 5]), torch.Size([2, 5, 5]))

用命名張量的做法

用 refine_names 函式為張量的每個維度命名

weights_named = weights.refine_names(..., 'channels')
img_named = img_t.refine_names(..., 'channels', 'rows', 'columns')
batch_named = batch_t.refine_names(..., 'channels', 'rows', 'columns')

把一維的 weights_named 張量對齊到 多維的張量 img_named

weights_aligned = weights_named.align_as(img_named)

矩陣乘法,再求和(給定命名的維度名稱)

img_gray_named = (img_named * weights_aligned).sum('channels')
batch_gray_named = (batch_named * weights_aligned).sum('channels')
img_gray_named.shape, batch_gray_named.shape, weights_aligned.shape
(torch.Size([5, 5]), torch.Size([2, 5, 5]), torch.Size([3, 1, 1]))

相關文章