PyTorch 中的乘法:mul()、multiply()、matmul()、mm()、mv()、dot()

Lowell_liu發表於2022-03-03

torch.mul()

函式功能:逐個對 inputother 中對應的元素相乘。

本操作支援廣播,因此 inputother 均可以是張量或者數字。

舉例如下:

>>> import torch
>>> a = torch.randn(3)
>>> a
tensor([-1.7095,  1.7837,  1.1865])
>>> b = 2
>>> torch.mul(a, b)
tensor([-3.4190,  3.5675,  2.3730])     # 這裡將 other 擴充套件成了 input 的形狀
>>> a = 3
>>> b = torch.randn(3, 1)
>>> b
tensor([[-0.7705],
        [ 1.1177],
        [ 1.2447]])
>>> torch.mul(a, b)
tensor([[-2.3116],
        [ 3.3530],
        [ 3.7341]])                     # 這裡將 input 擴充套件成了 other 的形狀
>>> a = torch.tensor([[2], [3]])
>>> a
tensor([[2],
        [3]])                           # a 是 2×1 的張量
>>> b = torch.tensor([-1, 2, 1])
>>> b
tensor([-1,  2,  1])                    # b 是 1×3 的張量
>>> torch.mul(a, b)
tensor([[-2,  4,  2],
        [-3,  6,  3]])

這個例子中,inputoutput 的形狀都不是公共形狀,因此兩個都需要廣播,都變成 2×3 的形狀,然後再逐個元素相乘。

由上述例子可以看出,這種乘法是逐個對應元素相乘,因此 inputoutput 的前後順序並不影響結果,即 torch.mul(a, b) =torch.mul(b, a)

官方文件

torch.multiply()

torch.mul() 的別稱。

torch.dot()

函式功能:計算 inputoutput 的點乘,此函式要求 inputoutput必須是一維的張量(其 shape 屬性中只有一個值)!並且要求兩者元素個數相同

舉例如下:

>>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1]))
tensor(7)

>>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1, 1]))		# 要求兩者元素個數相同
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: inconsistent tensor size, expected tensor [2] and src [3] to have the same number of elements, but got 2 and 3 elements respectively

官方文件

torch.mm()

函式功能:實現線性代數中的矩陣乘法(matrix multiplication):(n×m) × (m×p) = (n×p)

本函式不允許廣播!

舉例如下:

>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 2)
>>> torch.mm(mat1, mat2)
tensor([[-1.1846, -1.8327],
        [ 0.8820,  0.0312]])

官方文件

torch.mv()

函式功能:實現矩陣和向量(matrix × vector)的乘法,要求 input 的形狀為 n×moutputtorch.Size([m])的一維 tensor。

舉例如下:

>>> mat = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> mat
tensor([[1, 2, 3],
        [4, 5, 6]])
>>> vec = torch.tensor([-1, 1, 2])
>>> vec
tensor([-1,  1,  2])
>>> mat.shape
torch.Size([2, 3])
>>> vec.shape
torch.Size([3])
>>> torch.mv(mat, vec)
tensor([ 7, 13])

注意,此函式要求第二個引數是一維 tensor,也即其 ndim 屬性值為 1。這裡我們要區分清楚張量的 shape 屬性和 ndim 屬性,前者表示張量的形狀,後者表示張量的維度。(線性代數中二維矩陣的維度 m×n 通常理解為這裡的形狀)

對於 shape 值為 torch.Size([n])torch.Size(1, n) 的張量,前者的 ndim=1 ,後者的 ndim=2 ,因此前者是可視為線代中的向量,後者可視為線代中的矩陣。

對於 shape 值為 torch.Size([1, n])torch.Size([n, 1]) 的張量,它們同樣在 Pytorch 中被視為矩陣。例如:

>>> column = torch.tensor([[1], [2]])
>>> row = torch.tensor([3, 4])
>>> column.shape
torch.Size([2, 1])				# 矩陣
>>> row.shape
torch.Size([2])					# 一維張量
>>> matrix = torch.randn(1, 3)
>>> matrix.shape
torch.Size([1, 3])				# 矩陣

對於張量(以及線代中的向量和矩陣)的理解可看這篇博文

官方文件

torch.bmm()

函式功能:實現批量的矩陣乘法。

本函式要求 inputoutputndim 均為 3,且前者形狀為 b×n×m,後者形狀為 b×m×p 。可以理解為 input 中包含 b 個形狀為 n×m 的矩陣, output 中包含 b 個形狀為 m×p 的矩陣,然後第一個 n×m 的矩陣 × 第一個 m×p 的矩陣得到第一個 n×p 的矩陣,第二個……,第 b 個……因此最終得到 b 個形狀為 n×p 的矩陣,即最終結果是一個三維張量,形狀為 b×n×p

舉例如下:

>>> batch_matrix_1 = torch.tensor([ [[1, 2], [3, 4], [5, 6]] , [[-1, -2], [-3, -4], [-5, -6]] ])
>>> batch_matrix_1
tensor([[[ 1,  2],
         [ 3,  4],
         [ 5,  6]],

        [[-1, -2],
         [-3, -4],
         [-5, -6]]])
>>> batch_matrix_1.shape
torch.Size([2, 3, 2])

>>> batch_matrix_2 = torch.tensor([ [[1, 2], [3, 4]], [[1, 2], [3, 4]] ])
>>> bat
batch_matrix_1 batch_matrix_2
>>> batch_matrix_2
tensor([[[1, 2],
         [3, 4]],

        [[1, 2],
         [3, 4]]])
>>> batch_matrix_2.shape
torch.Size([2, 2, 2])

>>> torch.bmm(batch_matrix_1, batch_matrix_2)
tensor([[[  7,  10],
         [ 15,  22],
         [ 23,  34]],

        [[ -7, -10],
         [-15, -22],
         [-23, -34]]])

官方文件

torch.matmul()

torch.matmul() 可以用於 PyTorch 中絕大多數的乘法,在不同的情形下,它與上述各個乘法函式起著相同的作用,具體請看這篇博文

相關文章