理解pytorch幾個高階選擇函式(如gather)

Alex發表於2020-09-30

1. 引言

  最近在刷開源的Pytorch版動手學深度學習,裡面談到幾個高階選擇函式,如index_select,masked_select,gather等。這些函式大多很容易理解,但是對於gather函式,確實有些難理解,官方文件開始也看得一臉懵,感覺不太直觀。下面談談我對這幾個函式的一些理解。

2. 維度的理解

  對於numpy和pytorch,其陣列在做維度運算上剛開始可能會給人一種直觀上的誤解,以numpy求矩陣某個維度的最大值為例(pytorch的理解也是一樣的)

import numpy as np
a = np.arange(1, 13).reshape(3, 4)
"""
result:
a = [[1, 2, 3, 4],
      [5, 6, 7, 8,],
      [9, 10, 11, 12]]
"""

# 對a維度0求最大值
a.max(axis = 0)
"""
result:
[9, 10, 11, 12]
"""

# 對a維度1求最大值
a.max(axis = 1)
"""
result:
[4, 8, 12]
"""

  如果對a矩陣在維度0上找最大值,根據我們直觀上的經驗應該是[4, 8, 12]。即從[1, 2, 3, 4]找到4,從[5, 6, 7, 8]找到8,從[9, 10, 11, 12]找到12。但是從上面結果來看,numpy運算卻給了我們直觀上認為是列最大值的結果[9, 10, 11, 12]。
  實際numpy(pytorch)運算應該理解為往給定的維度進行移動運算。還是以維度0為例,維度0上有3個向量,分別為[1, 2, 3, 4],[5, 6, 7, 8]和[9, 10, 11, 12]。往維度0移動,即[1, 2, 3, 4]和[5, 6, 7, 8]逐元素計算最大值,得到[5, 6, 7, 8],再和[9, 10, 11, 12]運算得到結果[9, 10, 11, 12]。

維度運算圖1
  另外,對於維度為3的陣列,在numpy和pytorch中,應該把維度0理解為通道數,維度1和維度2才是對應高和寬。如果是3維陣列對應著用於多輸入通道和單輸出通道的卷積核(維度為U x V x D),那麼4維陣列就對應著用於多輸入通道和多輸出通道的卷積核(維度為U x V x D x P),此時,維度0則為多通道卷積核數量的方向,維度1為通道數,維度2和3才是分別對應高和寬。
維度運算圖2

3. gather函式

pytorch和numpy中許多函式都涉及維度運算,gather也不例外,但是它相對於其他函式更難理解。依然先來看一個例子

import torch
a = torch.arange(1, 16).reshape(5, 3)
"""
result:
a = [[1, 2, 3],
      [4, 5, 6],
      [7, 8, 9],
      [10, 11, 12],
      [13, 14, 15]]
"""

# 定義兩個index
b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]])
c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]])

# axis=0
output1 = a.gather(0, b)
"""
result:
[[1, 5, 9],
[7, 11, 15],
[1, 8, 15]]
"""

# axis=1
output2 = a.gather(1, c)
"""
result:
[[2, 3, 1, 3, 2],
[5, 6, 5, 4, 4]]
"""

上面的例子看起來可能有點複雜,我們來一步步的分析它,先從gather維度為0開始講起。

  1. a.gather(0, b)分為3個部分,a是需要被提取元素的矩陣,0代表的是提取的維度為0,b是提取元素的索引
    • 其中規定b和a是同維張量,即a是2維張量,b也必須是2維張量
  2. 0除了代表往維度0的方向提取元素外,還有一個特權---提取結果output可以在這個維度上的長度與a不同。打個比方,a現在的shape為(5, 3),那麼提取結果output1的shape可以是(1,3),(2, 3),甚至(n, 3)。具體維度0的長度到底為多少由b來決定。
  3. 根據0的特權,導致了給定的b張量除了維度0外,其他的維度大小必須和a一樣。其中張量b實際上包含以下兩個資訊
    • b可以利用除用於gather的維度(此處為維度0)外的維度來定位出唯一一個向量,也就是a[:, ?](三維度也是同理的,有a[:, ?1, ?2]),?的取值範圍為a同維度的index。
    • 對於上述定位出的向量,通過b中的元素來定位提取向量中的哪一個元素。
    • 上面說得可能有點抽象,實際上b中的每個元素都能在a中提取出一個元素。舉個具體點的例子,按照上面所說的,b[0, 0]可以提取a中的一個元素。對於b[0,0],除了維度0外,可以通過維度1來定位出唯一一個向量a[:, 0]。因為b[0, 0]的元素為0,即提取的是a[:, 0]的第0個元素---1,並將其作為output1[0, 0]的提取結果。
      下圖給出了維度0和維度1,gather運算的圖示
gather 2維度
對於3維或者更高維度的張量gather的原理也是一樣的
gather 2維度

4. index_select函式

其他的高階選擇函式都比較容易理解,這裡簡單的提一下。torch.index_select主要是根據傳入的tensor來往給定的axis方向來選取張量

import torch
a = torch.arange(9).reshape(3, 3)
torch.index_select(a, 0, torch.tensor([0, 2]))
"""
result:
[[0, 1, 2],
[6, 7, 8]]
"""

5. masked_select函式

實際上就是通過掩碼條件來選擇元素,像torch.masked_select(x, x>0.5),實際上是和x[x>0.5]等價的,最後返回的是一維張量

import torch
a = torch.rand(5, 3)

# 結果和a[a > 0.5]等價
torch.masked_select(a, a>0.5)

6. nonzero函式

找到非零元素的index

import torch
a = torch.eye(3)
torch.nonzero(a)

"""
result: 對應著非零元素的index
[[0, 0],
[1, 1],
[2, 2]]
"""

相關文章