直接理解轉置卷積(Transposed convolution)的各種情況

頎周發表於2020-10-29

  使用GAN生成影像必不可少的層就是上取樣,其中最常用的就是轉置卷積(Transposed Convolution)。如果把卷積操作轉換為矩陣乘法的形式,轉置卷積實際上就是將其中的矩陣進行轉置,從而產生逆向的效果。所謂效果僅僅在於特徵圖的形狀,也就是說,如果卷積將特徵圖從形狀a對映到形狀b,其對應的轉置卷積就是從形狀b對映回形狀a,而其中的值並不一一對應,是不可逆的。另外,不要把逆卷積(Deconvolution)和轉置卷積混淆,逆卷積的目標在於構建輸入特徵圖的稀疏編碼(Sparse coding),並不是以上取樣為目的的。但是轉置卷積的確是來源於逆卷積,關於逆卷積與轉置卷積的論文請看[1][2]。

  下面直接對轉置卷積的各種情況進行舉例,從而全面理解轉置卷積在Pytorch中的運算機制。使用Pytorch而不是TF的原因在於,TF中的padding方式只有兩種,即valid與same,並不能很好地幫我們理解原理。而且TF和Pytorch插入0值的方式有些差異,雖然在模型層面,你只需關注模型輸入輸出的形狀,隱層的微小差異可以通過訓練來抵消,但是為了更好得把握模型結構,最好還是使用Pytorch。

  對於Pytorch的nn.ConvTranspose2d()的引數,下面的討論不考慮膨脹度dilation,預設為1;output_padding就是在最終的輸出特徵外面再加上幾層0,所以也不討論,預設為0;為了便於理解,bias也忽略不計,設為False;不失一般性,輸入輸出的channels都設為1。除了對將卷積轉換成矩陣乘法的理解外,理解難點主要在於stride和padding的變化對轉置卷積產生的影響,因此下面我們主要變化kernel_size、stride、padding三個引數來分析各種情況。

  舉例之前要注意,轉換為矩陣的形式是由卷積的結果得到的,矩陣形式本身是不能直接獲得的。要注意這個因果關係,轉換為矩陣形式是為了便於理解,以及推導轉置卷積。

例項分析

kernel_size = 2, stride = 1, padding = 0

  首先是kernel_size = 2,stride=1,padding=0的情況,如下圖: 

  圖中上半部分表示將卷積轉換為矩陣乘法的形式。在卷積中,我們是輸入一個3x3的特徵圖,輸出2x2的特徵圖,矩陣乘法形式如上圖上中部分所示;轉置卷積就是將這個矩陣乘法反過來,如上圖下中部分所示。然後將下中部分的矩陣乘法轉換為卷積的形式,即可得到轉置卷積的示意圖如上圖右下部分所示。

kernel_size = 2, stride = 1, padding = 1

  然後是kernel_size = 2,stride=1,padding=1的情況(因為第一張圖中已有,虛線與註釋都不加了):

  與上一張圖的主要不同之處在於轉置卷積將卷積結果的最外層去掉,這是因為padding=1,也正符合與卷積相反的操作。也就是說,padding越大,轉置卷積就會去掉越多的外層,輸出就會越小。

kernel_size = 3, stride = 1, padding = 1

  為了分析轉置卷積的卷積核與卷積的卷積核的區別,這次把kernel_size變為3,如下圖:

  可以看出,轉置卷積的先將輸入padding 2層,用於抵消卷積核帶來的規模上的減小,從而將輸出擴增到相對應卷積操作的輸入大小。然後,我們可以發現,卷積核是輸入的卷積核的逆序。也就是說,我們輸入函式中的是1~9的方陣,而它實際作為卷積核的是9~1的方陣。最後,因為padding=1,這對於卷積操作是向外加一層0,而對於逆卷積,就是去掉最外面的一層,所以得到最終3x3的結果。

kernel_size = 2, stride = 2, padding = 1

  最後,分析stride對轉置卷積的影響,將stride設為2,如下圖:

  分析在圖中都已寫明。你可能會奇怪,為什麼這裡轉置卷積最終輸出與卷積的輸入形狀不同,這是因為卷積的padding並沒有被全都用上(只計算了一邊),而轉置卷積最後卻把兩邊的padding都去掉了,所以造成了卷積與轉置卷積不對應的情況。

總結

  經過對以上各種例項的分析,對於某個$kernel \,size=k,stride=s,padding=p$的轉置卷積,如果輸入寬高都為$n$,則輸出寬高為

 $\begin{aligned} m&=ns-(s-1)+2(k-1)-(k-1)-2p\\ &=(n-1)s-2p+k   \\ \end{aligned}$

  實際上,卷積與轉置卷積除了輸入輸出的形狀上相反以外,沒有別的聯絡,所以我們只要會計算轉置卷積輸出的形狀即可。

  以上圖都是用excel作的,已上傳至部落格園檔案,需要的可以下載(點選連結)。

參考文獻 

  [1] Zeiler M D, Krishnan D, Taylor G W, et al. Deconvolutional networks[C]. Computer Vision and Pattern Recognition, 2010.

  [2] Zeiler M D, Fergus R. Visualizing and Understanding Convolutional Networks[C]. European Conference on Computer Vision, 2013.

相關文章