清華開源遷移學習演算法庫

AIBigbull2050發表於2020-08-09

作者:清華大資料軟體團隊機器學習組

本文長度為 1700字,建議閱讀 6分鐘

本文為你介紹 Trans-Learn 演算法庫。

Trans-Learn是基於PyTorch實現的一個高效、簡潔的遷移學習演算法庫,目前釋出了第一個子庫——深度域自適應演算法庫(DALIB),支援的演算法包括:

  • Domain Adversarial Neural Networks (DANN)
  • Deep Adaptation Network (DAN)
  • Joint Adaptation Networks (JAN)
  • Conditional Adversarial Domain Adaptation (CDAN)
  • Maximum Classifier Discrepancy (MCD)
  • Margin Disparity Discrepancy (MDD)

專案地址:

域自適應背景介紹

目前深度學習模型在部分計算機視覺、自然語言處理任務中已經超過了人類的表現,但是它們的成功依賴於大規模的資料標註。但是 實際場景中,標註資料往往是稀缺的。解決標註資料稀缺問題的一個方法是透過計算機模擬生成訓練資料,例如用計算機圖形學的技術合成訓練資料。

清華開源遷移學習演算法庫

圖表 1 VisDA2017競賽任務

但是由於訓練資料和測試資料不再服從獨立同分布,訓練得到的深度網路的準確率大打折扣。為了解決上述資料漂移造成的問題,域自適應(Domain Adaptation) 的概念被提出。 域自適應的目標是將模型在源域(Source) 學到的知識遷移到目標域(Target)。例如計算機模擬生成訓練資料的例子中,合成資料是源域,真實場景的資料是目標域。

域自適應有效地緩解了深度學習對於人工標註資料的依賴,受到了學術界和工業界廣泛的關注。目前已經被引入到圖片分類、影像分割(Segmentation)、目標檢測(Object Detection)、機器翻譯(Machine Translation) 等眾多工上。吳恩達曾說過:“ 在監督學習之後,遷移學習將引領下一波機器學習技術商業化浪潮。”隨著產品級的機器學習應用進入資料稀缺的領域,監督學習得到的尖端模型效能大打折扣,域自適應變得至關重要。

研究現狀

深度域自適應方法主要包括以下兩大類:

1.  矩匹配。透過最小化分佈差異來對齊不同域的特徵分佈。例如深度適配網路DAN,聯合適配網路JAN。

2.  對抗訓練。域對抗網路DANN是最早的工作,它引入一個領域判別器,鼓勵特徵提取器學到領域無關的特徵。 在DANN的基礎上,衍生出了一系列方法,例如條件域對抗網路CDAN,間隔差異散度MDD等。

清華開源遷移學習演算法庫

圖表 2 DANN網路架構圖

清華開源遷移學習演算法庫

圖表 3 MDD網路架構圖

上述方法在實驗資料上體現了良好的效能。然而目前學術界域自適應方法的開源實現中存在下述問題:

  • 複用性差。域自適應方法和模型架構、資料集耦合在一起,不利於域自適應方法在新的模型、資料集上覆用。
  • 穩定性差。部分對抗訓練方法隨著訓練進行,準確率會大幅度下降。

DALIB設計的初衷就是讓使用者透過少數幾行程式碼,就可以將域自適應演算法用在實際專案中,而無需考慮域自適應模組的實現細節。

易用性

DALIB將現有域自適應訓練程式碼中的域自適應損失函式分離出來,按照PyTorch交叉熵損失函式的形式進行封裝,方便使用者的使用。域自適應損失函式也和模型架構進行了解耦,因此不依賴於具體的分類任務,所以演算法庫很容易擴充套件到圖片分類以外的分類任務。

如下,使用兩行程式碼即可定義一個與任務無關的域對抗損失函式。

清華開源遷移學習演算法庫

不同域自適應損失函式中有一些公用的模組,例如所有演算法中都用到的分類器模組,對抗訓練中用到的梯度翻轉模組、域判別器模組,核方法中的核函式模組等。這些公用模組和提供的域自適應損失函式是分離的。因此,在DALIB中,使用者可以像搭積木一樣,重新定製自己需要的域自適應損失函式。

例如,核方法中,使用者可以自己定義不同引數的高斯核或者其他核函式,然後傳入到多核最大均值差異(MK-MMD)的計算中。

清華開源遷移學習演算法庫

目前,所有的模組和損失函式均已提供詳細的API說明文件。

穩定性

域自適應演算法研究領域往往關注方法的創新程度或者理論層面的價值,而忽視了工程實現中的穩定性和可復現性。在復現現有的演算法的過程中,出現了部分演算法準確率不穩定的問題。透過對數值方面的改進,這些問題都已經得到解決。(具體實現就不在此處展開了。)

此外,DALIB幾乎在所有任務上,準確率都比原論文匯報準確率高,部分資料集上甚至能高14%。下圖分別是Office-31和VisDA-2017上的測試結果。

清華開源遷移學習演算法庫

圖表 4 Office-31上不同演算法的準確率

清華開源遷移學習演算法庫

圖表 5 VisDA2017上不同演算法的準確率

演算法庫提供了各個演算法在Office-31、Office-Home和VisDA-2017上的測試結果,以及所有的測試指令碼。我們認為開源該演算法庫對於這個領域未來的研究工作是具有巨大價值的。

未來的工作

域自適應演算法子庫DALIB下一個版本會支援域自適應演算法的不同設定,包括部分域自適應任務(Partial Domain Adaptation)、開放集域自適應任務(Open-set Domain Adaptation)、通用域自適應任務(Universal Domain Adaptation)等。

遷移學習演算法庫Trans-Learn目前還處於初期開發階段,難免有不完善的地方,歡迎其他研究者提意見。同時遷移學習這個方向也還在不斷髮展,今後會不斷跟進新工作中比較好的演算法。

當前版本由龍明盛老師課題組的江俊廣和付博同學開發,如果有任何意見和建議,歡迎聯絡JiangJunguang1123@outlook.com

fb1121@vip.qq.com

編輯:於騰凱

校對:林亦霖

—完—







來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69946223/viewspace-2710293/,如需轉載,請註明出處,否則將追究法律責任。

相關文章