分散式多工學習論文閱讀(四):去偏lasso實現高效通訊

orion發表於2021-11-10

1.難點-如何實現高效的通訊

我們考慮下列的多工優化問題:

\[ \underset{\textbf{W}}{\min} \sum_{t=1}^{T} [\frac{1}{m_t}\sum_{i=1}^{m_t}L(y_{ti}, \langle \bm{w}_t, \bm{x}_{ti} \rangle)]+\lambda \text{pen}(\textbf{W}) \tag{1} \]

這裡\(\text{pen}(\mathbf{W})\)是一個用於增強group sparse的正則項(參見聯合特徵學習(joint feature learning),常為\(l_1/l_2\)\(l_1/l_{\infin}\)範數的組合,用於只保留對所有任務有用的特徵)。比如。在group lasso penalty[1][2] 中使用 \(\text{pen}(\mathbf{W}) =\sum_{t=1}^T||\bm{w}_t||_2 = \sum_{t=1}^T (\sum_{j=1}^d{w}_{jt}^2 )^{1/2}\)(這裡\(d\)為特徵維度,\(T\)為任務數,\(\bm{w}_t\)\(\mathbf{W}\)的第\(t\)列); \(\text{iCAP}\)使用\(\text{pen}(\mathbf{W}) = ||\mathbf{W}||_{\infin, 1} = \sum_{j=1}^d||\bm{w}^j||_{\infin}= \sum_{j=1}^d\underset{1\leqslant t \leqslant T}{\text{max}}|w_{jt}|\) [3][4](這裡\(\bm{w}^j\)是指\(\mathbf{W}\)的第\(j\)行。注意區分這個和矩陣的\(\infin\)範數,求和與求最大的順序是不一樣的!這裡相當於求向量的無窮範數之和),等等。
在分散式的環境中,我們可以按照文章《分散式多工學習論文閱讀(二)同步和非同步優化演算法》(連結:https://www.cnblogs.com/orion-orion/p/15487700.html)提到的基於近端梯度的同步/非同步優化演算法來優化問題\((1)\),但是正如我們在該篇部落格中所說的,這種方法需要多輪的通訊,時間開銷較大。這樣,如何實現機器間的有效通訊是我們必須要想辦法解決該問題。

現在的熱點解決方案是採用去中心化(decentralize)的思想,即使任務節點繞過主節點,直接利用相鄰任務節點的資訊,這樣可以大大降低通訊量[5][6][7]。這種方法我們未來會著重介紹,此時按下不表。

當然,讀者可能會思考,我們可以不可以直接每個任務各自優化各的\(l_1\)正則目標函式,即每個任務直接採用近端梯度法求解下列的local lasso問題:

\[ \hat{\textbf{w}}_t = \underset{\textbf{w}_t}{\text{argmin}}\frac{1}{m_t}\sum_{i=1}^{m_t}L(y_{ti}, \langle \textbf{w}_t, \textbf{x}_{ti} \rangle)+\lambda_t ||\textbf{w}_t||_{1} \]

很遺憾,這種方法雖然做到了不同任務優化的解耦,但本質上變成了單任務學習,沒有充分利用好多工之間的聯絡(任務之間的練習須依靠group sparse正則項\(\text{pen}(\textbf{W})\)來實現)。那麼,有沒有即能夠減少通訊次數,又能夠儲存group regularization的基本作用呢?(暫時不考慮任務節點相互通訊的去中心化的方法)

2. 基於去偏lasso模型的分散式演算法

論文《distributed multitask learning》[8]提出的演算法介於傳統的分散式近端梯度法和local lasso之間,其計算只需要一輪通訊,但仍然保證了使用group regularization所帶來的統計學效益。 該論文提出的演算法描述如下:

去偏lasso演算法

這裡我們特別說明一下第4行的操作,\(m_t^{-1}\mathbf{X}_t^T(\bm{y}_t - \mathbf{X}_t\hat{\bm{w}}_t)\)
是損失函式的次梯度;矩陣\(\textbf{M}_t\in \mathbb{R}^{d \times d}\)是Hessian矩陣的近似逆,\(m_t\)是任務\(t\)對應的樣本個數(事實上原論文假定\(m_1=m_2=...=m_T\));節點\(t\)對應的訓練資料是\((\mathbf{X}_t, \bm{y}_t)\)


這種求去偏lasso估計量的方法由最近關於高維統計[9][10][11]的文章提出,這些論文都企圖去除引入演算法第3行所示的\(l_1\)正則項所導致的偏差(bias),具體方法是運用\(l_1\)正則損失函式關於\(\bm{w}_t\)的次梯度來構造得到引數成分的無偏估計量\(\hat{\bm{w}}^u_t\)。下面我們會參照去偏估計器的取樣分佈,但我們的最終目標不同。[9][10][11]這三篇論文構造矩陣\(\mathbf{M}\)的方法不同,本篇論文主要參照論文[11]的方法,複合假設。每個機器使用矩陣\(\mathbf{M}_t=(\hat{\bm{m}}_{tj})_{j=1}^d\),它的行是:

\[\begin{aligned} & \hat{\bm{m}}_{tj} = \underset{\bm{m}_j \in \mathbb{R}^p}{\text{argmin}} \quad \bm{m}_j^T\hat{\mathbf{\Sigma}_t}\bm{m}_j \\ & \text{s.t.} \quad ||\hat{\mathbf{\Sigma}}_t\bm{m}_j - \bm{e}_j ||_{\infin} \leqslant u. \end{aligned} \]

這裡\(\bm{e}_j\)是第\(j\)個元素為1其他元素為0的(標準基)向量,\(\hat{\Sigma}_t={m_t}^{-1} \mathbf{X}_t^T\mathbf{X}_t\)


當每個任務節點得到去偏估計量\(\hat{\bm{w}}_t^u\)後,就會將其送往主節點。在主節點那邊,待從所有任務節點收到\(\{\hat{\bm{w}}_t^u\}_{t=1}^T\)後,就來到了第\(12\)行的操作。第\(12\)行的操作在master節點的操作充分利用了不同任務引數之間的共享稀疏性,即主節點將接收到的估計量拼接成矩陣\(\hat{\textbf{W}}^u=(\hat{\bm{w}}_1^u, \hat{\bm{w}}_2^u,..., \hat{\bm{w}}_T^u)\),然後再執行hard thresholding以過得\(\mathbf{S}\)的估計量:

\[\hat{S}(\Lambda)=\{j \text{ }| \text{ } ||\hat{\textbf{W}}_j^u||_2 > \Lambda \} \]

參考文獻

  • [1] Yuan M, Lin Y. Model selection and estimation in regression with grouped variables[J]. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 2006, 68(1): 49-67.
  • [2] Friedman J, Hastie T, Tibshirani R. A note on the group lasso and a sparse group lasso[J]. arXiv preprint arXiv:1001.0736, 2010.
  • [3] Zhao P, Rocha G, Yu B. The composite absolute penalties family for grouped and hierarchical variable selection[J]. The Annals of Statistics, 2009, 37(6A): 3468-3497.
  • [4] Liu H, Palatucci M, Zhang J. Blockwise coordinate descent procedures for the multi-task lasso, with applications to neural semantic basis discovery[C]//Proceedings of the 26th Annual International Conference on Machine Learning. 2009: 649-656.
  • [5] Zhang C, Zhao P, Hao S, et al. Distributed multi-task classification: A decentralized online learning approach[J]. Machine Learning, 2018, 107(4): 727-747.
  • [6] Yang P, Li P. Distributed primal-dual optimization for online multi-task learning[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2020, 34(04): 6631-6638.
  • [7] Li J, Abbas W, Koutsoukos X. Byzantine Resilient Distributed Multi-Task Learning[J]. arXiv preprint arXiv:2010.13032, 2020.
  • [8] Wang J, Kolar M, Srerbo N. Distributed multi-task learning[C]//Artificial intelligence and statistics. PMLR, 2016: 751-760.
  • [9] Zhang C H, Zhang S S. Confidence intervals for low dimensional parameters in high dimensional linear models[J]. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 2014, 76(1): 217-242.
  • [10] Van de Geer S, Bühlmann P, Ritov Y, et al. On asymptotically optimal confidence regions and tests for high-dimensional models[J]. The Annals of Statistics, 2014, 42(3): 1166-1202.
  • [11] Javanmard A, Montanari A. Confidence intervals and hypothesis testing for high-dimensional regression[J]. The Journal of Machine Learning Research, 2014, 15(1): 2869-2909.
  • [12] 楊強等. 遷移學習[M].機械工業出版社, 2020.

相關文章