[原始碼解析] 深度學習流水線並行GPipe (2) ----- 梯度累積

羅西的思考發表於2021-08-26

[原始碼解析] 深度學習流水線並行GPipe (2) ----- 梯度累積

0x00 摘要

梯度累積是一種增大訓練時 batch size的技術,在本地使用 micro-batch 多次進行正向和反向傳播積累梯度後,再進行梯度規約和優化器更新,這是用來均攤通訊成本的一種常用策略。本文通過幾個框架/庫的實現對比,讓大家對這個技術有進一步的瞭解。

本系列其他文章如下:

[原始碼解析] 深度學習流水線並行Gpipe(1)---流水線基本實現

0x01 概述

1.1 前文回顧

前文提到,目前分散式模型訓練有幾個必要並行技術:

  • 流水並行,尤其是如何自動設定流水;
  • 梯度累加(Gradient Accumulation);
  • 後向重計算;
  • 1F1B 策略(我們將採用PipeDream分析);

在前文中,我們介紹了Gpipe如何實施流水線並行技術。本文我們介紹梯度累加(Gradient Accumulation)。

0x02 基本概念

梯度累積是一種用來均攤通訊成本的一種常用策略。它在本地使用 micro-batch 多次進行正向和反向傳播積累梯度後,再進行梯度規約和優化器更新,相當於擴大了N倍的batch size。

2.1 背景知識

深度學習模型由許多相互連線的層組成,樣本在這些層中進行傳播,具體傳播包含兩個過程:前向(forward)過程與反向(backword)過程。

  • 前向過程是從輸入計算得到輸出。樣本在每一步都通過前向傳播進行傳播,在通過所有層傳播後,網路為樣本生成預測,然後計算每個樣本的損失值,損失值意味著 “對於這個樣本,本網路錯了多少?”。
  • 然後就是反向過程。神經網路在此過程中計算這些損失值相對於模型引數的梯度。可以認為著就是一個梯度累積的過程。
  • 最後,這些梯度用於計算各個模型引數的更新。

訓練中,每個樣本的大小由超引數batch size指定,此引數的大小會對最終的模型效果產生很大的影響。一定條件下,batch size設定的越大,模型就會越穩定。

2.2 產生原因

累加梯度顧名思義就是累加後的梯度值。為什麼要累加呢?因為執行記憶體不夠用。

在訓練模型時,如果一次性將所有訓練資料輸入到模型,經常會造成記憶體不足,這時候就需要把一個大 Batch 拆分成若干小批次資料(專業術語為mini-batch)。分成小批次後,帶來一個問題,那就是本來應該是所有資料全部送入後計算梯度再更新引數,現在成了每個小批次都要計算梯度更新引數,為了不這麼頻繁計算梯度,於是就引入了累加梯度。也就是說:

  • 將整個dataset分成多個batch;
  • 分別將每個batch分成多個小批次,將每個小批次餵給神經網路;
  • 每個小批次雖然計算梯度,但是在每次反向傳播後,先不進行優化器的迭代更新。
  • 經過若干個小批次後(即一個batch中的所有小批次),用每個小批次計算的梯度的累積和去進行優化器迭代更新引數、梯度清零的操作。

這樣就跟把全部資料一次性送入模型進行訓練效果一樣了。

2.3 本質

梯度累加本質上就是累加 accumulation_stepsbatch_size/accumulation_steps 的梯度, 再根據累加的梯度來更新網路引數,以達到真實梯度類似batch_size 的效果。在使用時,需要注意適當的擴大學習率。

也就是說:

  • 首先將整個dataset分成多個batch,每個 batch size = 32,且假定 accumulation steps = 8
  • 因為 batch size = 32 ,太大了,單機顯示卡無法跑,於是我們在前向傳播的時候以 batch_size = 32 / 8 = 4 來計算梯度;
  • 這樣就再分別將每個batch分成多個batch size 為 4 的小批次,將每個小批次逐一餵給神經網路;
  • 每個小批次雖然計算梯度,但是在每次反向傳播(在反向傳播的時候,會將mean_loss也除以8)後,先不進行優化器的迭代更新。
  • 經過 accumulation steps 個小批次後(即一個batch中的所有小批次),用每個小批次計算梯度的累積和去進行優化器迭代更新引數。
  • 最後進行梯度清零的操作。
  • 處理下一個batch。

這樣就跟把 32 batch size 一次性送入模型進行訓練效果一樣了。

具體如下,時間軸是由左自右:

                                     +-------------------+
                                     |    GLOBAL BATCH   +--------------------------+
                                     +-------------------+                          |
                                                                                    |
                                                                                    |
 +<---------------------------------------------------------------------------------+
 |
 |
 |    +--------------+     +--------------+     +--------------+     +--------------+
 +--> | MINI BATCH 0 +---->+ MINI BATCH 1 +---->+ MINI BATCH 2 +---->+ MINI BATCH 3 |
      +-----+--------+     +-------+------+     +------+-------+     +-------+------+
            |                      |                   |                     |
            |                      |                   |                     |
            |                      |                   |                     |
            v                      v                   v                     v
       +----+-----+          +-----+-----+       +-----+-----+          +----+-----+
       |  grad 0  |          |  grad 1   |       |  grad 2   |          |  grad 3  |
       +----+-----+          +-----+-----+       +-----+-----+          +----+-----+
            |                      |                   |                     |
            |                      |                   |                     |
            |                      |                   |                     |
            v                      v                   v                     v
     +------+----------------------+-------------------+---------------------+------+
     |                                                                              |
     |                              GLOBAL BATCHGRADIENTS                           |
     |                                                                              |
     +------------------------------------------------------------------------------+


+------------------------------------------------------------------------------------>
                                                                        Time

2.4 VS 資料並行

micro-batch 跟資料並行有高度的相似性:

  • 資料並行是空間上的,資料被拆分成多個 tensor,同時餵給多個裝置平行計算,然後將梯度累加在一起更新。
  • micro-batch 是時間上的資料並行,資料被拆分成多個 tensor,這些 tensor 按照時序依次進入同一個裝置序列計算,然後將梯度累加在一起更新。

當總的 batch size 一致,且資料並行的並行度和 micro-batch 的累加次數相等時,資料並行和 Gradient Accumulation 在數學上完全等價。

Gradient Accumulation 通過多個 micro-batch的梯度累加使得下一個 micro-batch 的前向計算不需要依賴上一個 micro-batch 的反向計算,因此可以暢通無阻的進行下去(當然在一個大 batch 的最後一次 micro-batch 還是會觸發這個依賴)。

2.5 解決問題

Gradient Accumulation 解決了很多問題:

  • 在單卡下,Gradient Accumulation 可以將一個大的 batch size 拆分成等價的多個小 micro-batch ,從而達到節省視訊記憶體的目的。
  • 在資料並行下,Gradient Accumulation 解決了反向梯度同步開銷佔比過大的問題(隨著機器數和裝置數的增加,梯度的 AllReduce 同步開銷也加大),因為梯度同步變成了一個稀疏操作,因此可以提升資料並行的加速比。
  • 在流水線並行下, Gradient Accumulation 使得不同 stage 之間可以並行執行不同的 micro-batch,通過多個 micro-batch的梯度累加使得下一個 micro-batch 的前向計算不需要依賴上一個 micro-batch 的反向計算,因此從而讓各個階段的計算不阻塞,可以暢通無阻的進行下去(當然在一個大 batch 的最後一次 micro-batch 還是會觸發這個依賴), 達到流水線的目的。

0x03 PyTorch 梯度累積

3.1 自動累積

PyTorch預設會對梯度進行累加。即,PyTorch會在每一次backward()後進行梯度計算,但是梯度不會自動歸零,如果不進行手動歸零的話,梯度會不斷累加.

至於為什麼PyTorch有這樣的特點,https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/9 這裡給出了一個解釋。我們結合其他的解釋大致得出如下:

  • 從PyTorch的設計原理上來說,在每次進行前向計算得到預測值時,會產生一個用於梯度回傳的計算圖,這張圖儲存了進行反向傳播需要的中間結果,當呼叫了.backward()後,會從記憶體中將這張圖進行釋放。

  • 利用梯度累加,可以在最多儲存一張計算圖的情況下進行多工的訓練。在多工中,對前面共享的張量進行了多次計算操作後,呼叫不同任務的backward(),那些張量的梯度會自動累加。

  • 另外一個理由就是在記憶體大小不夠的情況下疊加多個batch的grad作為一個大batch進行迭代,因為二者得到的梯度是等價的。

  • 由於PyTorch的動態圖和autograd機制,導致並沒有一個確切的點知道何時停止前向操作,因為你不知道什麼時候一個計算會結束以及什麼時候又會有一個新的開始。所以自動設定梯度為 0 比較棘手。

3.2 程式碼示例

下面給出一個傳統程式碼示例:

for i,(images,target) in enumerate(train_loader):
    # 1. input output
    images = images.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
    outputs = model(images)
    loss = criterion(outputs,target)
  
    # 2. backward
    optimizer.zero_grad()   # reset gradient
    loss.backward()
    optimizer.step()

然後給出一個梯度累積示例:

  • 獲取loss: 輸入影像和標籤,通過計算得到預測值,計算損失函式;
  • loss.backward()反向傳播,計算當前梯度;
  • 多次迴圈步驟 1-2, 不清空梯度,使梯度累加在已有梯度上;
  • 梯度累加一定次數後,先optimizer.step()根據累積的梯度更新網路引數,然後optimizer.zero_grad()清空過往梯度,為下一波梯度累加做準備;
for i, (images, target) in enumerate(train_loader):
    # 1. input output
    images = images.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
    outputs = model(images) # 前向傳播
    loss = criterion(outputs, target) # 計算損失

    # 2. backward
    loss.backward() # 反向傳播,計算當前梯度
    
     # 3. update parameters of net
    if ((i+1)%accumulation)==0:
        # optimizer the net
        optimizer.step() # 更新網路引數
        optimizer.zero_grad() # reset grdient # 清空過往梯度

3.3 DistributedDataParallel 的梯度累積

DistributedDataParallel(DDP)在module級別實現資料並行性。其使用torch.distributed包communication collectives來同步梯度,引數和緩衝區。並行性在單個程式內部和跨程式均有用。

在這種情況下,雖然gradient accumulation 也一樣可以應用,但是為了提高效率,需要做相應的調整。

3.3.1 單卡模型梯度累計

我們首先回憶單卡模型,即普通情況下如何進行梯度累加。

# 單卡模式,即普通情況下的梯度累加
for data in enumerate(train_loader # 每次梯度累加迴圈
    optimizer.zero_grad()
    for _ in range(K):
        prediction = model(data / K)
        loss = loss_fn(prediction, label) / K
        loss.backward()  # 積累梯度,不應用梯度改變,執行K次
    optimizer.step()  # 應用梯度更新,更新網路引數,執行一次

在 loss.backward() 語句處,DDP會進行梯度規約 all_reduce。

因為每次梯度累加迴圈之中有K個步驟,所以有K次 all_reduce。但實際上,每次梯度累加迴圈中,optimizer.step()只有一次,這意味著我們這K次 loss.backward() 之中,其實只進行一次 all_reduce 即可,前面 K - 1 次 all_reduce 是沒有用的

3.3.2 DDP如何加速

於是我們就思考,是否可以在 loss.backward() 之中有一個開關,使得我們在前面K-1次 loss.backward() 之中只做反向傳播,不做梯度同步(累積)

DDP 已經想到了這個問題,它提供了一個暫時取消梯度同步的context函式 no_sync()。在no_sync()context之下,DDP不會進行梯度同步。但是在no_sync()上下文結束之後的第一次 forward-backward 會進行同步。

最終程式碼如下:

model = DDP(model)

for data in enumerate(train_loader # 每次梯度累加迴圈
    optimizer.zero_grad()
    
    for _ in range(K-1):# 前K-1個step 不進行梯度同步(累積梯度)。
        with model.no_sync(): # 這裡實施“不操作”
            prediction = model(data / K)
            loss = loss_fn(prediction, label) / K
            loss.backward()  # 積累梯度,不應用梯度改變
    
    prediction = model(data / K)
    loss = loss_fn(prediction, label) / K 
    loss.backward()  # 第K個step 進行梯度同步(累積梯度)
    optimizer.step() # 應用梯度更新,更新網路引數  

3.3.3 no_sync實現

no_sync 的程式碼如下:

    @contextmanager
    def no_sync(self):
        r"""
        A context manager to disable gradient synchronizations across DDP
        processes. Within this context, gradients will be accumulated on module
        variables, which will later be synchronized in the first
        forward-backward pass exiting the context.

        Example::

            >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
            >>> with ddp.no_sync():
            >>>   for input in inputs:
            >>>     ddp(input).backward()  # no synchronization, accumulate grads
            >>> ddp(another_input).backward()  # synchronize grads
        """
        old_require_backward_grad_sync = self.require_backward_grad_sync
        self.require_backward_grad_sync = False
        try:
            yield
        finally:
            self.require_backward_grad_sync = old_require_backward_grad_sync

具體如何使用?我們在 DistributedDataParallel 的 forward 方法中可以看到,只有在 require_backward_grad_sync 為 True時候,才會呼叫reducer.prepare_for_forward() 和 reducer.prepare_for_backward,才會把require_forward_param_sync 設定為 True。

   def forward(self, *inputs, **kwargs):
    
        with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
            
            self.reducer.save_thread_local_state()
            if torch.is_grad_enabled() and self.require_backward_grad_sync:
                # True時候才會進入
                self.logger.set_runtime_stats_and_log()
                self.num_iterations += 1
                self.reducer.prepare_for_forward()
            
            # 省略部分程式碼

            if torch.is_grad_enabled() and self.require_backward_grad_sync:
                # True時候才會進入
                self.require_forward_param_sync = True
                if self.find_unused_parameters and not self.static_graph:
                    # Do not need to populate this for static graph.
                    self.reducer.prepare_for_backward(list(_find_tensors(output)))
                else:
                    self.reducer.prepare_for_backward([])
            else:
                self.require_forward_param_sync = False

			# 省略部分程式碼

再看看 Reducer的兩個方法。

prepare_for_forward 只是做統計工作,可以忽略。

void Reducer::prepare_for_forward() {
  std::lock_guard<std::mutex> lock(mutex_);
  num_iterations_++;
  if (should_collect_runtime_stats()) {
    record_forward_compute_start_time();
  }
}

prepare_for_backward 會做重置和預備工作,與梯度累積相關的是 expect_autograd_hooks_ = true

void Reducer::prepare_for_backward(
    const std::vector<torch::autograd::Variable>& outputs) {
  std::lock_guard<std::mutex> lock(mutex_);

  // Reset accounting.
  expect_autograd_hooks_ = true; // 這裡是關鍵
  reset_bucket_counting();

  // Reset unused parameter accounting.
  has_marked_unused_parameters_ = false;
  // Reset per iteration marked ready parameters.
  perIterationReadyParams_.clear();

  // If static graph is not set, search graph to detect unused parameters.
  // When static graph is set, unused_parameters_ will be detected and will
  // not change after 1st iteration.
  // If static_graph_ = false and find_unused_parameters_ is false,
  // we assume that autograd hooks for ALL variables will be called,
  // and we don't have to search the autograd graph for presence of these hooks.
  if (dynamic_graph_find_unused()) {
    unused_parameters_.clear();
    search_unused_parameters(outputs);
  }
}

expect_autograd_hooks_ = true 如何使用?在 Reducer::autograd_hook 之中有,如果不需要進行all-reduce操作,則直接返回。

void Reducer::autograd_hook(VariableIndex index) {
    
  std::lock_guard<std::mutex> lock(this->mutex_);

  // Carry over thread local state from main thread. This allows for
  // thread-local flags such as profiler enabled to be configure correctly.
  at::ThreadLocalStateGuard g(thread_local_state_);

  // Ignore if we don't expect to be called.
  // This may be the case if the user wants to accumulate gradients
  // for number of iterations before reducing them.
  if (!expect_autograd_hooks_) { // 如果不需要進行all-reduce操作,則直接返回。
    return;
  }

  // 省略後續程式碼

有點繞,我們梳理一下:

一個 step 有兩個操作:forward 和 backward。

  • forward 操作時候 :require_backward_grad_sync = True 意味著 forward 時候
    • 設定 require_forward_param_sync = True。
    • 會呼叫reducer.prepare_for_forward() 和 reducer.prepare_for_backward
    • reducer.prepare_for_backward 意味著會設定 expect_autograd_hooks_ = true,expect_autograd_hooks_是關鍵。
  • backward 操作時候
    • expect_autograd_hooks_ = true 意味著 backward 時候進行 進行all-reduce操作。
    • 否則直接返回,不做 all-reduce操作。

即如下圖,

  • 上半部分是 forward 的邏輯,就是 forward()函式,
  • 下半部分是 backward 邏輯,就是 Reducer::autograd_hook() 函式。
  • expect_autograd_hooks_ 是forward 和 backward 之間串聯的關鍵之處。
forward
+---------------------------------------------------------------------------------+
| forward()                                                                       |
|                                                                                 |
|                require_backward_grad_sync == True?? +---------+                 |
|                             +                                 |                 |
|                             |                                 |                 |
|                             | Yes                             |                 |
|                             |                                 | No              |
|                             v                                 |                 |
|                 reducer.prepare_for_forward                   |                 |
|                             +                                 |                 |
|                             |                                 |                 |
|                             |                                 |                 |
|                             v                                 |                 |
|                 reducer.prepare_for_backward                  |                 |
|                             +                                 |                 |
|                             |                                 |                 |
|                             |                                 |                 |
|                             v                                 v                 |
|                 expect_autograd_hooks_ = true    expect_autograd_hooks_ = false |
|                             +                                 +                 |
|                             |                                 |                 |
+---------------------------------------------------------------------------------+
                              |                                 |
+--------------------------------------------------------------------------------+
 backward                     |                                 |
                              |                                 |
 +--------------------------------------------------------------------------------+
 |                            |                                 |                 |
 | Reducer::autograd_hook()   |                                 |                 |
 |                            |                                 |                 |
 |                            |    +----------------------------+                 |
 |                            |    |                                              |
 |                            |    |                                              |
 |                            v    v                                              |
 |                 expect_autograd_hooks_ == True?? +------------+                |
 |                            +                                  |                |
 |                            | Yes                              |                |
 |                            |                                  |  No            |
 |                            v                                  v                |
 |                      Do All-Reduce                          Return             |
 |                                                                                |
 |                                                                                |
 +--------------------------------------------------------------------------------+

no_sync 操作就 意味著設定 require_backward_grad_sync = False,最終設定了 expect_autograd_hooks_ = False。這樣,backward 時候就不會進行 All-Reduce 操作

0x04 Tensorflow實現

在 pytorch 中,梯度只要不清零預設是累加的,於是很容易實現上述問題。但在Tensorflow中,卻不那麼容易。

我們從 stackoverflow 得到示例程式碼如下:

## 定義優化器
opt = tf.train.AdamOptimizer()

## 得到你模型中的所有可訓練變數
tvs = tf.trainable_variables()

# 用於記錄每個變數的累積梯度,初始化為0s
accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in tvs]
# 定義清零操作
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_vars]

## 使用優化器的compute_gradients來計算梯度
gvs = opt.compute_gradients(rmse, tvs)

## 將當前梯度累加在之前定義的變數上
accum_ops = [accum_vars[i].assign_add(gv[0]) for i, gv in enumerate(gvs)]

## 定義訓練step,梯度下降,更新引數
train_step = opt.apply_gradients([(accum_vars[i], gv[1]) for i, gv in enumerate(gvs)])

## 訓練迴圈
while ...:
    # 使用 zero_ops 初始化
    sess.run(zero_ops)
    # 使用accum_ops對accum_vars進行'n_minibatches'次梯度累積
    for i in xrange(n_minibatches):
        sess.run(accum_ops, feed_dict=dict(X: Xs[i], y: ys[i]))
    # 使用累積的梯度進行引數更新
    sess.run(train_step)

0x05 Gpipe實現

在 GPipe 的流水並行示例中,每個“時間點” 可以在多個階段(stage)上同時做不同的micro-batch,圖中每個方塊中的標號表示了第幾個 micro-batch;同一個 micro-batch 還是序列的經過所有的 stage,在這種情況下,每個裝置的空閒時間只有 25% 左右。

具體程式碼如下:

5.1 優化器

在 lingvo/core/optimizer.py 中 GradientAggregationOptimizer 中有具體實現,關鍵程式碼為apply_gradients,邏輯為:

  • 如果 _num_micro_batches 為 1,則說明不用梯度累積,直接 apply_gradients;
  • 遍歷 grads_and_vars 列表,累積梯度;
  • accum_step 為梯度累積條件:
    • 如果達到了小批次迭代數目,則呼叫 _ApplyAndReset:
      • 呼叫 apply_gradients 應用梯度;
      • 呼叫 zero_op 清零梯度;
    • 否則就呼叫_Accum,實際上是 no_op不做操作;

具體程式碼如下:

  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    if self._num_micro_batches == 1:
      return self._opt.apply_gradients(grads_and_vars, global_step)
    global_step = global_step or py_utils.GetOrCreateGlobalStepVar()
    with tf.init_scope():
      self._create_slots([v for (_, v) in grads_and_vars])

    accums = []
    variables = []

    # 遍歷,累積梯度
    for g, v in grads_and_vars:
      accum = self.get_slot(v, 'grad_accum')
      variables.append(v)
      # pytype: disable=attribute-error
      if isinstance(g, tf.IndexedSlices):
        scaled_grad = tf.IndexedSlices(
            g.values / self._num_micro_batches,
            g.indices,
            dense_shape=g.dense_shape)
      else:
        scaled_grad = g / self._num_micro_batches
      accum_tensor = accum.read_value()
      accums.append(accum.assign(accum_tensor + scaled_grad))
      # pytype: enable=attribute-error

    # 應用梯度,清零梯度
    def _ApplyAndReset():
      normalized_accums = accums
      if self._apply_crs_to_grad:
        normalized_accums = [
            tf.tpu.cross_replica_sum(accum.read_value()) for accum in accums
        ]
      apply_op = self._opt.apply_gradients(
          list(zip(normalized_accums, variables)))
      with tf.control_dependencies([apply_op]):
        zero_op = [tf.assign(accum, tf.zeros_like(accum)) for accum in accums]
      return tf.group(zero_op, tf.assign_add(global_step, 1))

    # 累積函式,其實是不做操作
    def _Accum():
      return tf.no_op()

    # 梯度累積條件,如果達到了小批次迭代數目,則應用梯度,清零梯度,否則就不做操作
    accum_step = tf.cond( 
        tf.equal(
            tf.math.floormod(self._counter + 1, self._num_micro_batches), 0),
        _ApplyAndReset,  # Apply the accumulated gradients and reset.
        _Accum)  # Accumulate gradients.

    with tf.control_dependencies([tf.group(accums)]):
      return tf.group(accum_step, tf.assign_add(self._counter, 1))

5.2 包裝器

ShardedAdam 是給 GradientAggregationOptimizer 和 ShardedAdamOptimizer 做了包裝,使用者可以直接使用。

class ShardedAdam(optimizer.Adam):
  """Adam optimizer wrapper that shards the slot variables."""

  @classmethod
  def Params(cls):
    params = super().Params()
    params.Define('num_micro_batches', 1, 'Number of accumulated batches.')
    return params

  def GetOptimizer(self, lr):
    p = self.params
    opt = ShardedAdamOptimizer(
        learning_rate=lr,
        beta1=p.beta1,
        beta2=p.beta2,
        epsilon=p.epsilon,
        name=p.name)
    if p.num_micro_batches > 1:
      tf.logging.info('Applying gradient aggregation.')
      
      opt = optimizer.GradientAggregationOptimizer( # 應用梯度累積
          opt, p.num_micro_batches, apply_crs_to_grad=True)
      self._cached_opt = opt
    return opt

5.3 應用

DenseLm12kWide41BAdam16x16 中有如何使用 ShardedAdam。

@model_registry.RegisterSingleTaskModel
class DenseLm12kWide41BAdam16x16(DenseLm128B16x16):
  """41B params LM model with 2D split and ADAM optimizer on v3-512."""

  # Each layer has 1.6875B parameters.
  SEQUENCE_LENGTH = 2048
  NUM_DEVICES_PER_SPLIT = 512
  BATCH_DIM_PER_DEVICE = 0.5  # Total batch size 256
  DEVICE_MESH_SHAPE = [16, 32]
  DEVICE_MESH = gshard_utils.GetNonPod2dMesh(DEVICE_MESH_SHAPE, [16, 16, 2])
  NUM_TRANSFORMER_LAYERS = 24
  HIDDEN_DIM = 48 * 1024
  MODEL_DIM = 12 * 1024
  NUM_HEADS = 96
  ATTENTION_KEY_VALUE_DIM = 128
  GATED_GELU = False
  POSITIONAL_EMBEDDING = True
  NUM_MICRO_BATCHES = 1

  def Task(self):
    p = super().Task()
    
    # 使用ShardedAdam
    p.train.optimizer = ShardedAdam.Params().Set(
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-6,
        num_micro_batches=self.NUM_MICRO_BATCHES)
    return p

0xFF 參考

[原創][深度][PyTorch] DDP系列第三篇:實戰與技巧

相關文章