1、並行與分散式多工學習(Multi-task Learning, MTL)簡介
我們在上一篇文章《並行多工學習論文閱讀(一)多工學習速覽》(連結:https://www.cnblogs.com/lonelyprince7/p/15481054.html)中提到,實現多工學習的一種典型的方法為增加一個正則項[1][2][3]:
其中\(g(\textbf{W})\)編碼了任務的相關性(多工學習的假定)並結合了\(T\)個任務;\(\lambda\)是一個正則化引數,用於控制有多少知識在任務間共享。在許多論文中,都假設了損失函式\(f(\textbf{W})\)是凸的,且是\(L\text{-Lipschitz}\)可導的(對\(L>0\)),然而正則項\(g(\textbf{W})\)雖然常常不滿足凸性(比如採用矩陣的核範數),但是我們認為其實接近凸的,可以採用近端梯度演算法(proximal gradient methods)[4]來求解。
不過任務數量很大時,多工學習的計算複雜度很高,此時需要用多CPU/多GPU對學習演算法進行加速,儘量使\(T\)個任務的梯度的計算分攤到\(T\)個不同的工作節點(worker)上。但實際上由於正則項的存在和損失函式的複雜性,想做到這個需要我們仔細地設計並行多工學習演算法,在保證演算法加速的同時而儘量不影響優化演算法最終的收斂。
2、MTL的同步(synchronized)分散式優化演算法
我們將會從MTL的單機優化方法開始,逐步說明分散式優化的必要性並介紹它的一種主要實現手段——同步分散式優化演算法。
我們先來看單機優化,由於\(g(\mathbf{W})\)正則項的不光滑性,MTL的目標函式常採用基於近端梯度的一階優化方法進行求解,包括FISTA[5], SpaRSA[6]以及最近提出的二階優化方法PNOPT[7]。下面我們簡要回顧一下在這些方法中涉及到的兩個關鍵計算步驟:
(1) 梯度計算(gradient computing) 設第\(k\)迭代步的引數矩陣為\(\mathbf{W}^{(k)}\),目標函式光滑部分\(f(\mathbf{W}^{(k)})\)的梯度由每個任務的損失函式單獨計算梯度後拼接而得:
(2) 近端對映(proximal mapping) 在梯度更新後,我們會計算
此處\(\eta\)是迭代步長。不過請注意,此處的\(\hat{\mathbf{W}}\)並不是我們下一步的搜尋點,下一步的搜尋點會經過近端對映\(\text{Prox}(\hat{\mathbf{W}}; \eta, \lambda, g)\)獲得,該近端對映等價於求解下列優化問題:
這樣,我們就得到了下一步的搜尋點\(\mathbf{W}^{(k+1)}\)。
當資料量很大時,資料經常會分片儲存在不同的計算機甚至是不同的計算中心。比如,學習任務經常會涉及到不同的樣本集(用於學習不同的任務):\(\mathcal{D}_1,..., \mathcal{D}_T\),這些樣本集進場會儲存在不同的地方。比如如果我想用不同醫院的病例樣本集進行多工學習,那麼不同醫院的資料肯定各自儲存在不同地方。不管是出於網路傳輸頻寬考慮還是資料隱私考慮,想要將所有場所的資料集中在一起然後跑最優化演算法顯然不太現實(即使資料已經脫敏,大規模轉移病人資料仍然是個很有爭議的問題)。所以,對MTL設計分散式優化演算法就顯得非常重要了,分散式優化演算法旨在儘量將耗時的計算放在各分界點本地進行,然後再通過網路傳輸到中心節點。
不失一般性,我們假設我們的資料集分散儲存在一個用星形網路連線的計算機叢集中。每個單獨的計算機系統我們稱之為節點(node),工作節點(worker)或主體(agent)。第\(t\)個節點對於任務\(t\)的資料\(\mathcal{D}_t\)擁有完全訪問權,並能夠進行數值計算(比如計算第\(t\)個任務的梯度\(\nabla \mathcal{l_t}(\bm{w}_t\))。我們假定有一箇中心節點(central server)能夠收集所有任務節點(task agents)的資料,並進行近端對映操作。
我們接下來看如何並行。因為\(T\)個任務的獨立性,可以讓第\(t\)個任務節點儲存\(\bm{w}_t^{(k)}\),然後負責計算梯度\(\nabla \mathcal{l_t}(\bm{w}_t^{(k)})\),這樣就很容易地並行化了。然後我們收集每個任務的梯度向量\(\nabla \mathcal{l_t}(\bm{w}_t^{(k)})\)到中心節點並拼接得到\(\nabla f(\mathbf{W}^{(k)})\),然後計算\(\hat{\mathbf{W}}\),最後經過近端對映操作得到\(\mathbf{W}^{(k+1)}\)。然後再將\(\mathbf{W}^{(k+1)}\)拆分為\(\bm{w}_1^{(k+1)},...,\bm{w}_T^{(k+1)}\)分別傳送到\(T\)個任務節點,進行下一輪的迭代。整個並行演算法如下圖所示:
因為必須要所有任務節點的梯度計算並收集完畢後,主節點才能進行下一步操作,所以上面這種方法被稱為同步的(synchronized)。同步方法的最大弊端就是如果有一個或多個任務節點網路傳輸頻寬過高,或者直接down掉,其他任務節點都會停下來等待(因為拿不到下一輪的資料)。因為多數一階優化演算法都需要經過很多輪迭代才能夠收斂到一個特定的精度,在同步優化演算法中的等待會造成不能容忍的演算法執行時間和運算資源的極大浪費。
2、MTL的非同步(asynchronized)分散式優化演算法
上面提到的同步優化演算法可能讓一些讀者想到MapReduce計算架構,這種架構很少用於迭代演算法。比如我們在深度學習的訓練中多采用引數伺服器(Parameter Server)架構,這是一種非同步優化的架構。在多工學習的領域,也有學者提出了非同步優化演算法,接下來我們以《Asynchronous Multi-Task Learning》[8](IM Baytas等,2016)這篇論文為例,來介紹MTL的非同步優化演算法。
在本篇論文的非同步優化演算法中,中心節點只要收到了來自一個任務節點的已經算好的梯度,就會馬上對模型的引數矩陣\(\mathbf{W}\)進行更新,而不用等待其他任務節點完成計算。中心節點和任務節點都會在記憶體中維護一份\(\mathbf{W}\)的拷貝,任務節點之間的拷貝可能會各不相同。AMTL(Asynchronized Multi-task learning)的收斂率分析可以參照另外兩篇介紹ARock[9]計算框架介紹Tmac[10]計算框架的論文,這兩篇論文采用Krasnosel’skii-Mann
(KM) 迭代方法來解決非同步並行座標更新(asynchronous parallel coordinate update)問題。我們稱一個任務節點被啟用,當它進行(梯度)計算並與中心節點通訊以進行更新。《Asynchronous Multi-Task Learning》這篇論文也提出了一個非同步並行的框架,該框架基於以下關於啟用率(activation rate)的假設:
假設1: 所有任務節點服從獨立的泊松過程並且有相同的啟用率。
該假設可以得到一個有用的結論,如果不同的任務節點的啟用率不同,我們理論上可以調整迭代步長\(\eta\)來調整迭代結果:如果任務節點的啟用率很大,那麼該任務節點被啟用的可能性就會很大,從而我們應該降低
該任務節點對應的迭代步長\(\eta\)(注意:因為是非同步演算法,每個任務節點都有其對應的迭代步長)。該論文提出了的一個動態迭代步長策略,具體細節在此略過。
論文使用了一種前向-後向運算元分裂方法[11][12]來求解目標函式\((1)\)。求解目標函式\((1)\)等價於找到最優解\(\mathbf{W}^*\)使\(0 \in \nabla f(\mathbf{W}^*) + \lambda \partial g(\mathbf{W^*})\),這裡\(\partial g(\mathbf{W}^*)\)指非光滑函式\(g(\cdot)\)在\(\mathbf{W}^*\)點的次梯度集合(拼成一個矩陣)。我們有以下等價關係:
因此前向後向迭代如下:
該迭代對\(\eta \in (0, 2/L)\)會收斂到解。前面我們提到\(\nabla f(\mathbf{W})\)是可分的,比如可以寫成\(\nabla f(\mathbf{W}) = (\nabla l_1(\bm{w}_1),..., \nabla l_T(\bm{w}_T))\),且此處的前向運算元\(\mathbf{I}- \eta \nabla f\)也是可分的。不過這裡的後向運算元\((\mathbf{I} + \eta \lambda \partial g)^{-1}\)是不可分的,將導致後面無法並行。因此我們不能直接在前向-後向迭代上應用座標更新法。不過,如果我們轉換前向和後向的順序,我們可以得到下列的後向-前向迭代:
這裡我們使用一個輔助矩陣\(\mathbf{V} \in \mathbb{R}^{d \times T}\)來在更新中替代\(\mathbf{W}\),這是因為前向-後向迭代和後向-前向迭代中的更新變數是不一樣的。因此,由\(\mathbf{V}^*\)得到\(\mathbf{W}^*\)需要一個額外的後向迭代步驟。之後我們就可以在後向-前向迭代的基礎上,按照論文[9]來對任務塊(代指所有和任務有關的變數)進行座標更新了。更新步驟如下所示:
這裡\(\bm{v}_t \in \mathbb{R}^d\)是任務\(t\)中\(\bm{w}_t\)相應的輔助變數。注意想要更新一個任務塊\(\bm{v}_t\)只需要在一個任務塊上的一個完整的後向步驟和前向步驟。後面我們會給出整體的AMTL演算法。該演算法的迭代步遵循的為在論文[9]中討論的KM迭代。KM迭代提供了一個解決不動點迭代問題的通用框架。(此處的不動點迭代即用運算元做優化的思想,將優化問題等價為找到nonexpansive operator的不動點,可參見運算元優化相關書籍[13]和知乎文章[14])。迭代式的推導過程具體可參照論文[9]。
機器學習中一般根據子問題的複雜性來選擇前向-後向迭代和後向-前向的迭代。如果資料集\((\bm{x}_t, y_t)\)很大,此時後向步驟相比前向步驟更容易計算,我們會使用後向-前向迭代來進行座標更新。具體到MTL的應用中,後向迭代步驟由式\((4)\)的近端對映給出,而這通常有解析解(比如跡範數的奇異值的軟閾值)。另一方面,式\((2)\)的梯度計算是典型的耗時瓶頸(尤其是資料集很大時),因此後向-前向迭代為分散式MTL提供了一個更為高效的優化框架。最後,我們注意到後向前向迭代當\(\eta \in (0, 2/L)\)是一個non-expansive operator(因為前向和後向步驟都是non-expansive的)。
整個非同步多工學習框架如下圖所示:
在本篇AMTL論文中,任務節點並不共享記憶體(注意,在文獻[9]中,各任務節點可訪問一個共享記憶體),任務節點間不能通訊,但它們都各自與主節點連線並能與之通訊。任務節點和主節點之間的通訊只有向量\(\bm{v}_t\),相較各任務節點上儲存的本地資料\(D_t\)很小。每個任務節點負責計算前向步驟;而主節點負責計算後向步驟,一旦更新的梯度從任務節點傳來,就進行近端對映(近端對映也能夠在多輪梯度更新後才進行,取決於梯度更新的速度)。因為每個任務節點只需要和該任務節點相關的任務塊,故本篇論文進一步減少了任務節點和主節點之間的通訊代價。
上面這幅圖進一步描述了AMTL中的非同步更新機制,包括中心節點和任務節點分別執行後向和前向步驟的順序。在\(t_1\)時刻,任務節點2從中心節點接收了已完成近端對映的引數\(\text{Prox}_{\eta \lambda g}(\hat{\bm{v}}^k_2)\),之後在任務節點2上的前向(梯度)計算步驟就會馬上啟動。在虛擬碼步驟\((\mathbf{III}.4)\)所示的任務梯度下降更新完成後,任務2的引數\(\mathbb{v}_2\)會被送回中心節點。當中心節點收到引數後,它會開始對整個(即包括所有任務的)引數矩陣進行近端對映。
然而,這個演算法卻會有潛在的不一致性(inconsistency)問題。如圖所示,當\(t_2\)與\(t_3\)之間,即任務節點2在執行計算時,任務節點4已經將其計算好的引數傳送到中心節點並觸發了近端對映。因此,中心節點的引數矩陣在任務節點2計算梯度時,就因響應任務節點4而更新。之後當任務節點2將算好的引數送到中心節點時,近端對映只能在不一致的資料上進行計算(資料由來自任務節點2的引數和之前已更新的引數混合而成)。 同樣,任務節點4在\(t_3\)時刻收到引數並完成計算後,此時中心節點的引數已經被更新(因為任務節點2在\(t_4和t_5\)之間已觸發近端對映),後面也會產生同樣的問題。
為什麼會有這種不一致性呢?這是因為AMTL中的資料讀取是沒有加記憶體鎖的。因此,對於非同步座標更新模式,從中心節點讀引數向量時會有不一致性的問題。這種由於後向迭代步驟產生的不一致性已經被論文考慮在了收斂率分析中,具體細節大家可以參見論文。
最後,這篇論文在異構網路環境下的版本程式碼已開源在Github上(連結: https://github.com/illidanlab/AMTL ),大家可前往學習。
參考文獻
- [1] Evgeniou T, Pontil M. Regularized multi--task learning[C]//Proceedings of the tenth ACM SIGKDD international conference on Knowledge discovery and data mining. 2004: 109-117.
- [2] Zhou J, Chen J, Ye J. Malsar: Multi-task learning via structural regularization[J]. Arizona State University, 2011, 21.
- [3] Zhou J, Chen J, Ye J. Clustered multi-task learning via alternating structure optimization[J]. Advances in neural information processing systems, 2011, 2011: 702.
- [4] Ji S, Ye J. An accelerated gradient method for trace norm minimization[C]//Proceedings of the 26th annual international conference on machine learning. 2009: 457-464.
- [5] A. Beck and M. Teboulle, “A fast iterative shrinkage-thresholding algorithm for linear inverse problems,” SIAM Journal on Imaging Sciences, vol. 2, no. 1, pp. 183–202, 2009.
- [6] S. J. Wright, R. D. Nowak, and M. A. Figueiredo, “Sparse reconstruction by separable approximation,” IEEE Transactions on Signal Processing, vol. 57, no. 7, pp. 2479–2493, 2009.
- [7] J. D. Lee, Y. Sun, and M. A. Saunders, “Proximal newton-type methods for minimizing composite functions,” SIAM Journal on Optimization, vol. 24, no. 3, pp. 1420–1443, 2014.
- [8] Baytas I M, Yan M, Jain A K, et al. Asynchronous multi-task learning[C]//2016 IEEE 16th International Conference on Data Mining (ICDM). IEEE, 2016: 11-20.
- [9] Z. Peng, Y. Xu, M. Yan, and W. Yin, “ARock: An algorithmic framework for asynchronous parallel coordinate updates,” SIAM Journal on Scientific Computing, vol. 38, no. 5, pp. A2851–A2879, 2016.
- [10] B. Edmunds, Z. Peng, and W. Yin, “Tmac: A toolbox of modern asyncparallel, coordinate, splitting, and stochastic methods,” UCLA CAM Report 16-38, 2016.
- [11] P. L. Combettes and V. R. Wajs, “Signal recovery by proximal forwardbackward splitting,” Multiscale Modeling & Simulation, vol. 4, no. 4, pp. 1168–1200, 2005.
- [12] Z. Peng, T. Wu, Y. Xu, M. Yan, and W. Yin, “Coordinate-friendly structures, algorithms and applications,” Annals of Mathematical Sciences and Applications, vol. 1, pp. 57–119, 2016.
- [13] Bauschke H H, Combettes P L. Convex analysis and monotone operator theory in Hilbert spaces[M]. New York: Springer, 2011.
- [14] https://zhuanlan.zhihu.com/p/150605754
- [15] 楊強等. 遷移學習[M].機械工業出版社, 2020.