Semi-Supervised Semantic Segmentation with High- and Low-level Consistency
創新點:
利用兩個分支結構分別處理low-level和high-level的特徵,進行半監督語義分割
網路結構
上分支:Semi-Supervised Semantic Segmentation GAN (s4GAN)
下分支:Multi-Label Mean Teacher (MLMT)
s4GAN
訓練segmentation network \(S\)
segmentation network \(S\)的損失函式由以下三部分組成:
- Cross-entropy loss
輸入原圖到segmentation network \(S\)中,對於labeled images,輸出的分割結果\(S(x^l)\)和標籤\(y^l\)對比,計算交叉熵損失\(L_{ce}\)
- Feature matching loss
為了使得分割結果\(S(x^l)\)和標籤\(y^l\)的特徵分佈儘可能一致,本文計算分割結果\(S(x^l)\)和標籤\(y^l\)的特徵分佈差異mean discrepancy,並設計Feature matching loss
上式中\(D_k\)表示discriminator的第\(k\)層
注:此Feature matching loss適用於有標籤和無標籤的資料 - Self-training loss
本文認為,在訓練過程中generator和discriminator需要達到某種平衡,如果discriminator過於strong,則無法給generator任何有用的學習訊號。因此,對於unlabeled image,本文每次將generator產生的,可以成功欺騙discriminator的分割圖當作真實標籤,用於監督學習。由此可以促使segmentation network(即generator)變強,且一定程度上阻礙discriminator的進步,不希望discriminator過於強大,破壞平衡。
具體而言,discriminator在s4GAN中用於在image-level判斷一張分割圖是真實標籤(real label),還是segmentation network的輸出(fake label),根據為真實標籤的可能性輸出一個0~1之間的概率值(若為真實標籤,則輸出1)
文章設定閘值,對於輸出大於閘值的分割圖,作為高質量的預測圖,當作真實標籤,用於監督學習,並計算交叉熵損失
s4GAN總損失:
訓練discriminator
discriminator的輸入包含原圖image和對應標籤,訓練discriminator,希望discriminator能給真實標籤打高分,給fake label打低分。具體損失函式和傳統的GAN相同。
(channel wise)
MLMT
該分支包含兩個網路,分別為學生網路和老師網路,訓練時,一張image經過微小的,不同的擾動之後分別輸入學生網路和老師網路,學生網路和老師網路使用online ensemble的weight(老師網路是學生網路學習的目標,老師網路的權重在學生網路的基礎上根據指數平均移動線移動,詳見論文)。本文希望學生網路的輸出和老師網路的輸出儘可能一致,則對於所有image,使用均方誤差來衡量兩個網路輸出的差異,對於labeled image,同時使用類交叉熵函式計算損失
Network Fusion
簡單的通過deactivate segmentation networks的輸出中沒有出現在input image中的圖片來融合兩個網路的結果。
對於一張image分割圖的一個類別c的mask,尺寸為\(HxWx1\),(對於每一個畫素?)如果學生網路的輸出(soft label)小於設定的某個閘值,則令segmentation network的輸出為0,否則segmentation network的輸出不變。
實驗
資料集:
PASCAL VOC 2012 segmentation benchmark, the PASCAL-Context dataset, and the Cityscapes dataset.
網路具體結構:
segmentation network:
deeplab v2
discriminator:
4層卷積層,通道數分別為\({64,128,256,512}\),卷積核大小為4x4,每個卷積層後面都有一個negative slope of 0.2的Leaky-ReLU層和一個dropout概率為0.5的dropout層(該高概率的dropout layer對於GAN的穩定訓練非常關鍵)。最後一個卷積層後面是一個全域性平均池化層和全連線層,全域性平均池化的輸出用於Feature matching loss的計算
學生網路和老師網路:
ResNet101(在imagenet上預訓練)
實驗結果:
疑問:
- 網路融合的目的?
- self-train loss的設定(為阻止discriminator變強)?