pytorch多維張量相乘和廣播機制示例

Picassooo發表於2024-10-26

示例:

import torch 
box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]],

        [[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32)

wh = torch.tensor([[[200.],
         [400.],
         [200.],
         [400.]],

        [[200.],
         [400.],
         [200.],
         [400.]]]).to(torch.float32)

print(box.shape)  # (2, 3 ,4)
print(wh.shape)  # (2, 4, 1)

result = box @ wh
print(result.shape)  # (2, 3, 1)
print(result)
# tensor([[[320.],
#          [900.],
#          [180.]],

#         [[320.],
#          [900.],
#          [180.]]])

  

下面這個示例用到了廣播機制:

import torch 
box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]],

        [[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32)

wh = torch.tensor([[[200.],
         [400.],
         [200.],
         [400.]]]).to(torch.float32)

print(box.shape)  # (2, 3 ,4)
print(wh.shape)  # (1, 4, 1)   注意這個wh的第0維度的大小是1

result = box @ wh  # 這裡在第0維度會使用廣播機制
print(result.shape)  # (2, 3, 1)
print(result)
# tensor([[[320.],
#          [900.],
#          [180.]],

#         [[320.],
#          [900.],
#          [180.]]])

  

相關文章