深度學習中影像上取樣的方法

ZhiboZhao發表於2021-07-17

深度學習中的影像上取樣方法

所謂上取樣,就是將影像從一個較低的尺寸 \([C, H, W]\) 恢復到一個較大的尺寸 \([C, sH, sW]\),其中 \(s\) 是上取樣倍數,從小圖到大圖這一變換過程也叫影像的超解析度重建。影像超解析度重建是一個研究很深入的領域,對於大部分的應用場景,我們不需要對此做過多研究,通常使用一些簡單且常用的方法對影像進行上取樣進行預處理。在這裡我們就介紹幾種簡單的上取樣方式。

如下圖所示,一般的上取樣方式首先將原始影像的尺寸進行放大,空出來很多需要補充的區域,然後通過一定的插值演算法來計算待補充的區域,從而實現影像的放大。常見的插值演算法主要分為傳統的插值演算法和基於深度學習的插值演算法兩類。

一、傳統的插值演算法

傳統的插值演算法主要包括最鄰近插值 (Nearest interpolation),雙線性插值 (Bilinear interpolation) 和雙三次插值 (Bicubic interpolation) 三種方法。

1.1 最鄰近插值 (Nearest interpolation)

最鄰近插值很好立即,就是選取與待填充位置最近的畫素值作為該位置的值。在 PyTorch 中通過 nn.UpsamplingNearest2d() 來實現。具體的程式碼如下:

input = torch.arange(1,5, dtype=torch.float32).view(1,1,2,2)	# 定義 input 輸入
m = nn.UpsamplingNearest2d(scale_factor=2)	# 建立最鄰近插值例項
m(input)	# 計算插值結果

tensor([[[[1., 1., 2., 2.],
          [1., 1., 2., 2.],
          [3., 3., 4., 4.],
          [3., 3., 4., 4.]]]])

1.2 雙線性插值 (Bilinear interpolation)

簡單地將最鄰近的值作為插值會帶來明顯的棋盤效應,因此一個改進的插值演算法是雙線性插值。計算方法如下:

其中,$P$ 為待計算元素,而 $Q_{11}$, $Q_{12}$,$Q_{21}$,$Q_{22}$ 代表與 $P$ 最相鄰的四個元素。計算的方法分為兩步:

先分別對上下兩行做一次插值計算,得到 \(R_{1}\), \(R_{2}\)

\[f(R_{1}) = \dfrac{x_{2}-x}{x_{2}-x_{1}}f(Q_{11}) + \dfrac{x-x_{1}}{x_{2}-x_{1}}f(Q_{21}) \\ f(R_{2}) = \dfrac{x_{2}-x}{x_{2}-x_{1}}f(Q_{12}) + \dfrac{x-x_{1}}{x_{2}-x_{1}}f(Q_{22}) \]

再對 \(R_{1}\), \(R_{2}\) 做一次插值,得到 \(P\)

\[f(P) = \dfrac{y_{2}-y}{y_{2}-y_{1}}f(R_{1}) + \dfrac{y-y_{1}}{y_{2}-y_{1}}f(R_{2}) \]

具體程式碼如下:

n = nn.UpsamplingBilinear2d(scale_factor=2)	# 建立雙線性插值例項
n(input)	# 計算雙線性插值結果
tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
          [1.6667, 2.0000, 2.3333, 2.6667],
          [2.3333, 2.6667, 3.0000, 3.3333],
          [3.0000, 3.3333, 3.6667, 4.0000]]]])

1.3 雙三次插值 (Bicubic interpolation)

雙三次插值(Bicubic interpolation)也有一些文章會翻譯為三線性插值,本文統一同雙三次插值。其根據離待插值最近的4*4=16個已知值來計算待插值,每個已知值的權重由距離待插值距離決定,距離越近權重越大。具體的計算公式在這裡就不再贅述了,有興趣的可以查閱相關資料。程式碼如下:

k = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True)	# 建立雙三次插值例項
k(input)	# 計算雙三次插值輸出
tensor([[[[1.0000, 1.3148, 1.6852, 2.0000],
          [1.6296, 1.9444, 2.3148, 2.6296],
          [2.3704, 2.6852, 3.0556, 3.3704],
          [3.0000, 3.3148, 3.6852, 4.0000]]]])

二、基於深度學習的插值方法

2.1 反摺積 (Transposed Convolution)

反摺積是一種特殊的卷積,總是可以使用一種卷積來模擬反摺積的過程。然而該方式將引入許多‘0’的行和‘0’的列,導致實現上非常的低效。並且,反摺積只能恢復尺寸,並不能恢復數值,因此經常用在神經網路中作為提供恢復的尺寸,具體的數值往往通過訓練得到。程式碼如下:

# 將一個 [1, 1, 3, 3] 的影像通過反摺積將尺寸上取樣為 [1, 1, 5, 5]
input = torch.arange(1,10,dtype=torch.float32).view(1,1,3,3)
Transposed = nn.ConvTranspose2d(1,1,3,stride=2, padding = 1)

Transposed(input)
Out[122]: 
tensor([[[[-0.1766, -0.3358, -0.2509, -0.2865, -0.3252],
          [ 0.5506, -1.5407,  0.6985, -1.7486,  0.8463],
          [-0.3995, -0.1880, -0.4738, -0.1388, -0.5481],
          [ 0.9942, -2.1643,  1.1420, -2.3722,  1.2899],
          [-0.6225, -0.0403, -0.6968,  0.0089, -0.7711]]]])

2.2 亞畫素上取樣 (Pixel Shuffle)

普通的上取樣採用的臨近畫素填充演算法,主要考慮空間因素,沒有考慮channel因素,上取樣的特徵圖人為修改痕跡明顯,影像分割與GAN生成影像中效果不好。為了解決這個問題,ESPCN中提到了亞畫素上取樣方式。具體原理如下:

根據上圖,可以得出將維度為 \([B, C, H, W]\) 的 feature map 通過亞畫素上取樣的方式恢復到維度 \([B, C, sH, sW]\) 的過程分為兩步:

  1. 首先通過卷積進行特徵提取,將 \([B, C, H, W]=>[B, s^{2}C, H, W]\)
  2. 然後通過Pixel Shuffle 的操作,將 \([B, s^{2}C, H, W] => [B, C, sH, sW]\)

Pixel Shuffle的主要功能就是將這 \(s^{2}\) 個通道的特徵圖組合為新的 \([B, C, sH, sW]\) 的上取樣結果。具體來說,就是將原來一個低分辨的畫素劃分為 \(s^{2}\) 個更小的格子,利用 \(s^{2}\) 個特徵圖對應位置的值按照一定的規則來填充這些小格子。按照同樣的規則將每個低分辨畫素劃分出的小格子填滿就完成了重組過程。在這一過程中模型可以調整 \(s^{2}\) 個shuffle通道權重不斷優化生成的結果。

在ESPCN中,具體的實現過程如下:

class Net(nn.Module):
    def __init__(self, upscale_factor):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(32, 1 * (upscale_factor ** 2), (3, 3), (1, 1), (1, 1))	# 最終將輸入轉換成 [32, 9, H, W]
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)	# 通過 Pixel Shuffle 來將 [32, 9, H, W] 重組為 [32, 1, 3H, 3W]
    def forward(self, x):
        x = torch.tanh(self.conv1(x))
        x = torch.tanh(self.conv2(x))
        x = torch.sigmoid(self.pixel_shuffle(self.conv3(x)))
        return x
    
if __name__ == "__main__":
    model = Net(upscale_factor=3)
    input = torch.arange(1, 10, dtype = torch.float32).view(1,1,3,3)
    output = model(input)
    print(output.size())

# 輸出結果為:
torch.Size([1, 1, 9, 9])

相關文章