示例:
import torch box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000], [0.6000, 0.6000, 0.9000, 0.9000], [0.1000, 0.1000, 0.2000, 0.2000]], [[0.1000, 0.2000, 0.5000, 0.3000], [0.6000, 0.6000, 0.9000, 0.9000], [0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32) wh = torch.tensor([[[200.], [400.], [200.], [400.]], [[200.], [400.], [200.], [400.]]]).to(torch.float32) print(box.shape) # (2, 3 ,4) print(wh.shape) # (2, 4, 1) result = box @ wh print(result.shape) # (2, 3, 1) print(result) # tensor([[[320.], # [900.], # [180.]], # [[320.], # [900.], # [180.]]])
下面這個示例用到了廣播機制:
import torch box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000], [0.6000, 0.6000, 0.9000, 0.9000], [0.1000, 0.1000, 0.2000, 0.2000]], [[0.1000, 0.2000, 0.5000, 0.3000], [0.6000, 0.6000, 0.9000, 0.9000], [0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32) wh = torch.tensor([[[200.], [400.], [200.], [400.]]]).to(torch.float32) print(box.shape) # (2, 3 ,4) print(wh.shape) # (1, 4, 1) 注意這個wh的第0維度的大小是1 result = box @ wh # 這裡在第0維度會使用廣播機制 print(result.shape) # (2, 3, 1) print(result) # tensor([[[320.], # [900.], # [180.]], # [[320.], # [900.], # [180.]]])