- 文章轉自:微信公眾號「機器學習煉丹術」
- 作者:煉丹兄(已授權)
- 聯絡方式:微信cyx645016617
- 論文名稱:“Learning From Synthetic Data: Addressing Domain Shift for Segmentation”
「前言」:最近好久沒更新公眾號了,我一不小心陷入了一個誤區:我以為自己看的文章足夠多了,用之前的風格遷移和GAN的知識來解決一個domain adaptive的問題,一頓亂拳並沒有打死老師傅,反而自己累個夠嗆。然後找到這樣一篇不錯的DA framework,來認真學習一下章法,假期結束重新用章法組合拳再來會會。
0 綜述
不同於以往的對抗模型或者是超畫素資訊來實現這個領域遷移,本文使用的是對抗生成網路GAN來將兩個領域的特徵空間拉近。
本文提出的是語義分割的領域自適應演算法。論文特別關注的問題是:目標領域沒有label。
傳統的DA方法包含最小化某些可以衡量source和target兩個分佈的距離函式。兩種常見的度量是:
- 最大均值差(Maximum Mean Discrepancy, MMD)
- 通過對抗學習,使用DCNN來學習distance metric
本文的主要貢獻在於提出了一種基於生成模型的特徵空間源分佈與目標分佈對齊演算法。
1 method
從圖片中來初步判斷,其實是比較好理解的:
- 首先,我猜測其做域遷移,可能是仿照GAN領域中做風格遷移的辦法;
- 圖片中總共有4個網路,F網路應該是特徵提取網路,C網路是做分割的網路,G網路是把F提取的特徵再還原成原圖的網路,D網路是做分類的網路,和一般GAN不同的是,D中做四個分類,是True source,True target, False source, False targe. 類似於把cycleGAN中的兩個二分類的discriminator合併了。
2 細節
原始圖片定義為\(X\),source domain的圖片定義為\(X^s\),target domain的圖片定義為\(X^t\).
- base network. 架構類似於預訓練的VGG16,被分成了兩個部分:特徵提取部分叫做F網路,做畫素分割的叫做C網路。
- G網路是用來從F生成的embedding特徵中,重建原始影像的;D網路不僅要分別出圖片是否是real or fake,還會做一個分割任務,類似於C網路。這個分割任務僅僅針對source domain,因為target domain不存在標籤。
現在我們假定已經準備好了資料和標籤\({X^s,Y^s}\):
- 首先經過F提取出來feature expression,\(F(X^s)\)
- C網路生成分割的標籤\(\hat{Y}^s\)
- G網路重建圖片\(\hat{X}^s\)
基於最近的相關的成功的研究,不再在G的輸入中顯式的concatenate一個隨機變數,而是在Generator中使用dropout layer
3 損失
作者提出了很多的對抗損失:
- 在一個domin內的損失有:
- Discriminator損失,分辨src-real和src-fake;
- Discriminator損失,分辨tgt-real和tgt-fake;
- Generator損失,讓fake source可以被discriminator判斷成src-real的損失;
- 在不同domain的損失:
- F網路的損失,可以讓fake source的輸入被判斷為real target;
- F網路的損失,可以讓fake target的輸入被判斷為real source;
除了上面說到的對抗損失外,還有下面的分割損失:
- \(L_{seg}\):在標準分割網路C中的pixel-wise的交叉熵損失;
- \(L_{aux}\):D網路也會輸出一個分割結果,交叉熵損失;
- \(L_{rec}\):原始影像和重建影像之間的L1損失。
4 訓練過程
在每一個iteration中,一個隨機的三元組被輸入到模型中:\(\{X^s,Y^s,X^t\}\),然後網路按照下面的順序進行更新引數:
- 先更新引數D,更新策略如下:
- 對於source input,用\(L_{aux}\)和\(L^s_{adv,D}\);
- 對於target input,用\(L^t_{adv,D}\)
- 然後更新G,更新策略如下:
- 愚弄discriminator的兩個loss,\(L^s_{adv,G}\)和\(L^t_{adv,G}\);
- 重建損失,\(L^s_{rec}\)和\(L^t_{rec}\);
- F網路的更新策略如下:
- F網路的更新是最關鍵的!(論文中說的)
- 更新F網路是為了實現domain adaptive,$L^s_{adv,F}$是為了混淆fake source 和real target;
- 類似於G-D之間的min-max game,這裡是F和D之間的競爭,只不過前者是為了混淆fake和real,後者是為了混淆source domain和target domain;
5 D的設計動機
我們可以發現,這裡面的D其實不是傳統的GAN中的D,輸出不再是單獨的一個scalar,表示圖片是fake or real的概率
最近有一篇GAN裡面提到了,patch discriminator(這個論文恰好之前讀過),這個是讓D輸出的也是一個二位的量,每一個值表示對應patch的fake or real的概率,這個措施極大的提高了G重建的圖片的質量,這裡繼承延伸了patch discriminator的思想,輸出的圖片是一個pixel-wise的類似分割的結果,每一個畫素有四個類別:fake-src,real-src,fake-tgt,real-tgt;
GAN一般是比較難訓練的,尤其是針對大尺度的真實圖片資料,一種穩定的方法來訓練生成模型的架構是Auxiliary Classifier GAN(ACGAN)(真好,這個論文我之前也看過),簡單的說通過增加一個輔助分類損失,可以訓練一個更穩定的G,因此這也是為什麼D中還會有一個分割損失\(L_{aux}\)
6 總結
作者提高,每一個元件都提供了關鍵的資訊,不多說了,假期回實驗室我要開始用章法組合拳來解決問題了。