Adversarial Learning for Semi-Supervised Semantic Segmentation
摘要
創新點:我們提出了一種使用對抗網路進行半監督語義分割的方法。
在傳統的GAN網路中,discriminator大多是用來進行輸入影像的真偽分類(Datasets裡面sample的圖片打高分,generator產生的圖片打低分),而本文設計了一種全卷積的discriminator,用於區分輸入標籤圖中各個畫素(pixel-wise)的分類結果是ground truth或是segmentation network給出的。本文證明了所提出的discriminator可以通過耦合模型的對抗損失和標準交叉熵損失來提高語義分割的準確性。此外,全卷積鑑別器通過發現未標記影像預測結果中的可信區域,實現半監督學習,從而提供額外的監督訊號。
網路模型
對於labeled images:
-
image \(x_n\)輸入segmentation network,得到分割結果 \(S(x_n)\)
-
分割結果\(S(x_n)\)和該圖片對應真實標籤\(Y_n\)比較,計算交叉熵損失\(L_{ce}\)
-
分割結果\(S(x_n)\)送入discriminator中求 \(L_{adv}\)
-
使用\(S(x_n)\)和真實標籤\(Y_n\)訓練discriminator:分別將\(S(x_n)\)和真實標籤\(y\)輸入discriminator,讓discriminator分辨輸入標籤的每個畫素是來自是ground truth還是segmentation network(即輸入的每個畫素為來自於\(S(x_n)\)還是真實標籤\(Y_n\))
discriminator的輸入為\(S(x_n)\)或真實標籤\(Y_n\),尺寸為\(HxWxC\),其中\(C\)為語義分割的類別數;輸出尺寸為\(HxWx1\),畫素值代表這個pixel來自於真實標籤\(Y_n\)的概率(如果discriminator認為該畫素100%是來自真實標籤\(Y_n\),則該位置畫素值為1)
損失函式為:
- 注:上式中,當輸入為\(S(x_n)\)時,\(y_n = 0\),當輸入為\(Y_n\)時,\(y_n = 1\)
對於unlabeled image:
-
將image \(x_n\)輸入segmentation network,得到輸出\(S(x_n)\),尺寸為\(HxWxC\),每個維度上的值代表該畫素取這個類別的概率值。對輸入進行 one-hot encode,得到 \(\hat{Y_n}\)
編碼過程:
-
用\(\hat{Y_n}\)和\(S(x_n)\)進行交叉熵損失計算
-
將\(S(x_n)\)通過訓練後的discriminator,得到\(D(S(x_n))\),尺寸為\(HxWx1\),並設定閘值,通過指示函式對輸出進行二值化(對於輸出中畫素值大於閘值的畫素,認為是可信的,以突出正確的區域)
無標籤部分的損失函式為:
實際中\(T_{semi}\)的取值為0.1~0.3
訓練總損失:
Tips:
- 在訓練過程中首先用labeled image進行5000iteration的訓練(segmentation network和discriminator交替update)
- 此後隨機sample,每個batch裡面都可能有labeled image和unlabeled image,各自按照自己的步驟訓練
- discriminator只用每個batch裡面的labeled image進行訓練
具體網路結構
Segmentation network:
首先採用DeepLab-v2 中的ResNet-101作為backbone進行預訓練,並去掉最後一個分類層,將最後兩個卷積層的步幅從2修改為1,從而使輸出特徵圖的解析度有效地達到輸入影像大小的1/8。為了擴大感受野,我們將擴充套件後的卷積分別應用於步幅為2和4的conv4和conv5層。此外,我們在最後一層使用了Atrous Spatial Pyramid Pooling (ASPP)。最後,我們應用一個上取樣層和softmax輸出來匹配輸入影像的大小。
Discriminator:
實驗
Table4: 訓練資料集為pascal VOC標準的1464張圖片,SBD中的圖片作為無標籤資料進行訓練