PyTorch 中 torch.matmul() 函式的文件詳解

Lowell_liu發表於2022-03-04

官方文件

torch.matmul() 函式幾乎可以用於所有矩陣/向量相乘的情況,其乘法規則視參與乘法的兩個張量的維度而定。

關於 PyTorch 中的其他乘法函式可以看這篇博文,有助於下面各種乘法的理解。

torch.matmul() 將兩個張量相乘劃分成了五種情形:一維 × 一維、二維 × 二維、一維 × 二維、二維 × 一維、涉及到三維及三維以上維度的張量的乘法。

以下是五種情形的詳細解釋:

  1. 如果兩個張量都是一維的,即 torch.Size([n]) ,此時返回兩個向量的點積。作用與 torch.dot() 相同,同樣要求兩個一維張量的元素個數相同。

    例如:

    >>> vec1 = torch.tensor([1, 2, 3])
    >>> vec2 = torch.tensor([2, 3, 4])
    >>> torch.matmul(vec1, vec2)
    tensor(20)
    >>> torch.dot(vec1, vec2)
    tensor(20)
    
    # 兩個一維張量的元素個數要相同!
    >>> vec1 = torch.tensor([1, 2, 3])
    >>> vec2 = torch.tensor([2, 3, 4, 5])
    >>> torch.matmul(vec1, vec2)
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: inconsistent tensor size, expected tensor [3] and src [4] to have the same number of elements, but got 3 and 4 elements respectively
    
  2. 如果兩個引數都是二維張量,那麼將返回矩陣乘積。作用與 torch.mm() 相同,同樣要求兩個張量的形狀需要滿足矩陣乘法的條件,即(n×m)×(m×p)=(n×p)

    例如:

    >>> arg1 = torch.tensor([[1, 2], [3, 4]])
    >>> arg1
    tensor([[1, 2],
            [3, 4]])
    >>> arg2 = torch.tensor([[-1], [2]])
    >>> arg2
    tensor([[-1],
            [ 2]])
    >>> torch.matmul(arg1, arg2)
    tensor([[3],
            [5]])
    >>> torch.mm(arg1, arg2)
    tensor([[3],
            [5]])
    
    >>> arg2 = torch.tensor([[-1], [2], [1]])
    >>> torch.matmul(arg1, arg2)					# 要求滿足矩陣乘法的條件
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x2 and 3x1)
    
  3. 如果第一個引數是一維張量,第二個引數是二維張量,那麼在一維張量的前面增加一個維度,然後進行矩陣乘法,矩陣乘法結束後移除新增的維度。文件原文為:“a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.”

    例如:

    >>> arg1 = torch.tensor([-1, 2])
    >>> arg2 = torch.tensor([[1, 2], [3, 4]])
    >>> torch.matmul(arg1, arg2)
    tensor([5, 6])
    
    >>> arg1 = torch.unsqueeze(arg1, 0)			# 在一維張量前增加一個維度
    >>> arg1.shape
    torch.Size([1, 2])
    >>> ans = torch.mm(arg1, arg2)				# 進行矩陣乘法
    >>> ans
    tensor([[5, 6]])
    >>> ans = torch.squeeze(ans, 0)				# 移除增加的維度
    >>> ans
    tensor([5, 6])
    
  4. 如果第一個引數是二維張量(矩陣),第二個引數是一維張量(向量),那麼將返回矩陣×向量的積。作用與 torch.mv() 相同。另外要求矩陣的形狀和向量的形狀滿足矩陣乘法的要求。

    例如:

    >>> arg1 = torch.tensor([[1, 2], [3, 4]])
    >>> arg2 = torch.tensor([-1, 2])
    >>> torch.matmul(arg1, arg2)
    tensor([3, 5])
    
    >>> torch.mv(arg1, arg2)
    tensor([3, 5])
    
  5. 如果兩個引數均至少為一維,且其中一個引數的 ndim > 2,那麼……(一番處理),然後進行批量矩陣乘法。

    這條規則將所有涉及到三維張量及三維以上的張量(下文稱為高維張量)的乘法分為三類:一維張量 × 高維張量、高維張量 × 一維張量、二維及二維以上的張量 × 二維及二維以上的張量。

    1. 如果第一個引數是一維張量,那麼在此張量之前增加一個維度。

      文件原文為:“ If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after.”

    2. 如果第二個引數是一維張量,那麼在此張量之後增加一個維度。

      文件原文為:“If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. ”

    3. 由於上述兩個規則,所有涉及到一維張量和高維張量的乘法都被轉變為二維及二維以上的張量 × 二維及二維以上的張量。

      然後除掉最右邊的兩個維度,對剩下的維度進行廣播。原文為:“The non-matrix dimensions are broadcasted.”

      然後就可以進行批量矩陣乘法。

      For example, if input is a (j × 1 × n × n) tensor and other is a (k × n × n) tensor, out will be a (j × k × n × n) tensor.

    舉例如下:

    >>> arg1 = torch.tensor([1, 2, -1, 1])
    >>> arg2 = torch.randint(low=-2, high=3, size=[3, 4, 1])
    >>> torch.matmul(arg1, arg2)
    tensor([[ 5],
            [-1],
            [-1]])
            
    >>> arg2
    tensor([[[ 2],
             [ 2],
             [-1],
             [-2]],
    
            [[-2],
             [ 2],
             [ 1],
             [-2]],
    
            [[ 0],
             [ 0],
             [-1],
             [-2]]])
    

    根據第一條規則,先對 arg1 增加維度:

    >>> arg3 = torch.unsqueeze(arg1, 0)
    >>> arg3
    tensor([[ 1,  2, -1,  1]])
    >>> arg3.shape
    torch.Size([1, 4])
    

    由於 arg2.shape=torch.Size([3, 4, 1]) ,根據廣播的規則,arg3 要被廣播為 torch.Size([3, 1, 4]) ,也就是下面的 arg4

    >>> arg4 = torch.tensor([ [[ 1,  2, -1,  1]], [[ 1,  2, -1,  1]], [[ 1,  2, -1,  1]] ])
    >>> arg4
    tensor([[[ 1,  2, -1,  1]],
    
            [[ 1,  2, -1,  1]],
    
            [[ 1,  2, -1,  1]]])
    >>> arg4.shape
    torch.Size([3, 1, 4])
    

    最後我們使用乘法函式 torch.bmm() 來進行批量矩陣乘法:

    >>> torch.bmm(arg4, arg2)
    tensor([[[ 5]],
    
            [[-1]],
    
            [[-1]]])
    

    由於在第一條規則中對一維張量增加了維度,因此矩陣計算結束後要移除這個維度。移除之後和前面使用 torch.matmul() 的結果相同!

PS:在看文件第五條規則時,起先也非常不明白,試了很多次高維和一維的張量乘法總是提示RuntimeError: mat1 and mat2 shapes cannot be multiplied,然後就嘗試理解這條規則。因為這條規則很長,分成了三個小情形,並且這三個情形並不是一一獨立的,而是前兩個情形經過處理之後最後全都可以轉變成第三個情形。另一個理解的突破口是 prependedappended 這兩個單詞,通過它們的字首可以猜測出:一個是在張量前面增加維度,一個是在張量後面增加維度,然後廣播再進行批量矩陣乘法就驗證出來了!

相關文章