論文資訊
論文標題:Unsupervised Domain Adaptation for COVID-19 Information Service with Contrastive Adversarial Domain Mixup
論文作者:Huimin Zeng, Zhenrui Yue, Ziyi Kou, Lanyu Shang, Yang Zhang, Dong Wang
論文來源:aRxiv 2022
論文地址:download
論文程式碼:download
1 Introduction
2 Problem Statement
Regarding misinformation detection, we aim at training a model $f$ , which takes an input text $\boldsymbol{x}$ (a COVID-19 claim or a piece of news) to predict whether the information contained in $\boldsymbol{x}$ is valid or not (i.e., a binary classification task). Moreover, in our domain adaptation problem, we use $\mathcal{P}$ to denote source domain data distribution and $\mathcal{Q}$ for the target domain data distribution. Each data point ($\boldsymbol{x}$, $y$) contains an input segment of COVID-19 claim or news ($\boldsymbol{x}$) and a label $y \in\{0,1\}$ ( $y=1$ for true information and $y=0$ for false information). To differentiate the notations of the data sampled from the source distribution $\mathcal{P}$ and the target distribution $\mathcal{Q}$ , we further introduce two definitions of the domain data:
-
- Source domain: The subscript $s$ is used to denote the source domain data: $\mathcal{X}_{s}=\left\{\left(\boldsymbol{x}_{s}, y_{s}\right) \mid\left(\boldsymbol{x}_{s}, y_{s}\right) \sim \mathcal{P}\right\}$ .
- Target domain: Similarly, the subscript t is used to denote the target domain data: $\mathcal{X}_{t}=\left\{\boldsymbol{x}_{t} \mid \boldsymbol{x}_{t} \sim \mathcal{P}\right\}$ . Note that in our unsupervised setting, the ground truth labels of target domain data $y_{t}$ are not used during training.
Our goal is to adapt a classifier $f$ trained on $\mathcal{P}$ to $\mathcal{Q}$ . For a given target domain input $\boldsymbol{x}_{t}$ , a well-adapted model aims at making predictions as correctly as possible.
3 Method
整體框架:
3.1 Domain Discriminator
第一步是訓練一個域鑑別器 $f_{D}$ 來分類輸入資料是屬於源域還是屬於目標域。該域鑑別器與 COVID 模型共享相同的 BERT Encoder $f_{e}$,並具有不同的二進位制分類模組 $f_{D}$。域鑑別器以 BERT Encoder 中的標記 [CLS] 表示作為輸入,以預測輸入資料的域,如所示:
$\hat{y}=f_{D}(\boldsymbol{z}) \quad\quad(1)$
其中,$z$ 是 token [CLS] 的表示。
對於 $f_{D}$ 的訓練,明確地將源域資料的域標籤 $y_{D}$ 定義為 $y_{D}=0$,將目標域資料的域標籤定義為 $y_{D}=1$。因此,對域鑑別器的訓練可以表述為:
$\underset{f_{D}}{\text{min}} \;\; \mathbb{E}_{\left(\boldsymbol{x}, y_{D}\right) \sim \mathcal{X}^{\prime}}\left[l\left(f_{D}\left(f_{e}(\boldsymbol{x})\right), y_{D}\right)\right] \quad\quad(2)$
其中,$\mathcal{X}^{\prime}$ 表示帶有域標籤的源域和目標域訓練資料的合併資料集。
3.2 Adversarial Domain Mixup
在訓練了域鑑別器後,我們提出直接干擾來自源域和目標域的輸入資料的潛在表示到域鑑別器的決策邊界,如 Figure 1b 所示。為此,來自兩個域的擾動表示(即域對抗表示)可以變得更接近,表明域間隙減小。在此,從兩個域生成的域對抗性表示在模型的潛在特徵空間中形成了一個平滑的中間域混合。在數學上,透過求解一個最佳化問題,可以找到干擾訓練樣本 $ \boldsymbol{x}$ 的潛在表示 $ \boldsymbol{z}$ 的最優擾動 $\delta^{*}$:
$\begin{array}{r}\mathcal{A}\left(f_{e}, f_{D}, \boldsymbol{x}, y_{D}, \epsilon\right)=\underset{\boldsymbol{\delta}}{\text{max}} \left[l\left(f_{D}(\boldsymbol{z}+\boldsymbol{\delta}), y_{D}\right)\right] \\\text { s.t. } \quad\|\boldsymbol{\delta}\| \leq \epsilon, \quad \boldsymbol{z}=f_{e}(\boldsymbol{x})\end{array}\quad\quad(3)$
注意,在上面的方程中,我們引入了一個超引數 $\epsilon$ 來約束擾動 $\delta$ 的範數,從而避免了無窮大解。最後,將 $\text{Eq.3}$ 應用於合併訓練集 $\mathcal{X}^{\prime}$ 中的所有訓練樣本,得到對抗域混合 $\mathcal{Z}^{\prime}$:
$\begin{aligned}\mathcal{Z}^{\prime} & =\left\{\boldsymbol{z}^{\prime} \mid \boldsymbol{z}^{\prime}=\boldsymbol{z}+\mathcal{A}\left(f_{e}, f_{D}, \boldsymbol{x}, y_{D}, \epsilon\right),\left(\boldsymbol{x}, y_{D}\right) \in \mathcal{X}^{\prime}\right\} \\& :=\mathcal{Z}_{s}^{\prime} \cup \mathcal{Z}_{t}^{\prime}\end{aligned}\quad\quad(4)$
其中,$\mathcal{Z}_{s}^{\prime}$ 是擾動的源特性,$\mathcal{Z}_{t}^{\prime}$ 是受干擾的目標特徵。我們使用投影梯度下降(PGD)來近似 $\text{Eq.3}$ 的解,如在[7],[8]。
3.3 Contrastive Domain Adaptation
接下來,受[6]的啟發,我們提出了 $\mathcal{Z}_{a d v}$ 的雙重對比自適應損失,以進一步將源資料域的知識適應到目標資料域。首先,我們減少了類內表示之間的域差異。也就是說,如果一個表示從源資料域的標籤是真(或假)和一個表示從目標資料域的偽標籤是真(或假),那麼這兩個表示被視為類內表示,我們減少域之間的差異。其次,如 Figure 1c 所示,真實資訊和虛假資訊的表示之間的差異將被擴大。
為了計算我們提出的對比自適應損失,我們建議使用徑向基函式(RBF)來度量標記類之間的差異。在[11]中,RBF 被證明是量化深度神經網路中不確定性的有效工具。由於我們的偽標記過程是為了自動過濾出目標域資料的低置信度標籤,因此使用RBF來衡量標記類之間的差異可以有效地提高偽標籤的質量,最終有助於模型的域適應。
在形式上,使用 RBF 核心的定義:$k\left(z_{1}, z_{2}\right)=\exp \left[-\frac{\left\|\boldsymbol{z}_{1}-\boldsymbol{z}_{2}\right\|^{2}}{2 \sigma^{2}}\right]$
我們定義了錯誤資訊檢測任務的類感知損失如下:
$\begin{aligned}\mathcal{L}_{\text {con }}\left(\mathcal{Z}^{\prime}\right) =&-\sum_{i=1}^{\left|\mathcal{Z}_{s}^{\prime}\right|} \sum_{j=1}^{\left|\mathcal{Z}_{t}^{\prime}\right|} \frac{\mathbb{1}\left(y_{s}^{(i)}=0, \hat{y}_{t}^{(j)}=0\right) k\left(\boldsymbol{z}_{s}^{(i)}, \boldsymbol{z}_{t}^{(j)}\right)}{\sum_{l=1}^{\left|\mathcal{Z}_{s}^{\prime}\right|} \sum_{m=1}^{\left|\mathcal{Z}_{t}^{\prime}\right|} \mathbb{1}\left(y_{s}^{(l)}=0, \hat{y}_{t}^{(m)}=0\right)} \\& -\sum_{i=1}^{\left|\mathcal{Z}_{s}^{\prime}\right|} \sum_{j=1}^{\left|\mathcal{Z}_{t}^{\prime}\right|} \frac{\mathbb{1}\left(y_{s}^{(i)}=1, \hat{y}_{t}^{(j)}=1\right) k\left(\boldsymbol{z}_{s}^{(i)}, \boldsymbol{z}_{t}^{(j)}\right)}{\sum_{l=1}^{\left|\mathcal{Z}_{s}^{\prime}\right|} \sum_{m=1}^{\left|\mathcal{Z}_{t}^{\prime}\right|} \mathbb{1}\left(y_{s}^{(l)}=1, \hat{y}_{t}^{(m)}=1\right)} \\& +\sum_{i=1}^{\left|\mathcal{Z}_{s}^{\prime}\right|} \sum_{j=1}^{\left|\mathcal{Z}_{s}^{\prime}\right|} \frac{\mathbb{1}\left(y_{s}^{(i)}=1, y_{s}^{(j)}=0\right) k\left(\boldsymbol{z}_{s}^{(i)}, \boldsymbol{z}_{s}^{(j)}\right)}{\sum_{l=1}^{\left|\mathcal{Z}_{s}^{\prime}\right|} \sum_{m=1}^{\left|\mathcal{Z}_{s}^{\prime}\right|} \mathbb{1}\left(y_{s}^{(l)}=1, y_{s}^{(m)}=0\right)} \\& +\sum_{i=1}^{\left|\mathcal{Z}_{t}^{\prime}\right|} \sum_{j=1}^{\left|\mathcal{Z}_{t}^{\prime}\right|} \frac{\mathbb{1}\left(\hat{y}_{t}^{(i)}=1, \hat{y}_{t}^{(j)}=0\right) k\left(\boldsymbol{z}_{t}^{(i)}, \boldsymbol{z}_{t}^{(j)}\right)}{\sum_{l=1}^{\left|\mathcal{Z}_{t}^{\prime}\right|} \sum_{m=1}^{\left|\mathcal{Z}_{t}^{\prime}\right|} \mathbb{1}\left(\hat{y}_{t}^{(l)}=1, \hat{y}_{t}^{(m)}=0\right)}\end{aligned}\quad\quad(5)$
其中,$\hat{y}_{t}$ 為目標域樣本的偽標籤,$z$ 表示標記 CLS 的表示。
3.4 Overall Contrastive Adaptation Loss
現在,我們將任務分類問題的交叉熵損失和上述對比自適應損失合併為 COVID 模型的單一最佳化目標:
$\mathcal{L}_{\text {all }}=\mathcal{L}_{c e}(\boldsymbol{\mathcal { X }})+\lambda \mathcal{L}_{\text {con }}\left(\mathcal{Z}^{\prime}\right) \quad\quad(6)$
其中,$\mathcal{L}_{c e}$ 代表交叉熵損失函式。
4 Experiment
在我們的實驗中,我們使用了三個 source misinformation datasets :GossipCop , LIAR and PHEME,兩個 COVID misinformation datasets:Constraint and ANTiVax。
Results: