Memory-Efficient Adaptive Optimization

馒头and花卷發表於2024-09-10

目錄
  • 符號說明
  • SM3
    • 區間的劃分
  • 程式碼

Anil R., Gupta V., Koren T., Singer Y. Memory-efficient adaptive optimization. NeurIPS, 2019.

本文提出了一種 memory-efficient 的最佳化器: SM3.

符號說明

  • \(t = 1,\ldots, T\), optimization rounds;
  • \(w_t \in \mathbb{R}^d\), paramter vector;
  • \(\ell_t\), convex loss function;
  • \(g_t = \nabla \ell_t (w_t)\), 梯度;
  • \(w^* \in \mathbb{R}^d\), optimal paramter.

SM3

  • 自適應的最佳化器 (AdaGrad) 形式如下:

    \[\gamma_t (i) = \sum_{s=1}^t g_s^2 (i), \quad \forall i \in [d], \]

    然後每一步按照如下的規則更新:

    \[w_{t+1}(i) = w_t(i) - \eta \frac{g_t(i)}{\sqrt{\gamma_t(i)}}, \quad \forall i \in [d]. \]

  • \(\gamma_t\) 的存在意味著我們需要 \(\mathcal{O}(d)\) 的額外儲存. 作者提出的 SM3 將這個額外的儲存消耗降低為 \(\mathcal{O}(k)\).

  • 首先, 透過某種方式確定 \(k\) 個非空子集:

    \[\{S_r\}_{r=1}^k, \quad \bigcup_{r=1}^k S_r = [d]. \]

  • 然後按照如下的方式更新:

  • 可以注意到, \(S_r\) 的存在相當於指定 \(i \in S_r\) 的引數共享一個自適應的學習率. 特別的, 由於 \(S_r\) 不一定是互斥的, 所以每一次我們從中挑選一個最好的. 作者證明了這個方法的收斂性.

  • 進一步的, 作者提出了一個更加的穩定的版本, 具有更好一點的 bound:

區間的劃分

  • 現在的問題是, 如何確定 \(\{S_r\}_{r=1}^k\), 作者給出的建議是, 對一個 \(m \times n\) 的權重, 可以分別按行共享和按列共享, 從而需要 \(m + n\) 個快取.

程式碼

[PyTorch-SM3]

相關文章