PyTorch 反摺積運算(一)

jiang_huixin發表於2021-01-05

反摺積是一種特殊的正向卷積操作, 通過補零的方式擴大輸入影像的尺寸, 接著翻轉卷積核, 和普通卷積一樣進行正向卷積, 由於前期補充了大量的零, 即便進行了卷積運算, 輸出影像的尺寸依然比輸入影像大, 這樣就達到了向上取樣的目的


下面展示一些例項, 使用 PyTorch 計算反摺積


示例-1: 輸入輸出通道數都是1, 步長也為1

輸入資料 1*1*3*3(Batch 和 Channel 均為 1)

In [1]: import torch

In [2]: from torch import nn

In [3]: torch.manual_seed(0)
Out[3]: <torch._C.Generator at 0x7f17986464b0>

In [4]: x = torch.randint(5, size=(1,1,3,3), dtype=torch.float)

In [5]: x
Out[5]: 
tensor([[[[4., 4., 3.],
          [0., 3., 4.],
          [2., 3., 2.]]]])

卷積核 1*1*2*2

In [6]: w = torch.tensor([[1,2],[0,1]], dtype=torch.float).view(1,1,2,2)

In [7]: w
Out[7]: 
tensor([[[[1., 2.],
          [0., 1.]]]])

建立反摺積層, 不使用偏置(方便計算)

In [8]: layer = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, bias=False)

In [9]: layer.weight = nn.Parameter(w)

In [10]: layer.eval()
Out[10]: ConvTranspose2d(1, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)

手動計算

  1. 翻轉卷積核
PyTorch 反摺積運算(一)
  1. 填充零

對於正常卷積 oup = (inp + 2*padding - kernel_size) / stride + 1, 反摺積可以看做普通卷積的映象操作, 反摺積的輸入對應正向卷積的輸出, 令 oup=3, 得 inp=4(padding=0, kernel_size=2, stride=1), 所以輸出尺寸為 4*4

步長為 1 時, 僅在四周補零, 為了得到 4*4 的輸出, 需要補一圈零

PyTorch 反摺積運算(一)
  1. 正向卷積
PyTorch 反摺積運算(一)

使用 PyTorch 計算

# 無需正向傳播
In [17]: with torch.no_grad():
    ...:     y = layer(x)
    ...: 

In [18]: y
Out[18]: 
tensor([[[[ 4., 12., 11.,  6.],
          [ 0.,  7., 14., 11.],
          [ 2.,  7., 11.,  8.],
          [ 0.,  2.,  3.,  2.]]]])

修改反摺積層, 額外設定 padding=1

# padding=1
In [30]: layer_2 = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=1, bias=False)

In [31]: layer_2.weight = nn.Parameter(w)

In [32]: layer_2.eval()
Out[32]: ConvTranspose2d(1, 1, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1), bias=False)

根據 oup = (inp + 2*padding - kernel_size) / stride + 1, 代入 oup=3, padding=1, kernel_size=2, stride=1, 得 inp=2, 反摺積後特徵圖尺寸為 2*2, 此時無需填充零, 計算過程為

PyTorch 反摺積運算(一)

使用 PyTorch 驗證

In [33]: with torch.no_grad():
    ...:     y = layer_2(x)
    ...: 

In [34]: y
Out[34]: 
tensor([[[[ 7., 14.],
          [ 7., 11.]]]])

相關文章