授人以魚不如授人以漁,原汁原味的知識才更富有精華,本文只是對張量基本操作知識的理解和學習筆記,看完之後,想要更深入理解,建議去 pytorch 官方網站,查閱相關函式和操作,英文版在這裡,中文版在這裡。本文的程式碼是在
pytorch1.7
版本上測試的,其他版本一般也沒問題。
一,張量的基本操作
Pytorch
中,張量的操作分為結構操作和數學運算,其理解就如字面意思。結構操作就是改變張量本身的結構,數學運算就是對張量的元素值完成數學運算。
- 常使用的張量結構操作:維度變換(
tranpose
、view
等)、合併分割(split
、chunk
等)、索引切片(index_select
、gather
等)。 - 常使用的張量數學運算:標量運算、向量運算、矩陣運算。
二,維度變換
2.1,squeeze vs unsqueeze 維度增減
squeeze()
:對 tensor 進行維度的壓縮,去掉維數為1
的維度。用法:torch.squeeze(a)
將 a 中所有為 1 的維度都刪除,或者a.squeeze(1)
是去掉a
中指定的維數為1
的維度。unsqueeze()
:對資料維度進行擴充,給指定位置加上維數為1
的維度。用法:torch.unsqueeze(a, N)
,或者a.unsqueeze(N)
,在a
中指定位置N
加上一個維數為1
的維度。
squeeze
用例程式如下:
a = torch.rand(1,1,3,3)
b = torch.squeeze(a)
c = a.squeeze(1)
print(b.shape)
print(c.shape)
程式輸出結果如下:
torch.Size([3, 3])
torch.Size([1, 3, 3])
unsqueeze
用例程式如下:
x = torch.rand(3,3)
y1 = torch.unsqueeze(x, 0)
y2 = x.unsqueeze(0)
print(y1.shape)
print(y2.shape)
程式輸出結果如下:
torch.Size([1, 3, 3])
torch.Size([1, 3, 3])
2.2,transpose vs permute 維度交換
torch.transpose()
只能交換兩個維度,而 .permute()
可以自由交換任意位置。函式定義如下:
transpose(dim0, dim1) → Tensor # See torch.transpose()
permute(*dims) → Tensor # dim(int). Returns a view of the original tensor with its dimensions permuted.
在 CNN
模型中,我們經常遇到交換維度的問題,舉例:四個維度表示的 tensor:[batch, channel, h, w]
(nchw
),如果想把 channel
放到最後去,形成[batch, h, w, channel]
(nhwc
),如果使用 torch.transpose()
方法,至少要交換兩次(先 1 3
交換再 1 2
交換),而使用 .permute()
方法只需一次操作,更加方便。例子程式如下:
import torch
input = torch.rand(1,3,28,32) # torch.Size([1, 3, 28, 32]
print(b.transpose(1, 3).shape) # torch.Size([1, 32, 28, 3])
print(b.transpose(1, 3).transpose(1, 2).shape) # torch.Size([1, 28, 32, 3])
print(b.permute(0,2,3,1).shape) # torch.Size([1, 28, 28, 3]
三,索引切片
3.1,規則索引切片方式
張量的索引切片方式和 numpy
、python 多維列表幾乎一致,都可以透過索引和切片對部分元素進行修改。切片時支援預設引數和省略號。例項程式碼如下:
>>> t = torch.randint(1,10,[3,3])
>>> t
tensor([[8, 2, 9],
[2, 5, 9],
[3, 9, 9]])
>>> t[0] # 第 1 行資料
tensor([8, 2, 9])
>>> t[2][2]
tensor(9)
>>> t[0:3,:] # 第1至第3行,全部列
tensor([[8, 2, 9],
[2, 5, 9],
[3, 9, 9]])
>>> t[0:2,:] # 第1行至第2行
tensor([[8, 2, 9],
[2, 5, 9]])
>>> t[1:,-1] # 第2行至最後行,最後一列
tensor([9, 9])
>>> t[1:,::2] # 第1行至最後行,第0列到最後一列每隔兩列取一列
tensor([[2, 9],
[3, 9]])
以上切片方式相對規則,對於不規則的切片提取,可以使用 torch.index_select
, torch.take
, torch.gather
, torch.masked_select
。
3.2,gather 和 torch.index_select 運算元
gather
運算元的用法比較難以理解,在翻閱了官方文件和網上資料後,我有了一些自己的理解。
1,gather
是不規則的切片提取運算元(Gathers values along an axis specified by dim. 在指定維度上根據索引 index 來選取資料)。函式定義如下:
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
引數解釋:
input
(Tensor) – the source tensor.dim
(int) – the axis along which to index.index
(LongTensor) – the indices of elements to gather.
gather
運算元的注意事項:
- 輸入
input
和索引index
具有相同數量的維度,即input.shape = index.shape
- 對於任意維數,只要
d != dim
,index.size(d) <= input.size(d),即對於可以不用索引維數d
上的全部資料。 - 輸出
out
和 索引index
具有相同的形狀。輸入和索引不會相互廣播。
對於 3D tensor,output
值的定義如下:
gather
的官方定義如下:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
透過理解前面的一些定義,相信讀者對 gather
運算元的用法有了一個基本瞭解,下面再結合 2D 和 3D tensor 的用例來直觀理解運算元用法。
(1),對於 2D tensor 的例子:
>>> import torch
>>> a = torch.arange(0, 16).view(4,4)
>>> a
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
>>> index = torch.tensor([[0, 1, 2, 3]]) # 選取對角線元素
>>> torch.gather(a, 0, index)
tensor([[ 0, 5, 10, 15]])
output
值定義如下:
# 按照 index = tensor([[0, 1, 2, 3]])順序作用在行上索引依次為0,1,2,3
a[0][0] = 0
a[1][1] = 5
a[2][2] = 10
a[3][3] = 15
(2),索引更復雜的 2D tensor 例子:
>>> t = torch.tensor([[1, 2], [3, 4]])
>>> t
tensor([[1, 2],
[3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1, 1],
[ 4, 3]])
output
值的計算如下:
output[i][j] = input[i][index[i][j]] # if dim = 1
output[0][0] = input[0][index[0][0]] = input[0][0] = 1
output[0][1] = input[0][index[0][1]] = input[0][0] = 1
output[1][0] = input[1][index[1][0]] = input[1][1] = 4
output[1][1] = input[1][index[1][1]] = input[1][0] = 3
總結:可以看到 gather
是透過將索引在指定維度 dim
上的值替換為 index
的值,但是其他維度索引不變的情況下獲取 tensor
資料。直觀上可以理解為對矩陣進行重排,比如對每一行(dim=1)的元素進行變換,比如 torch.gather(a, 1, torch.tensor([[1,2,0], [1,2,0]]))
的作用就是對 矩陣 a
每一行的元素,進行 permtute(1,2,0)
操作。
2,理解了 gather
再看 index_select
就很簡單,函式作用是返回沿著輸入張量的指定維度的指定索引號進行索引的張量子集。函式定義如下:
torch.index_select(input, dim, index, *, out=None) → Tensor
函式返回一個新的張量,它使用資料型別為 LongTensor
的 index
中的條目沿維度 dim
索引輸入張量。返回的張量具有與原始張量(輸入)相同的維數。 維度尺寸與索引長度相同; 其他尺寸與原始張量中的尺寸相同。例項程式碼如下:
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]])
四,合併分割
4.1,torch.cat 和 torch.stack
可以用 torch.cat
方法和 torch.stack
方法將多個張量合併,也可以用 torch.split
方法把一個張量分割成多個張量。torch.cat
和 torch.stack
有略微的區別,torch.cat
是連線,不會增加維度,而 torch.stack
是堆疊,會增加一個維度。兩者函式定義如下:
# Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
torch.cat(tensors, dim=0, *, out=None) → Tensor
# Concatenates a sequence of tensors along **a new** dimension. All tensors need to be of the same size.
torch.stack(tensors, dim=0, *, out=None) → Tensor
torch.cat
和 torch.stack
用法例項程式碼如下:
>>> a = torch.arange(0,9).view(3,3)
>>> b = torch.arange(10,19).view(3,3)
>>> c = torch.arange(20,29).view(3,3)
>>> cat_abc = torch.cat([a,b,c], dim=0)
>>> print(cat_abc.shape)
torch.Size([9, 3])
>>> print(cat_abc)
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[10, 11, 12],
[13, 14, 15],
[16, 17, 18],
[20, 21, 22],
[23, 24, 25],
[26, 27, 28]])
>>> stack_abc = torch.stack([a,b,c], axis=0) # torch中dim和axis引數名可以混用
>>> print(stack_abc.shape)
torch.Size([3, 3, 3])
>>> print(stack_abc)
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[10, 11, 12],
[13, 14, 15],
[16, 17, 18]],
[[20, 21, 22],
[23, 24, 25],
[26, 27, 28]]])
>>> chunk_abc = torch.chunk(cat_abc, 3, dim=0)
>>> chunk_abc
(tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]),
tensor([[10, 11, 12],
[13, 14, 15],
[16, 17, 18]]),
tensor([[20, 21, 22],
[23, 24, 25],
[26, 27, 28]]))
4.2,torch.split 和 torch.chunk
torch.split()
和 torch.chunk()
可以看作是 torch.cat()
的逆運算。split()
作用是將張量拆分為多個塊,每個塊都是原始張量的檢視。split()
函式定義如下:
"""
Splits the tensor into chunks. Each chunk is a view of the original tensor.
If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.
If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.
"""
torch.split(tensor, split_size_or_sections, dim=0)
chunk()
作用是將 tensor
按 dim
(行或列)分割成 chunks
個 tensor
塊,返回的是一個元組。chunk()
函式定義如下:
torch.chunk(input, chunks, dim=0) → List of Tensors
"""
Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.
Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.
Parameters:
input (Tensor) – the tensor to split
chunks (int) – number of chunks to return
dim (int) – dimension along which to split the tensor
"""
例項程式碼如下:
>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
[2, 3]]),
tensor([[4, 5],
[6, 7]]),
tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
>>> torch.chunk(a, 2, dim=1)
(tensor([[0],
[2],
[4],
[6],
[8]]),
tensor([[1],
[3],
[5],
[7],
[9]]))
五,卷積相關運算元
5.1,上取樣方法總結
上取樣大致被總結成了三個類別:
- 基於線性插值的上取樣:最近鄰演算法(
nearest
)、雙線性插值演算法(bilinear
)、雙三次插值演算法(bicubic
)等,這是傳統影像處理方法。 - 基於深度學習的上取樣(轉置卷積,也叫反摺積
Conv2dTranspose2d
等) Unpooling
的方法(簡單的補零或者擴充操作)
計算效果:最近鄰插值演算法 < 雙線性插值 < 雙三次插值。計算速度:最近鄰插值演算法 > 雙線性插值 > 雙三次插值。
5.2,F.interpolate 取樣函式
Pytorch 老版本有
nn.Upsample
函式,新版本建議用torch.nn.functional.interpolate
,一個函式可實現定製化需求的上取樣或者下采樣功能,。
F.interpolate()
函式全稱是 torch.nn.functional.interpolate()
,函式定義如下:
def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811
# type: (Tensor, Optional[int], Optional[List[float]], str, Optional[bool], Optional[bool]) -> Tensor
pass
引數解釋如下:
input
(Tensor):輸入張量資料;size
: 輸出的尺寸,資料型別為 tuple: ([optional D_out], [optional H_out], W_out),和scale_factor
二選一。scale_factor
:在高度、寬度和深度上面的放大倍數。資料型別既可以是 int——表明高度、寬度、深度都擴大同一倍數;也可是
tuple`——指定高度、寬度、深度等維度的擴大倍數。mode
: 上取樣的方法,包括最近鄰(nearest
),線性插值(linear
),雙線性插值(bilinear
),三次線性插值(trilinear
),預設是最近鄰(nearest
)。align_corners
: 如果設為True
,輸入影像和輸出影像角點的畫素將會被對齊(aligned),這隻在mode = linear, bilinear, or trilinear
才有效,預設為False
。
例子程式如下:
import torch.nn.functional as F
x = torch.rand(1,3,224,224)
y = F.interpolate(x * 2, scale_factor=(2, 2), mode='bilinear').squeeze(0)
print(y.shape) # torch.Size([3, 224, 224)
5.3,nn.ConvTranspose2d 反摺積
轉置卷積(有時候也稱為反摺積,個人覺得這種叫法不是很規範),它是一種特殊的卷積,先 padding
來擴大影像尺寸,緊接著跟正向卷積一樣,旋轉卷積核 180 度,再進行卷積計算。