[論文][半監督語義分割]Semi-Supervised Semantic Segmentation with High- and Low-level Consistency

柒七同學發表於2022-02-09

Semi-Supervised Semantic Segmentation with High- and Low-level Consistency

TPAMI 2019
論文原文
code

創新點:

利用兩個分支結構分別處理low-level和high-level的特徵,進行半監督語義分割

網路結構

image
上分支:Semi-Supervised Semantic Segmentation GAN (s4GAN)
下分支:Multi-Label Mean Teacher (MLMT)

s4GAN

訓練segmentation network \(S\)

segmentation network \(S\)的損失函式由以下三部分組成:

  1. Cross-entropy loss
    輸入原圖到segmentation network \(S\)中,對於labeled images,輸出的分割結果\(S(x^l)\)和標籤\(y^l\)對比,計算交叉熵損失\(L_{ce}\)
    image
  2. Feature matching loss
    為了使得分割結果\(S(x^l)\)和標籤\(y^l\)的特徵分佈儘可能一致,本文計算分割結果\(S(x^l)\)和標籤\(y^l\)的特徵分佈差異mean discrepancy,並設計Feature matching loss
    image
    上式中\(D_k\)表示discriminator的第\(k\)
    注:此Feature matching loss適用於有標籤和無標籤的資料
  3. Self-training loss
    本文認為,在訓練過程中generator和discriminator需要達到某種平衡,如果discriminator過於strong,則無法給generator任何有用的學習訊號。因此,對於unlabeled image,本文每次將generator產生的,可以成功欺騙discriminator的分割圖當作真實標籤,用於監督學習。由此可以促使segmentation network(即generator)變強,且一定程度上阻礙discriminator的進步,不希望discriminator過於強大,破壞平衡。
    具體而言,discriminator在s4GAN中用於在image-level判斷一張分割圖是真實標籤(real label),還是segmentation network的輸出(fake label),根據為真實標籤的可能性輸出一個0~1之間的概率值(若為真實標籤,則輸出1)
    文章設定閘值,對於輸出大於閘值的分割圖,作為高質量的預測圖,當作真實標籤,用於監督學習,並計算交叉熵損失
    image

s4GAN總損失:
image

訓練discriminator

discriminator的輸入包含原圖image和對應標籤,訓練discriminator,希望discriminator能給真實標籤打高分,給fake label打低分。具體損失函式和傳統的GAN相同。
image
image(channel wise)

MLMT

該分支包含兩個網路,分別為學生網路和老師網路,訓練時,一張image經過微小的,不同的擾動之後分別輸入學生網路和老師網路,學生網路和老師網路使用online ensemble的weight(老師網路是學生網路學習的目標,老師網路的權重在學生網路的基礎上根據指數平均移動線移動,詳見論文)。本文希望學生網路的輸出和老師網路的輸出儘可能一致,則對於所有image,使用均方誤差來衡量兩個網路輸出的差異,對於labeled image,同時使用類交叉熵函式計算損失
image

Network Fusion

簡單的通過deactivate segmentation networks的輸出中沒有出現在input image中的圖片來融合兩個網路的結果。
對於一張image分割圖的一個類別c的mask,尺寸為\(HxWx1\),(對於每一個畫素?)如果學生網路的輸出(soft label)小於設定的某個閘值,則令segmentation network的輸出為0,否則segmentation network的輸出不變。
image

實驗

資料集:

PASCAL VOC 2012 segmentation benchmark, the PASCAL-Context dataset, and the Cityscapes dataset.

網路具體結構:

segmentation network:

deeplab v2

discriminator:

4層卷積層,通道數分別為\({64,128,256,512}\),卷積核大小為4x4,每個卷積層後面都有一個negative slope of 0.2的Leaky-ReLU層和一個dropout概率為0.5的dropout層(該高概率的dropout layer對於GAN的穩定訓練非常關鍵)。最後一個卷積層後面是一個全域性平均池化層和全連線層,全域性平均池化的輸出用於Feature matching loss的計算

學生網路和老師網路:

ResNet101(在imagenet上預訓練)

實驗結果:

image

image

image

image

image

image

image

image

image

疑問:

  1. 網路融合的目的?
  2. self-train loss的設定(為阻止discriminator變強)?

相關文章