語義分割網路 U-Net 詳解

AIBigbull2050發表於2019-08-28


Unet 背景介紹

Unet 發表於 2015 年,屬於 FCN 的一種變體,想了解 FCN 可以看我的另一篇 FCN 全卷積網路論文閱讀及程式碼實現 。Unet 的初衷是為了解決生物醫學影像方面的問題,由於效果確實很好後來也被廣泛的應用在語義分割的各個方向,比如衛星影像分割,工業瑕疵檢測等。

Unet 跟 FCN 都是 Encoder-Decoder 結構,結構簡單但很有效。Encoder 負責特徵提取,你可以將自己熟悉的各種特徵提取網路放在這個位置。由於在醫學方面,樣本收集較為困難,作者為了解決這個問題,應用了影像增強的方法,在資料集有限的情況下獲得了不錯的精度。

Unet 網路結構與細節

  • Encoder
語義分割網路 U-Net 詳解

如上圖,Unet 網路結構是對稱的,形似英文字母 U 所以被稱為 Unet。整張圖都是由藍/白色框與各種顏色的箭頭組成,其中, 藍/白色框表示 feature map;藍色箭頭表示 3x3 卷積,用於特徵提取;灰色箭頭表示 skip-connection,用於特徵融合;紅色箭頭表示池化 pooling,用於降低維度;綠色箭頭表示上取樣 upsample,用於恢復維度;青色箭頭表示 1x1 卷積,用於輸出結果

可能你會問為啥是 5 層而不是 4 層或者 6 層,emmm,這應該去問作者本人,可能對於當時作者拿到的資料集來說,這個層數的表現更好,但不代表所有的資料集這個結構都適合。我們該多關注這種 Encoder-Decoder 的設計思想,具體實現則應該因資料集而異。

Encoder 由卷積操作和下采樣操作組成,文中所用的卷積結構統一為 3x3 的卷積核,padding 為 0 ,striding 為 1。沒有 padding 所以每次卷積之後 feature map 的 H 和 W 變小了,在 skip-connection 時要注意 feature map 的維度(其實也可以將 padding 設定為 1 避免維度不對應問題),pytorch 程式碼:

nn.Sequential(nn.Conv2d(in_channels, out_channels, 3),

nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))

上述的兩次卷積之後是一個 stride 為 2 的 max pooling,輸出大小變為 1/2 *(H, W):

語義分割網路 U-Net 詳解

pytorch 程式碼:

nn.MaxPool2d(kernel_size=2, stride=2)

上面的步驟重複 5 次,最後一次沒有 max-pooling,直接將得到的 feature map 送入 Decoder。

  • Decoder

feature map 經過 Decoder 恢復原始解析度,該過程除了卷積比較關鍵的步驟就是 upsampling 與 skip-connection。

Upsampling 上取樣常用的方式有兩種:1. FCN 中介紹的反摺積;2.  插值。這裡介紹文中使用的插值方式。在插值實現方式中,bilinear 雙線性插值的綜合表現較好也較為常見 。

雙線性插值的計算過程沒有需要學習的引數,實際就是套公式,這裡舉個例子方便大家理解(例子介紹的是引數 align_corners 為 Fasle 的情況)。

語義分割網路 U-Net 詳解

例子中是將一個 2x2 的矩陣透過插值的方式得到 4x4 的矩陣,那麼將 2x2 的矩陣稱為源矩陣,4x4 的矩陣稱為目標矩陣。雙線性插值中,目標點的值是由離他最近的 4 個點的值計算得到的,我們首先介紹如何找到目標點周圍的 4 個點,以 P2 為例。

語義分割網路 U-Net 詳解

第一個公式,目標矩陣到源矩陣的座標對映:

語義分割網路 U-Net 詳解

為了找到那 4 個點,首先要找到目標點在源矩陣中的 相對位置,上面的公式就是用來算這個的。P2 在目標矩陣中的座標是 (0, 1),對應到源矩陣中的座標就是 (-0.25, 0.25)。座標裡面居然有小數跟負數,不急我們一個一個來處理。我們知道雙線性插值是從座標周圍的 4 個點來計算該座標的值,(-0.25, 0.25) 這個點周圍的 4 個點是(-1, 0), (-1, 1), (0, 0), (0, 1)。為了找到負數座標點,我們將源矩陣擴充套件為下面的形式,中間紅色的部分為源矩陣。

語義分割網路 U-Net 詳解

我們規定 f(i, j) 表示 (i, j)座標點處的畫素值,對於計算出來的對應的座標,我們統一寫成 (i+u, j+v) 的形式。那麼這時 i=-1, u=0.75, j=0, v=0.25。把這 4 個點單獨畫出來,可以看到目標點 P2 對應到源矩陣中的 相對位置

語義分割網路 U-Net 詳解

第二個公式,也是最後一個。

f(i + u, j + v) = (1 - u) (1 - v) f(i, j) + (1 - u) v f(i, j + 1) + u (1 - v) f(i + 1, j) + u v f(i + 1, j + 1)

目標點的畫素值就是周圍 4 個點畫素值的加權和,明顯可以看出離得近的權值比較大例如 (0, 0) 點的權值就是 0.75*0.75,離得遠的如 (-1, 1) 權值就比較小,為 0.25*0.25,這也比較符合常理吧。把值帶入計算就可以得到 P2 點的值了,結果是 12.5 與程式碼吻合上了,nice。

pytorch 裡使用 bilinear 插值:

nn.Upsample(scale_factor=2, mode='bilinear')

CNN 網路要想獲得好效果,skip-connection 基本必不可少。Unet 中這一關鍵步驟融合了底層資訊的位置資訊與深層特徵的語義資訊,pytorch 程式碼:

torch.cat([low_layer_features, deep_layer_features], dim=1)

這裡需要注意的是 ,FCN 中深層資訊與淺層資訊融合是透過對應畫素相加的方式,而 Unet 是透過拼接的方式。

那麼這兩者有什麼區別呢,其實 在 ResNet 與 DenseNet 中也有一樣的區別,Resnet 使用了對應值相加,DenseNet 使用了拼接。 個人理解在相加的方式下,feature map 的維度沒有變化,但每個維度都包含了更多特徵,對於普通的分類任務這種不需要從 feature map 復原到原始解析度的任務來說,這是一個高效的選擇;而拼接則保留了更多的維度/位置 資訊,這使得後面的 layer 可以在淺層特徵與深層特徵自由選擇,這對語義分割任務來說更有優勢。

小結

Unet 基於 Encoder-Decoder 結構,透過拼接的方式實現特徵融合,結構簡明且穩定,如果你有語義分割的問題,尤其在樣本資料量不大的情況下,十分推薦一試。


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69946223/viewspace-2655246/,如需轉載,請註明出處,否則將追究法律責任。

相關文章