遷移學習(IIMT)——《Improve Unsupervised Domain Adaptation with Mixup Training》

加微信X466550探討發表於2023-01-10

論文資訊

論文標題:Improve Unsupervised Domain Adaptation with Mixup Training
論文作者:Shen Yan, Huan Song, Nanxiang Li, Lincan Zou, Liu Ren
論文來源:arxiv 2020
論文地址:download 
論文程式碼:download
引用次數:93

1 Introduction

  現有方法分別對源域和目標域施加約束,忽略了它們之間的重要相互作用。本文使用 mixup 來加強訓練約束來直接解決目標域的泛化效能。

  當前工作假設:當在表示級處理域差異時,訓練後的源分類器能夠在目標域上自動取得良好的效能。然而,當前研究表明,在兩個域上都表現良好的分類器可能不存在 [6,7],所以僅依賴源分類器可能導致目標域的顯著錯誤分類。現有最先進的方法在對抗學習過程中尋求額外的訓練約束,不過他們都是在所選擇的域獨立地使用訓練約束,而不是聯合約束。這使得這兩個域之間的重要相互作用尚未被探索,並可能會顯著限制訓練約束的潛力。

  本文透過簡單的 $\text{mixup training}$,證明了引入該訓練約束可以顯著提高模型適應效能。

  $\text{Mixup}$:給定一對樣本 $\left(x_{i}, y_{i}\right)$、$\left(x_{j}, y_{j}\right)$ ,生成的增強表示為:

    $\begin{array}{c}x^{\prime}=\lambda x_{i}+(1-\lambda) x_{j} \\y^{\prime}=\lambda y_{i}+(1-\lambda) y_{j}\end{array} $

  其中,$\lambda \in[0,1]$。

  透過使用 $\left(x^{\prime}, y^{\prime}\right)$ 訓練,鼓勵了模型的線性行為,其中原始資料中的線性插值導致預測的線性插值。

  受半監督學習[9] 的啟發,本文透過在目標資料上推斷標籤來實現跨域的 $\text{mixup}$。透過這種方式,與只使用源標籤來訓練分類器不同,本文還可以使用域之間的插值(虛擬)標籤來提供額外的監督。隨著 $\text{mixup}$ 訓練和領域對抗性訓練的進展,該模型推斷出虛擬標籤。該過程對於直接提高目標域分類器的泛化具有關重要。此外,為了在非常大的域差異下有效地加強線性約束,本文開發了一個特徵級一致性正則化器,以更好地促進 $\text{mixup}$ 訓練。除了域間約束外,$\text{mixup}$ 也可以在每個域內應用。域間和域內混合訓練構成了所提出的 IIMT 框架,用於加強多方面約束以提高目標域效能。

2 Problem Statement

  The overview of IIMT framework is shown in $\text{Figure 1}$. We denote the labeled source domain as set  $\left\{\left(x_{i}^{s}, y_{i}^{s}\right)\right\}_{i=1}^{m_{s}} \sim \mathcal{S}$  and unlabeled target domain as set  $\left\{x_{i}^{t}\right\}_{i=1}^{m_{t}} \sim \mathcal{T}$ . Here  $y_{i}$  denotes one-hot labels. The overall classification model is denoted as  $h_{\theta}: \mathcal{S} \mapsto \mathcal{C}$  with the parameterization by  $\theta$ . Following prominent approaches in UDA [6, 7], we consider the classification model as the composite of an embedding encoder  $f_{\theta}$  and an embedding classifier  $g_{\theta}: h=f \circ g$ . Note that encoder is shared by the two domains. The core component in our framework is mixup, imposed both across domains (Inter-domain in $\text{Figure 1}$) and within each domain (Intra-domain (source) and Intradomain (target) in $\text{Figure 1}$. All mixup training losses and the domain adversarial loss are trained end-to-end.

    

3 Method

3.1 Inter-domain Mixup Training

  本文框架中的關鍵元件:源域和目標域之間的 $\text{mixup}$ 訓練。在 $h$ 的訓練中,$\text{mixup}$ 提供了插值標籤來強制分類器跨域的線性預測行為。與單獨使用源標籤訓練相比,它們導致了一種簡單的歸納偏差,但本文可以直接提高分類器對目標域的泛化能力。

  $\text{mixup}$ 訓練需要樣本標籤來進行插值,本文利用推斷出的標籤作為對目標域的弱監督。類似的想法在半監督學習設定[10,9]中被證明在開發相關的未標記資料方面是非常有效的。

  首先,對目標域每個資料樣本執行 $K$ 個任務相關的隨機增強,以獲得轉換後的樣本 $\left\{\hat{x}_{i, k}\right\}_{k=1}^{K}$。然後,計算目標域的虛擬標籤:$\bar{q}_{i}=\frac{1}{K} \sum\limits _{k=1}^{K} h_{\theta}\left(\hat{x}_{i, k}\right)$,歸一化為 $q_{i}=\bar{q}_{i}^{\frac{1}{T}} / \sum\limits _{c} \bar{q}_{i, c}^{\frac{1}{T}}$,使用較小的 $T<1$ 產生更清晰的預測分佈。

  

  給定一對源樣本和目標樣本 $\left(x_{i}^{s}, x_{i}^{t}\right)$,標籤級 $\text{mixup}$ 以加強各域之間的線性一致性:

    $\begin{array}{l}x_{i}^{s t}=\lambda^{\prime} x_{i}^{s}+\left(1-\lambda^{\prime}\right) x_{i}^{t} \quad\quad(1) \\q_{i}^{s t}=\lambda^{\prime} y_{i}^{s}+\left(1-\lambda^{\prime}\right) q_{i}^{t} \quad\quad(2) \\\mathcal{L}_{q}=\frac{1}{B} \sum_{i} H\left(q_{i}^{s t}, h_{\theta}\left(x_{i}^{s t}\right)\right)\quad\quad(3) \end{array}$

  其中,$\text{B}$ 代表 $\text{batch size}$ ,$\text{H}$ 為交叉熵損失,$\text{mixup}$ 加權引數根據:$\lambda \sim \operatorname{Beta}(\alpha, \alpha)$ 和 $\lambda^{\prime}=\max (\lambda, 1-\lambda)$ 選擇。

  當設定 $\alpha$ 接近於 $1$ 時,從範圍 $[0,1]$ 中選擇 $\lambda$ 為中間值的機率更大,使得兩個域之間的插值水平更高。請注意,$\lambda^{\prime}$ 始終超過 $0.5$,以確保源域占主導地位。同樣地,也可產生目標域主導的 $\text{mixup}$,只需要透過在 $\text{Eq.1}$ 中切換 $x^{s}$ 和 $x^{t}$ 的係數,對應地形成 $\left(x^{t s}, q^{t s}\right)$。使用目標域主導的 $\left(x^{t s}, q^{t s}\right)$,採用均方誤差(MSE)損失,因為它更能容忍目標域中的虛假虛擬標籤。

3.1.1  Consistency Regularizer

  在域差異非常大的情況下,域間 $\text{mixup}$ 所施加的線性約束可能效果較差。具體來說,當異構的原始輸入在 $\text{Eq.1}$ 中被插值時,迫使模型 $h$ 產生相應的插值預測變得更加困難。同時,對於特徵級域混淆的域對抗損失的聯合訓練會增加訓練難度。

  因此,本文為潛在特徵設計一個一致性正則化器,以更好地促進域間 $\text{mixup}$ 訓練:

    $\begin{aligned}z_{i}^{s t} & =\lambda^{\prime} f_{\theta}\left(x_{i}^{s}\right)+\left(1-\lambda^{\prime}\right) f_{\theta}\left(x_{i}^{t}\right) \quad\quad(4)    \\\mathcal{L}_{z} & =\frac{1}{B} \sum\limits _{i}\left\|z_{i}^{s t}-f_{\theta}\left(x_{i}^{s t}\right)\right\|_{2}^{2}\quad\quad(5)\end{aligned}$

  即:透過兩個向量之間的 $\text{MSE}$ 損失,使混合特徵更接近於混合輸入的特徵。這個正則化器的作用:當 $\text{Eq.5}$ 強制 $z_{i}^{s t}$, $f_{\theta}\left(x_{i}^{s t}\right)$ 透過淺分類器 $g$,模型預測的線性更容易滿足。

    

3.1.2 Domain Adversarial Training

  最後一個組成部分是使用標準的域對抗性訓練來減少域的差異。本文的實現限制在更基本的 DANN 框架[1]上,以試圖集中於評估混合線性約束。在DANN中,一個域鑑別器和共享嵌入編碼器(生成器)在對抗性目標下進行訓練,使編碼器學習生成域不變特徵。本文的源和目標樣本 $\text{mixup}$ 的域對抗性損失:

    $\mathcal{L}_{d}=\frac{1}{B} \sum_{i} \ln D\left(f_{\theta}\left(x_{i}^{s t}\right)\right)+\ln \left(1-D\left(f_{\theta}\left(x_{i}^{s t}\right)\right)\right)\quad\quad(6)$

3.2 Intra-domain Mixup Training

  給定源標籤和目標虛擬標籤,$\text{mixup}$ 訓練也可以在每個域內執行。由於在同一域內的樣本遵循相似的分佈,因此不需要應用特徵級的線性關係。因此,只對這兩個領域使用標籤級 $\text{mixup}$ 訓練,並定義它們相應的損失:

    $\begin{array}{l}x_{i}^{s^{\prime}}=\lambda^{\prime} x_{i}^{s}+\left(1-\lambda^{\prime}\right) x_{j}^{s} \\y_{i}^{s^{\prime}}=\lambda^{\prime} y_{i}^{s}+\left(1-\lambda^{\prime}\right) y_{j}^{s} \\\mathcal{L}_{s}=\frac{1}{B} \sum\limits _{i} H\left(y_{i}^{s^{\prime}}, h_{\theta}\left(x_{i}^{s^{\prime}}\right)\right)\end{array}\quad\quad(7)$

    $\begin{array}{l}x_{i}^{t^{\prime}}=\lambda^{\prime} x_{i}^{t}+\left(1-\lambda^{\prime}\right) x_{j}^{t} \\q_{i}^{t^{\prime}}=\lambda^{\prime} q_{i}^{t}+\left(1-\lambda^{\prime}\right) q_{j}^{t} \\\mathcal{L}_{t}=\frac{1}{B} \sum\limits _{i}\left\|q_{i}^{t^{\prime}}-h_{\theta}\left(x_{i}^{t^{\prime}}\right)\right\|_{2}^{2}\end{array}\quad\quad(8)$

  雖然域內混合作為一種資料增強策略是直觀的,但它對 UDA 特別有用。正如在[6]中所討論的,沒有區域性約束的條件熵的最小化會導致資料樣本附近的預測突變。在[6]中,利用虛擬對抗訓練[10]來增強鄰域的預測平滑性。不同的是,我們發現域內混合訓練能夠實現相同的目標。

3.3 Training Objective

  訓練目標:

    $\mathcal{L}=w_{q} \mathcal{L}_{q}+w_{d} \mathcal{L}_{d}+w_{z} \mathcal{L}_{z}+w_{s} \mathcal{L}_{s}+w_{t} \mathcal{L}_{t}\quad\quad(9)$

  由於 $\mathcal{L}_{t}$ 只涉及虛擬標籤,因此很容易受到目標域的不確定性的影響。本文為訓練中的 $w_{t}$ 設定了一個線性時間表,從 $0$ 到一個預定義的最大值。從初始實驗中,觀察到該演算法對其他加權引數具有良好的魯棒性。因此,只搜尋 $w_{t}$,而簡單地將所有其他權重固定為 $1$。

4 Experiment

  For image classification experiments, we evaluate on MNIST, MNIST-M, Street View House Numbers (SVHN), Synthetic Digits (SYN DIGITS), CIFAR-10 and STL-10.

  A → B to denote the domain adaptation task with source domain A and target domain B.

  前三:手寫陣列識別;後二:目標檢測:

  

  消融實驗:

  

 

 


Note

條件熵:條件熵 $H(Y|X)$ 表示在已知隨機變數 $X$ 的條件下隨機變數 $Y$ 的不確定性。

    $\begin{aligned}H(Y \mid X) & =\sum\limits_{x \in X} p(x) H(Y \mid X=x) \\& =-\sum\limits_{x \in X} p(x) \sum\limits_{y \in Y} p(y \mid x) \log p(y \mid x) \\& =-\sum\limits_{x \in X} \sum\limits_{y \in Y} p(x, y) \log p(y \mid x)\end{aligned}$

 

DANN

import torch
from torch.autograd import Function
import torch.nn as nn
import torch.nn.functional as F

class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        print("forward===========================")
        print("xx = ",x)
        ctx.alpha = alpha
        ctx.feature = x
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        print("backward===========================")
        print("grad_output = ",grad_output)
        output = grad_output.neg() * ctx.alpha
        return output, None

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.featurizer = nn.Linear(4,3)

        self.classifier = nn.Linear(3,2)
        self.discriminator = nn.Linear(3,2)
        self.alpha = 1

    def forward(self,x,disc_labels,label):
        # 特徵提取
        z = self.featurizer(x)
        print("z = ",z)

        disc_input = z
        disc_input = ReverseLayerF.apply(disc_input, self.alpha)
        disc_out = self.discriminator(disc_input)
        disc_loss = F.cross_entropy(disc_out, disc_labels)

        all_preds = self.classifier(z)
        classifier_loss = F.cross_entropy(all_preds,label)
        loss = classifier_loss + disc_loss
        loss.backward()
        return

x = torch.tensor([[ 1.1118,  1.8797, -0.9592, -0.6786],
        [ 0.4843,  0.4395, -0.2360, -0.6523],
        [ 0.7377,  1.4712, -2.3062, -0.9620],
        [-0.7800,  1.8482,  0.0786,  0.0179]], requires_grad=True)
disc_labels = torch.LongTensor([0,0,1,1])
label =  torch.LongTensor([0,0,1,1])

print("x = ",x)
print("disc_labels = ",disc_labels)
print("label = ",label)
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
model = Net()
model(x,disc_labels,label)

GAN

# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)  #torch.Size([64, 1])
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)   #torch.Size([64, 1])

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))     #torch.Size([64, 1, 28, 28])

        # Train Generator   ========================

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))    #torch.Size([64, 100])

        # Generate a batch of images
        gen_imgs = generator(z)        #torch.Size([64, 1, 28, 28])

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        #Train Discriminator     ========================
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

相關文章