分散式多工學習論文閱讀(三):運用代理損失進行任務分解

orion發表於2021-11-05

1 代理損失函式——一種並行化拆解技巧

我們在本系列第一篇文章《分散式多工學習論文閱讀(一)多工學習速覽》(連結:https://www.cnblogs.com/orion-orion/p/15481054.html)中提到,實現多工學習的一種傳統(非神經網路)的方法為增加一個正則項[1][2][3]

\[\begin{aligned} \underset{\textbf{W}}{\min} & \sum_{t=1}^{T} [\frac{1}{m_t}\sum_{i=1}^{m_t}L(y_i^t, f(\bm{x}_i^t; \bm{w}_t))]+\lambda g(\textbf{W})\\ & = \sum_{t=1}^{T} \mathcal{l}_t(\bm{w}_t)+\lambda g(\textbf{W})\\ & = f(\textbf{W}) + \lambda g(\textbf{W}) \end{aligned} \tag{1} \]

目標函式中的\(f(\mathbf{W})\)很容易並行化拆解,但是一般\(g(\mathbf{W})\)就很難並行化了,那麼如何解決這個問題呢?答案是運用一個可以分解的代理損失函式來替換掉原始的目標函式。我們接下來就以論文《Parallel Multi-Task Learning》[4](zhang 2015c等人)為例來介紹該思想。該論文MTLR模型[5](zhang 2015a)的基礎上利用FISTA演算法設計代理損失函式,該代理函式可以依據學習任務進行分解,從而平行計算。

2 基於正則化的多工學習(MTLR)演算法回顧

給定\(T\)個任務\({\{\mathcal{T}_i\}}_{i=1}^T\),每個任務都有一個訓練集\(\mathcal{D}_t = {\{(\bm{x}_{i}^t, y_{i}^t)}_{i=1}^{m_t}\}\)。我們現在考慮以下形式的目標函式:

\[\begin{aligned} \underset{\textbf{W}, b}{\min} & \sum_{t=1}^{T} [\frac{1}{m_t}\sum_{i=1}^{m_t}L(y_i^t, \langle \bm{w}_t, \phi(\bm{x}_{i}^t)\rangle+b) ]+ \frac{\lambda}{2}\text{tr}(\textbf{W}\Omega\mathbf{W}^T)\\ \end{aligned} \tag{2} \]

這裡的\(\phi(\cdot)\)是一個和核函式\(k(\cdot, \cdot)\)相關的特徵對映,這裡\(\phi(\bm{x}_1)^T\phi(\bm{x}_2)=k(\bm{x}_1, \bm{x}_2)\)\(L(\cdot, \cdot)\)是損失函式(比如對於分類問題的\(\text{hinge loss}\) 和對於迴歸問題的 \(\epsilon \text{-insentive loss}\)。式\((2)\)的第一項是所有任務的經驗損失函式,第二項基於\(\mathbf{W}和\Omega\)來建模任務間的關係。根據論文[5]\(\Omega\)是一個正定(Positive definite, PD)矩陣,它用來描述任務兩兩之間關係的精度矩陣(協方差矩陣\(\Sigma\)的逆)。如果損失函式是凸的且\(\Omega\)正定,那麼目標函式\((2)\)關於\(\mathbf{W}和\bm{b}\)是聯合凸(jointly convex)的。為了體現目標函式\((2)\)和單任務核方法的關係,我們這裡只考慮\(\Omega\)是對角矩陣的情況。在此情況下,任務兩兩之間沒有關係,問題\((2)\)也退化為了多個單任務模型(每個模型對應一個任務)。因此,問題\((2)\)可以被視為單任務正則化模型的多工擴充套件。在問題\((2)\)中,\(\frac{\lambda}{2}\text{tr}(\textbf{W}\Omega\mathbf{W}^T)\)不影響我們的並行演算法設計,這是非常好的。而問題\((2)\)總是能夠加速問題的學習,當使用特定的優化程式如論文[5]和論文[6]一樣,根據過去的研究這些方法有很快的收斂率,不管正則項是什麼。
在問題\((2)\)中有許多損失函式可供使用,比如\(\text{hinge loss}\)\(\epsilon-\text{insensitive loss}\)\(\text{square loss}\),下面我們主要就採用這三種損失函式,後面我們會分別給出問題\((2)\)關於這三個損失函式的對偶形式。

2. 並行多工學習演算法

2.1 FISTA迭代演算法

下面我們就給出當使用不同的損失函式時問題\((2)\)的並行求解演算法。因為我們的求解演算法是基於FISTA迭代的,我們先來看FISTA迭代演算法。
FISTA迭代演算法[7]是一個加速梯度下降方法,用於求解一個類似於下面這種形式的複合凸目標函式(compositely convex objective function):

\[\underset{\bm{\theta} \in \mathcal{C}_{\bm{\theta}}}{\text{min}}\quad F(\bm{\theta}) = f(\bm{\theta}) + g(\bm{\theta}) \tag{3} \]

這裡\(\bm{\theta}\)是指模型的引數集合,\(f(\bm{\theta})\)是凸的且它的梯度有\(\text{Lipschitz}\)連續性,凸函式\(g(\bm{\theta})\)有著簡單的且易分解(並行)的結構,\(\mathcal{C}_{\bm{\theta}}\)是指\(\bm{\theta}\)的定義域。FISTA演算法最新構建代理損失函式\(Q_l(\bm{\theta}, \hat{\bm{\theta}})\)如下:

\[Q_{\mathcal{L}}(\bm{\theta}, \hat{\bm{\theta}})=g(\bm{\theta})+f(\hat{\bm{\theta}})+(\bm{\theta}-\hat{\bm{\theta}})^T\nabla_{\bm{\theta}}f(\hat{\bm{\theta}}) + \frac{\mathcal{L}}{2}||\bm{\theta}-\hat{\bm{\theta}}||_2^2 \tag{4} \]

這裡\(\nabla_{\bm{\theta}}f(\hat{\bm{\theta}})\)表示\(f(\bm{\theta})\)\(\bm{\theta}=\hat{\bm{\theta}}\)點的梯度,\(\mathcal{L}\)\(f(\cdot)\)梯度的\(\text{Lipschitz}\)常量,接著我們優化關於\(\bm{\theta}\)的函式\(Q_{\mathcal{L}}(\bm{\theta}, \hat{\bm{\theta}})\),約束為\(\bm{\theta} \in \mathcal{C}_{\bm{\theta}}\)。函式\(Q_{\mathcal{L}}(\bm{\theta}, \hat{\bm{\theta}})\)關於\(\bm{\theta}\)的優化器由\(q_{\mathcal{L}}(\hat{\bm{\theta}})\)表示。

FISTA演算法虛擬碼如下圖所示:

FISTA演算法虛擬碼

可以看到第\(17\)步和\(18\)步在\(\bm{\theta}\)能夠被劃分為許多部分的情況下可以輕易並行。但目前的問題是如何並行化演算法步驟\(11\)\(13\)

2.1 將目標函式轉換為對偶問題

當使用\(\text{hinge},\epsilon\text{-intensive}\)\(\text{squre}\)損失函式時,我們需要用\(\text{FISTA}\)演算法優化其對偶問題。下面我們分別說明得到這三個損失函式對應目標函式的對偶問題,後面我們會在此基礎上進行並行化。

2.1.1 Hinge Loss

(1)轉為對偶形式 我們將Hinge Loss函式\(L_h(y^{'},y)=\text{max}(1-y^{'}y, 0)\)代入式(2)的優化問題,並將無約束優化轉為有約束優化可得到:

\[\begin{aligned} & \underset{\mathbf{W}, \bm{b}, \bm{\eta}}{\text{min}}\quad \frac{\lambda}{2}\text{tr}(\textbf{W} \mathbf{\Omega} \textbf{W}^T) + \sum_{t=1}^T\frac{1}{m_t}\sum_{i=1}^{m_t}\eta_{i}^t \\ & \text{s.t.} \quad y_{i}^t(\langle \bm{w}_t, \phi(\bm{x}_{i}^t)\rangle+b_t) \geqslant 1- \eta_{i}^t, \quad \eta_{i}^t \geqslant 0 \end{aligned} \tag{5} \]

這裡\(\bm{\eta}=(\eta_{1}^t, ..., \eta_{m_T}^t)^T\)。引入非負的Lagrange乘子\(\{\alpha_{i}^t\}\)\(\{\beta_{i}^t \}\),我們可以得到問題\((5)\)的對偶形式如下:

\[\begin{aligned} & \underset{\bm{\alpha}}{\text{min}} \quad \frac{1}{2\lambda}\bm{\alpha}^T\mathbf{P}\bm{\alpha} - \sum_{t=1}^T\sum_{i=1}^{m_t}\alpha_{i}^t \\ & \text{s.t.} \quad \sum_{i=1}^{m_t}\alpha_{i}^t y_{i}^t = 0 \\ & \quad \quad \space \space t=1, 2,...,T, \quad i = 1, 2, ..., m_t, \quad 0 \leqslant \alpha_{i}^t \leqslant \frac{1}{m_t} \end{aligned} \tag{6} \]

這裡我們說明一下矩陣\(\mathbf{P}\)的含義,設\(\sigma_{ij}\)是任務關係協方差矩陣\(\Sigma\)的第\((i, j)\)個元素,\(\mathbf{K}\)是一個\(n \times n\)的矩陣,它的第\((I_{a}^b, I_{c}^d)\)個元素是\(\sigma_{ac}k(\bm{x}_{a}^b, \bm{x}_{c}^d)\),這裡\(I_{i}^t =i+\sum_{k=1}^{t-1}m_k\)計算在所有任務的訓練資料中的\(\bm{x}_{i}^t\)的下標。\(\odot\)指逐元素乘積操作,這裡有\(\mathbf{P}=\mathbf{K} \odot (\bm{y}\bm{y}^T)\)。這裡我們定義函式\(k_{MT}(\cdot, \cdot)\)\(k_{MT}(\bm{x}_{i}^t, \bm{x}_{s}^r) = \sigma_{r}^tk(\bm{x}_{i}^t, \bm{x}_{s}^r)\)
用來構造矩陣\(\mathbf{K}\)。很容易證明這是一個核函式。所以我們稱\(k_{MT}(\cdot, \cdot)\)是一個多工核函式,將\(\mathbf{K}\)稱為多工核矩陣。

2.1.2 \(\epsilon\) - Insensitive Loss

接下來我們討論將\(\epsilon-\) insensitive loss函式
\(L_{\epsilon}(y,y^{'}) = \left\{ \begin{aligned} 0 \quad \text{若} |y - y^{'}| \leqslant \epsilon \\ |y - y^{'}| - \epsilon \quad \text{其他} \end{aligned} \right .\)
代入問題\((2)\)進行優化。我們再引入一些鬆弛變數,問題\((2)\)可被轉化為:

\[\begin{aligned} & \underset{\mathbf{W}, \bm{b}, \bm{\eta}, \bm{\tau}}{\text{min}}\quad \sum_{t=1}^{T} \frac{1}{m_t} \sum_{i=1}^{m_t}(\eta_i^t+\tau_i^t) + \frac{\lambda}{2}\ \text{tr}(\mathbf{W}\mathbf{\Omega}\mathbf{W}^T) \\ & \text{s.t.} \quad n_i^t \geqslant 0, \bm{w}_t^T\phi(\bm{x}_i^t) + b_t - y_i^t \leqslant \epsilon + \eta_i^t \\ & \quad \quad \space \space \tau_i^t \geqslant 0, y_i^t - \bm{w}_t^T\phi(\bm{x}_i^t) - b_t \leqslant \epsilon + \tau_i^t \end{aligned} \tag{7} \]

這裡\(\bm{\eta} = (\eta_1^1,..., \eta_{m_T}^T)^T\)\(\bm{\tau} = (\tau_1^1,..., \tau_{m_T}^T)^T\)
我們接下來引入Lagrange乘子\(\bm{ \alpha} = (\alpha_1^1,...,\alpha_{m_T}^T)^T\)\(\bm{ \beta} = (\beta_1^1,...,\beta_{m_T}^T)^T\),進一步得到問題\((7)\)的對偶問題:

\[\begin{aligned} & \underset{\bm{\alpha}, \bm{\beta}}{\text{min}}\quad \frac{1}{2\lambda}(\bm{\alpha}- \bm{\beta})^T\mathbf{K}(\bm{\alpha} - \bm{\beta}) + \epsilon (\bm{\alpha}+\bm{\beta})^T \mathbf{1} + \bm{y}^T(\bm{\alpha} - \bm{\beta}) \\ & \text{s.t.} \quad \sum_{i=1}^{m_t}(\alpha_i^t - \beta_i^t) = 0 \\ & \quad \quad \quad t=1, 2,...,T, \quad i = 1, 2, ..., m_t, \quad 0 \leqslant \alpha_{i}^t ,\beta_{i}^t \leqslant \frac{1}{m_t} \end{aligned} \tag{8} \]

這裡\(\mathbf{1}\)表示一個元素全為1的合適大小的向量或者矩陣,\(\mathbf{K}\)表示由等式\((8)\)的多工核函式\(k_{MT}(\cdot, \cdot)\)構成的矩陣。這裡\(\bm{y}=(y_1^1, ...,y_{m_T}^T)^T\)

2.1.3 Square loss:

我們將square loss代入問題\((2)\),得到以下優化問題:

\[\begin{aligned} & \underset{\mathbf{W}, \bm{b}, {\eta_{ij}}}{\text{min}}\quad \sum_{t=1}^{T} \frac{1}{m_t} \sum_{i=1}^{m_t}(\eta_i^t)^2 + \frac{\lambda}{2}\ \text{tr}(\mathbf{W}\mathbf{\Omega}\mathbf{W}^T) \\ & \text{s.t.} \quad n_i^t = y_i^t - \bm{w}_t^T\phi(\bm{x}_i^t) - b_t \end{aligned} \tag{9} \]

引入Lagrange乘子\(\{\alpha_{ij}\}\),我們就可以得到問題\((9)\)的對偶形式:

\[\begin{aligned} & \underset{\bm{\alpha}}{\text{min}} \quad \frac{1}{2\lambda}\bm{\alpha}^T\mathbf{Q}\bm{\alpha} - \sum_{t=1}^T\sum_{i=1}^{m_t}\alpha_{i}^t y_{i}^t \\ & \text{s.t.} \quad \sum_{i=1}^{m_t}\alpha_{i}^t = 0 \\ & \quad \quad \space t=1, 2,...,T, \quad i = 1, 2, ..., m_t, \quad 0 \leqslant \alpha_{i}^t \leqslant \frac{1}{m_t} \end{aligned} \tag{10} \]

這裡\(\bm{\alpha} = (\alpha_1^1,...,\alpha_{m_T}^T)^T\)。這裡\(\mathbf{Q} = \mathbf{K} + \frac{\lambda}{2}\mathbf{\Lambda}\)\(\mathbf{\Lambda}\)是一個對角矩陣,相應的資料點屬於第\(k\)個任務時其對角元素為\(m_k\)

注意,後面我們會發現三個損失函式對應的對偶形式都有著相似的形式而且和單任務對偶形式的主要不同點都在於線性不等式約束。也就是說,在單任務對偶形式中,只有一個涉及Lagrange乘子的線性不等式約束;但是在多工環境下,有\(m\)個線性不等式約束,每個不等式都由一個任務的Lagrange乘子組成。有趣的是,這種差別決定了我們後面設計的並行演算法。

2.2 將對偶問題的求解並行化:

接下來我們需要展示應用FISTA演算法並行化求解\((6)\),其他損失函式同理。我們定義\(\bm{\theta}=\bm{\alpha}\)\(\bm{\phi} = \hat{\bm{\alpha}}\)\(f(\bm{\alpha}) = \frac{1}{2\lambda}\bm{\alpha}^T\bm{P}\bm{\alpha}\)\(g(\bm{\alpha}) = \sum_{t=1}^T\sum_{i=1}^{m_t}\alpha_{i}^t\),定義域\(\mathcal{C}_{\alpha} = \{\alpha | \sum_{i=1}^{m_t}\alpha_{i}^t y_{i}^t = 0(t=1,2,..,T, i=1,2,...,m_t, 0 \leqslant \alpha_i^t\leqslant \frac{1}{m_t})\}\)。下面我們來看如何並行化演算法步驟\(11\)\(13\)

\(f(\bm{\alpha})\)關於\(\bm{\alpha}\)的二階導數\(\nabla^2 f(\bm{\alpha})\)是我們這裡的\(\frac{1}{\lambda}\bm{P}\)。我們用\(||\cdot||\)表示矩陣的\(l_2\)範數,易得\(||\mathbf{P}||_2 \mathbf{I}_n-\mathbf{P}\)是一個半正定矩陣。所以\(f(\bm{\alpha})\)的最小\(\text{Lipschitz}\)常量是\(\frac{1}{\lambda}||\textbf{P}||_2\)(\(\mathcal{L} \geqslant {\frac{1}{\lambda}||\textbf{P}||_2}\))。當\(n\)非常大時,計算\(||\textbf{P}||_2\)非常耗時,我們下面會展示如何並行地計算它。

給定\(\mathcal{L}\),我們能夠優化關於\(\bm{\alpha}\)的函式\(Q_{\mathcal{L}}(\bm{\alpha}, \hat{\bm{\alpha}})\),這也是步驟11或13要求解的(並行地)。特別地,步驟11或13要求解的優化問題可以被描述為:

\[\begin{aligned} & \underset{\bm{\alpha}}{\text{min}}\quad \frac{\mathcal{L}}{2}||\bm{\alpha} - \hat{\bm{\alpha}}||^2_2 + \frac{1}{\lambda}\bm{\alpha}^T\textbf{P}\hat{\bm{\alpha}} - \sum_{t=1}^T\sum_{i=1}^{m_t}\alpha_{i}^t \\ & \text{s.t.} \quad \sum_{i=1}^{m_t}\alpha_{i}^t y_{i}^t = 0 \\ & \quad \quad \space \space t=1, 2,...,T, \quad i = 1, 2, ..., m_t, \quad 0 \leqslant \alpha_{i}^t \leqslant \frac{1}{m_t} \end{aligned} \tag{11} \]

該問題可以被分解為\(T\)個獨立的子問題,第\(t\)個子問題為:

\[\begin{aligned} & \underset{\bm{\alpha}_t}{\text{min}}\quad \sum_{i=1}^{m_t}\left( \frac{\mathcal{L}}{2}(\alpha_i^t)^2 - \alpha_i^t \alpha_j^t \right) \\ & \text{s.t.} \quad \sum_{i=1}^{m_t}\alpha_{i}^t y_{i}^t = c_t \\ & \quad \quad \space \space i = 1, 2, ..., m_t, \quad \rho \leqslant \alpha_{i}^t \leqslant d_t \end{aligned} \tag{12} \]

這裡\(\alpha_t = (\alpha_1^t, ..., \alpha_{m_t}^t )^T\)\(a_i^t = \mathcal{L}\hat{\bm{\alpha}}_j^t+1-\frac{1}{\lambda} \hat{p}_i^t\)\(\hat{p}_i^t\)\(\text{P}\hat{\bm{\alpha}}\)中與\(\bm{x}_i^t\)對應的元素。\(c_t=0,\rho = 0, d_t = \frac{1}{m_t}\)

問題\((12)\)是一個二次規劃(quadratic programming, QP)問題,我們能夠不用任何QP求解器,在\(O(m_t)\)的時間內用拉格朗日乘子法求解。正如問題\((11)\)所示,FISTA演算法的每一輪迭代都需要計算\(\mathbf{P}\hat{\bm{\alpha}}\)以決定\(Q_{\mathcal{L}}(\bm{\alpha}, \hat{\bm{\alpha}})\)。如果我們直接解任務\((12)\)\(\alpha\)會完全和之前的估計不同,且計算\(\mathbf{P}\hat{\bm{\alpha}}\)會花費\(O(n^2)\),當\(n\)很大時計算量太大。所以這裡我們希望採取SMO演算法的思想,只更新部分的\(\alpha\)元素,這樣計算\(\mathbf{P}\hat{\bm{\alpha}}\)的時間複雜度減少到\(O(n)\)。(因為我們只需要關心變化的元素)

參考文獻

  • [1] Evgeniou T, Pontil M. Regularized multi--task learning[C]//Proceedings of the tenth ACM SIGKDD international conference on Knowledge discovery and data mining. 2004: 109-117.
  • [2] Zhou J, Chen J, Ye J. Malsar: Multi-task learning via structural regularization[J]. Arizona State University, 2011, 21.
  • [3] Zhou J, Chen J, Ye J. Clustered multi-task learning via alternating structure optimization[J]. Advances in neural information processing systems, 2011, 2011: 702.
  • [4] Zhang Y. Parallel multi-task learning[C]//2015 IEEE International Conference on Data Mining. IEEE, 2015: 629-638.
  • [5] Zhang Y, Yeung D Y. A convex formulation for learning task relationships in multi-task learning[J]. arXiv preprint arXiv:1203.3536, 2012.
  • [6] Zhang Y, Yeung D Y. A regularization approach to learning task relationships in multitask learning[J]. ACM Transactions on Knowledge Discovery from Data (TKDD), 2014, 8(3): 1-31.
  • [7] A. Beck and M. Teboulle, “A fast iterative shrinkagethresholding algorithm for linear inverse problems,” SIAM Journal on Imaging Sciences, 2009
  • [8] 楊強等. 遷移學習[M].機械工業出版社, 2020.

相關文章