[原始碼解析] PyTorch 分散式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作

羅西的思考 發表於 2021-11-25
PyTorch

[原始碼解析] PyTorch 分散式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作

0x00 摘要

因為前文已經圍繞Reducer相關的各種成員變數做了相關分析,所以本文開始做動態邏輯分析,目的是:把前面幾篇文章串聯起來,為後面分析前向傳播和反向傳播設定基礎。

本系列其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀

[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎

[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構

[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯

[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法

[原始碼解析] PyTorch 分散式(1)------歷史和概述

[原始碼解析] PyTorch 分散式(2) ----- DataParallel(上)

[原始碼解析] PyTorch 分散式(3) ----- DataParallel(下)

[原始碼解析] PyTorch 分散式(4)------分散式應用基礎概念

[原始碼解析] PyTorch分散式(5) ------ DistributedDataParallel 總述&如何使用

[原始碼解析] PyTorch分散式(6) ---DistributedDataParallel -- 初始化&store

[原始碼解析] PyTorch 分散式(7) ----- DistributedDataParallel 之程式組

[原始碼解析] PyTorch 分散式(8) -------- DistributedDataParallel之論文篇

[原始碼解析] PyTorch 分散式(9) ----- DistributedDataParallel 之初始化

[原始碼解析] PyTorch 分散式(10)------DistributedDataParallel 之 Reducer靜態架構

0x01 引論

為了更好的分析,我們還是需要看看如何呼叫。

1.1 呼叫

Reducer 的建立程式碼如下,是在_ddp_init_helper 之中。

        # Note: reverse list of buckets because we want to approximate the
        # order in which their gradients are produced, and assume they
        # are used in the forward pass in the order they are defined.
        self.reducer = dist.Reducer(
            parameters, # parameters[0]是張量列表
            list(reversed(bucket_indices)), # 桶資訊
            self.process_group,
            expect_sparse_gradient,
            self.bucket_bytes_cap,
            self.find_unused_parameters,
            self.gradient_as_bucket_view,
            param_to_name_mapping,
        )

1.2 引數說明

呼叫的 parameters 舉例如下, parameters[0] 就是 rank 0 上模型的 parameters,可以看到其只有 [0] 元素有意義,這個 [0] 原始本身包括 20 個元素:

parameters = {list: 1} 
0 = {list: 4}           
 0 = {Parameter: 10} Parameter containing:\ntensor([[-4.0381e-02,  3.8828e-02, 1  )   
 1 = {Parameter: 10} Parameter containing:\ntensor([-0.0438, -0.2033,  0.2771,  0.0721,  ) 
 2 = {Parameter: 5} Parameter containing:\ntensor([[-0.0094, -0.1319,  0.0713,  0.3155,  )
 3 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )
 ...
 20 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )                                                   
 __len__ = {int} 20
__len__ = {int} 1

bucket_indices 舉例如下:

關於 tensor indices,就是給所有的tensor一個index,從0開始遞增,一直到 tensors.size()。假如模型的 parameters 一共有20個張量,則 tensor index 從 0 到 19,分成 6 個buckets,則在這6個buckets之中,每個 tensor index 都是唯一不重複的。

+-----------------------------------------------------------------------+
|                                                                       |
|  <tensor index 0, tensor index 1, tensor index 2, tensor index 3>     |
|                                                                       |
|                                                                       |
|  <tensor index 4, tensor index 5, tensor 6>                           |
|                                                                       |
|                                                                       |
|  ......                                                               |
|                                                                       |
|                                                                       |
|  <tensor index 16, tensor index 17, tensor index 18, tensor index 19> |
|                                                                       |
+-----------------------------------------------------------------------+

接下來,我們就看看如何進行初始化 Reducer。

0x02 Reducer 初始化

程式碼位於:torch/lib/c10d/reducer.h 和 torch/lib/c10d/reducer.cpp

2.1 建構函式

具體邏輯如下:

  • 看看本模組是不是多裝置模組,具體是: 遍歷張量,得到張量的裝置,把裝置插入到一個set結構之中,如果set內的裝置多於一個,是多裝置
  • 如果 expect_sparse_gradients沒有設定,就把expect_sparse_gradients_初始化為false。
  • 呼叫 initialize_buckets 初始化 buckets 並儘可能按照逆序將 parameters 分配到 buckets 之中,這樣按桶通訊就可以提高效率。後續在執行時候也可能再次重新初始化桶。
  • 為每個 parameter 加上 grad_accumulator,它們在 backward 時負責梯度同步。
    • 因為這些variables是autograd圖的葉子張量,所以它們的grad_fn都被設定為 gradient accumulation function。
    • Reducer儲存了指向這些functions的指標,這樣Reducer就可以知道它們在autograd傳播之中是否被使用,如果沒有使用,那麼就把這些functions的梯度張量(grad tensors)設定為規約就緒狀態。
    • 遍歷張量,為每個張量生成一個型別為VariableIndex的變數index。
    • 得到Variable::AutogradMeta的grad_accumulator_,即用於累加葉子 Variable 的梯度累加器。
    • 把reducer的autograd_hook函式新增進去每個grad_accumulator_之中,變數index是hook的引數。這個 hook 掛在 autograd graph 之上,在 backward 時負責梯度同步。grad_accumulator 執行完後,autograd_hook 就會執行。
  • gradAccToVariableMap_ 存了grad_accumulator & index 的對應關係(函式指標和引數張量的對應關係),這樣以後在 autograd graph 遍歷尋找 unused parameters 就方便了。
  • 初始化 backward_stats_。
  • 呼叫 initialize_local_used_map 初始化各種 unused map。
// The constructor takes a list of variables for every model replica.
// The bucket assignment for this reducer is specified as a list of
// buckets, each of which is specified as a list of indices into the
// variables list for **a single replica** (i.e. `variables[0]`).
Reducer::Reducer(
    std::vector<std::vector<at::Tensor>> replicas, // 張量
    std::vector<std::vector<size_t>> bucket_indices, // 桶資訊
    c10::intrusive_ptr<c10d::ProcessGroup> process_group,
    std::vector<std::vector<bool>> expect_sparse_gradients,
    int64_t bucket_bytes_cap,
    bool find_unused_parameters,
    bool gradient_as_bucket_view,
    std::unordered_map<size_t, std::string> paramNames)
    : replicas_(std::move(replicas)),
      process_group_(std::move(process_group)),
      expect_sparse_gradients_(std::move(expect_sparse_gradients)),
      expect_autograd_hooks_(false),
      require_finalize_(false),
      next_bucket_(0),
      has_marked_unused_parameters_(false),
      find_unused_parameters_(find_unused_parameters),
      gradient_as_bucket_view_(gradient_as_bucket_view),
      local_used_maps_reduced_(false),
      num_iterations_(0),
      num_buckets_ready_(0),
      has_rebuilt_bucket_(false),
      bucket_bytes_cap_(bucket_bytes_cap),
      divFactor_(kUnsetDivFactor),
      static_graph_(false),
      comm_hook_(nullptr),
      thread_local_state_(at::ThreadLocalState()),
      ddp_debug_level_(parseDistDebugLevel()),
      param_names_(std::move(paramNames)) {

  // Check whether the module is multi_device_module
  // 看看本模組是不是多裝置模組
  {
    std::set<int> unique_devices;
    for (const auto& v : replicas_[0]) { // 遍歷張量
      auto device_idx = int(v.device().index()); // 得到張量的裝置
      if (unique_devices.find(device_idx) == unique_devices.end()) {
        unique_devices.insert(device_idx); // 把裝置插入到一個set結構之中
        if (unique_devices.size() > 1) { // 如果set內的裝置多於一個,是多裝置
          is_multi_device_module_ = true; 
          break;
        }
      }
    }
  }

  // If `expect_sparse_gradients` is not specified, initialize it such that
  // we do not expect sparse gradients for any parameter.
  if (expect_sparse_gradients_.empty()) {
    expect_sparse_gradients_ = std::vector<std::vector<bool>>(
        replicas_.size(), std::vector<bool>(replicas_[0].size(), false));
  }

  // Initialize variable bucketing.
  // This can be reinitialized later after capturing runtime information.
  {
    std::lock_guard<std::mutex> lock(mutex_);
    initialize_buckets(std::move(bucket_indices)); //初始化桶
  }

  // All variables are expected to have their `grad_fn` set to the gradient
  // accumulation function (since they are leafs in the autograd graph).
  // We store pointers to these functions such that we can check if they are
  // used in an autograd pass. If they are not, we know their grad tensors
  // can be marked as ready for reduction.
  {
    const auto replica_count = replicas_.size();
    grad_accumulators_.resize(replica_count);
    for (size_t replica_index = 0; replica_index < replica_count; // 只有replicas_[0]有意義
         replica_index++) {
      const auto variable_count = replicas_[replica_index].size(); //張量數目
      grad_accumulators_[replica_index].resize(variable_count); // 給grad_accumulators_分配記憶體
        
      for (size_t variable_index = 0; variable_index < variable_count;
           variable_index++) { // 遍歷張量,variable_index 就是張量的index
        auto& variable = replicas_[replica_index][variable_index]; //得到具體的張量
        const auto index = VariableIndex(replica_index, variable_index); //每個張量生成一個VariableIndex

        // The gradient accumulator function is lazily initialized once.
        // Therefore we can use its presence in the autograd graph as
        // evidence that the parameter has participated in an iteration.
        auto grad_accumulator =
            torch::autograd::impl::grad_accumulator(variable); // 得到Variable::AutogradMeta的grad_accumulator_,即,用於累加葉子 Variable 的梯度累加器

#ifndef _WIN32
        using torch::distributed::autograd::ThreadLocalDistAutogradContext;
#endif
        // Hook to execute after the gradient accumulator has executed.
        hooks_.emplace_back(
            // 累加器新增hook,這個 hook 掛在 autograd graph 之上,在 backward 時負責梯度同步。
            // grad_accumulator 執行完後,autograd_hook 就會執行
            grad_accumulator->add_post_hook(
                torch::make_unique<torch::autograd::utils::LambdaPostHook>(
                    [=](const torch::autograd::variable_list& outputs,
                        const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32
                      this->rpc_context_.set(
                          ThreadLocalDistAutogradContext::getContextPtr());
#endif
                      this->autograd_hook(index); // 把reducer的autograd_hook函式新增進去
                      return outputs;
                    })),
            grad_accumulator);

        // Map raw function pointer to replica index and parameter index.
        // This is used later on when the autograd graph is traversed
        // to check for parameters for which no gradient is computed, if
        // find_unused_parameters=True.
        // Note that the mapping of gradient accumulator to variable should be
        // one to one as we deduplicate shared parameters before constructing
        // Reducer.
          
        // gradAccToVariableMap_ 存了grad_accumulator & index 的對應關係(函式指標和引數張量的對應關係),這樣以後在 autograd graph 遍歷尋找 unused parameters 就方便了
        if (find_unused_parameters_) {
          gradAccToVariableMap_[grad_accumulator.get()] = index;
        }

        numGradHooksTriggeredMap_[index] = 0;

        // The gradient accumulator is stored as weak_ptr in the autograd
        // metadata of the variable, so we have to keep it alive here for
        // the raw pointer to be valid.
        TORCH_CHECK(
            grad_accumulators_[replica_index][variable_index] == nullptr,
            c10::str(
                "Reducer tried to register duplicate grad accumulator for replica ",
                replica_index,
                " variable ",
                variable_index));
        grad_accumulators_[replica_index][variable_index] =
            std::move(grad_accumulator);
      }
    }
  }

  // Initialize backward stats vector.
  {
    const auto replica_count = replicas_.size();
    backward_stats_.resize(replica_count);
    const auto variable_count = replicas_[0].size();
    std::for_each(
        backward_stats_.begin(),
        backward_stats_.end(),
        [=](std::vector<int64_t>& v) { v.resize(variable_count); });
  }

  // See Note [Skip allreducing local_used_maps_dev]
  if (find_unused_parameters_) {
    initialize_local_used_map();
  }
}

我們接下來具體分析每一個部分。

2.2 初始化桶

initialize_buckets方法用來初始化桶,具體邏輯是對於每一個桶,新增其模型副本,對於每一個模型副本,新增張量列表:

  • 用分散式上下文設定 rpc_context_。

    • 如果在DDP建構函式內呼叫initialize_bucket,則 rpc上下文指標(rpc context ptr)是否為null 無關緊要,因為grad不會發生變化。
    • 如果在訓練迴圈期間呼叫initialize_bucket,例如在rebuild_bucket 內部,因為grad可能會發生改變並指向bucket_view,那麼它需要檢查rpc context ptr是否為null。
    • 如果rpc context ptr是null,則改變 variable.grad(),否則,在rpc上下文中改變梯度。
  • 清空buckets_ 和 variable_locators_。

  • 重置variable_locators_的尺寸,這樣每個variable都有一個bucket index。

  • 利用如下得到所有桶的個數和每個桶中副本個數:bucket_count = bucket_indices.size(); replica_count = replicas_.size();

  • 從0開始遞增到 bucket_count,逐一初始化 Bucket。

    • 生成一個 Bucket bucket
    • 如果bucket_indices[bucket_index].size() == 1,說明這個桶期待一個single sparse gradient,則設定 bucket.expect_sparse_gradient = true。
    • 從0開始遞增到replica_count,逐一初始化 BucketReplica。
      • 生成一個 BucketReplica replica
      • 如果這個桶期待一個single sparse gradient,則
        • 利用bucket_indices[bucket_index].front()取出向量第一個元素,設定為 variable_index。
        • 利用 variable_index 得到副本之中對應的variable。
        • 設定副本replica的變數列表,程式碼為replica.variables = {variable},這個副本只包括一個variable。
      • 否則說明是dense gradient,則
        • 遍歷桶的variable,即利用 replicas_[replica_index][variable_index] 得到variable。
        • 設定variable的裝置和資料型別
        • 給副本設定其variables,程式碼為:replica.variables.push_back(variable)。
        • 設定replica 的一些關於variable的元資訊,這些元資訊是flat contents相關的,比如offsets儲存了各個張量在flat bucket contents中的offset。
        • 給relica.contents分配記憶體
        • 利用 initialize_bucket_views(replica, replica.contents) 初始化 cotnents 和 views。
        • 利用 bucket.replicas.push_back(std::move(replica)) 把這個 replica 加入到 bucket。
    • 遍歷桶中的variable,程式碼為 bucket_indices[bucket_index]。
      • 設定 Reducer.variable_locators_,這樣 Reducer 就知道如何在 bucket 之中確定一個varaible。bucket_indexbuckets_列表的位置,表示 buckets_ 之上的一個bucket。intra_bucket_index 是在 bucket replica 之中 vector 域的 variable index。
    • 設定桶的變數,bucket.variable_indices = std::move(bucket_indices[bucket_index]);
    • 利用 buckets_.push_back(std::move(bucket)) 把bucket這個桶加入到 Reducer之中。

具體程式碼是:

void Reducer::initialize_buckets(
    std::vector<std::vector<size_t>> bucket_indices) {
  // If initialize_buckets is called inside DDP constructor, then
  // it does not matter rpc context ptr is nullptr or not, as grad
  // will not be mutated.
  // If initialize_buckets is called during training loop, e.g, inside
  // rebuild_buckets(), since grad could be mutated and be pointed to
  // bucket_view, then it needs to check rpc context ptr is nullptr or not,
  // If rpc context ptr is nullptr, mutate variable.grad(); otherwise,
  // mutate grad in rpc context.
#ifndef _WIN32
  using torch::distributed::autograd::ThreadLocalDistAutogradContext;
  this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
#endif

  // This shouldn't be called if we're expecting autograd hooks to fire.
  TORCH_CHECK(
      !expect_autograd_hooks_,
      "`initialize_buckets` must NOT be called during autograd execution.");

  // Clear current bucket assignment.
  buckets_.clear();
  variable_locators_.clear();

  // Ensure we have a bucket index for every variable.
  variable_locators_.resize(replicas_[0].size());

  // Iterate over buckets.
  const auto bucket_count = bucket_indices.size();
  const auto replica_count = replicas_.size();
  buckets_.reserve(bucket_count);
  // 從0開始遞增到bucket_count
  for (size_t bucket_index = 0; bucket_index < bucket_count; bucket_index++) {
    Bucket bucket; // 生成一個桶

    // TODO(@pietern): Validate indices.
    // Must be non-empty, unique, and unique across buckets.
    TORCH_CHECK(
        bucket_indices[bucket_index].size() > 0, "Empty bucket specified.");

    // Variables that expect sparse gradients must have their own bucket.
    if (bucket_indices[bucket_index].size() == 1) {
      // 說明這個桶期待一個single sparse gradient
      const auto variable_index = bucket_indices[bucket_index].front();
      bucket.expect_sparse_gradient =
          expect_sparse_gradients_[0][variable_index];
    } else {
      for (const auto variable_index : bucket_indices[bucket_index]) {
        TORCH_CHECK(
            !expect_sparse_gradients_[0][variable_index],
            "Buckets with more than one variable cannot include variables ",
            "that expect a sparse gradient.");
      }
    }

    // Iterate over model replicas. 從0開始遞增到replica_count,遍歷模型副本數目,為每一個模型副本都要做同樣設定
    for (size_t replica_index = 0; replica_index < replica_count;
         replica_index++) {
      BucketReplica replica; // 生成一個副本

      if (bucket.expect_sparse_gradient) {
        // 說明這個桶期待一個single sparse gradient
        const auto variable_index = bucket_indices[bucket_index].front(); // 得到張量的index
        const auto& variable = replicas_[replica_index][variable_index]; // 得到張量
        TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1);
        replica.variables = {variable}; // 這個副本只包括一個variable
      } else {
        at::TensorOptions options;
        // The start index of the variable in the flattened tensor.
        size_t offset = 0;

        // Reserve enough space for the per-variable fields stored in bucket
        // replica for efficiency.
        const size_t num_variables = bucket_indices[bucket_index].size();
        replica.variables.reserve(num_variables); 
        replica.offsets.reserve(num_variables);
        replica.lengths.reserve(num_variables);
        replica.sizes_vec.reserve(num_variables);

        // Iterate over bucket variables.
        for (const auto variable_index : bucket_indices[bucket_index]) { //遍歷桶中的variable
          TORCH_CHECK(
              variable_index < replicas_[replica_index].size(),
              "Out of range variable index specified.");
          const auto& variable = replicas_[replica_index][variable_index];
          if (!options.has_device()) {
            options = options.device(variable.device());
          } else {
            TORCH_CHECK(
                variable.device() == options.device(),
                "All parameters in a bucket must be ",
                "placed on the same device.");
          }
          if (!options.has_dtype()) {
            options = options.dtype(variable.dtype());
          } else {
            TORCH_CHECK(
                variable.dtype() == options.dtype(),
                "All parameters in a bucket must have the same dtype.");
          }
          
          const auto length = variable.numel();
          // 給副本設定其variables
          replica.variables.push_back(variable); // 這裡新增了一個新變數,所以最終能知道該桶中的變數數目
          // 設定replica 的一些關於variable的元資訊
          replica.offsets.push_back(offset);
          replica.lengths.push_back(length);
          replica.sizes_vec.push_back(variable.sizes());
          offset += length;
        }

        // Allocate bucket contents tensor.
        replica.contents = at::empty({static_cast<long>(offset)}, options);

        initialize_bucket_views(replica, replica.contents); // 初始化cotents和views
      }

      // Add bucket replica to enclosing bucket.
      bucket.replicas.push_back(std::move(replica)); // 桶的副本列表中新增一個新副本
    }

    // Map participating variables to this bucket.
    // This is identical across replicas so we only need to do this once.
    size_t intra_bucket_index = 0;
    for (const auto variable_index : bucket_indices[bucket_index]) { // 遍歷桶中的variable
      TORCH_CHECK(
          variable_index < variable_locators_.size(),
          "Out of range variable index specified.");
      variable_locators_[variable_index] = // 這樣 Reducer 就知道如何在 bucket 之中確定一個varaible
          VariableLocator(bucket_index, intra_bucket_index++);
    }
    bucket.variable_indices = std::move(bucket_indices[bucket_index]);

    buckets_.push_back(std::move(bucket)); // 把桶插入Reducer
  }
}

2.3 初始化檢視

initialize_bucket_views 這裡是設定 Replica 的contents 和 views。

// (see Note:  "Gradient Layout Contract" in initialize_buckets).
void Reducer::initialize_bucket_views(
    Reducer::BucketReplica& replica,
    at::Tensor& contents) {
  for (size_t i = 0; i < replica.variables.size(); i++) {
    auto& v = replica.variables[i];
    const auto offset = replica.offsets[i];
    const auto length = replica.lengths[i];
    if (v.is_non_overlapping_and_dense()) { // Dense型別的張量
      // If the param's memory is dense, match its layout, anticipating
      // the autograd engine (AccumulateGrad) will also create gradients
      // matching its layout.
      replica.bucket_views_in.push_back( // replica.bucket_views_in裡面都是檢視
          contents.as_strided(v.sizes(), v.strides(), offset));
    } else { // Sparse型別的張量
      // Fall back to a C-style contiguous view, again anticipating
      // AccumulateGrad will do the same when stashing grads for non-dense
      // params.
      replica.bucket_views_in.push_back( // replica.bucket_views_in裡面都是檢視
          contents.narrow(0, offset, length).view(v.sizes()));
    }
    // By default `bucket_views_out` and `bucket_views_in` are
    // essentially the same thing.
    replica.bucket_views_out = replica.bucket_views_in; // out也是檢視

    // If gradient_as_bucket_view_ is set as true, then there are two cases to
    // handle: initialize_bucket_views could be called inside initialize_buckets
    // when rebuild_buckets, if grad has already been defined/calculated in
    // previous iteration, old grad needs to be copied into new bucket_view and
    // let grad point to the new bucket_view, initialize_bucket_views could also
    // be called inside initialize_buckets during construction. Grads are not
    // defined during construction time, in this case, do not let grad point to
    // bucket_view, because grads should be kept as being undefined for globally
    // unused parameters.
    if (gradient_as_bucket_view_) {
      auto& bucket_view = replica.bucket_views_in.back();
      runGradCallbackForVariable(v, [&](auto& grad) {
        if (grad.defined() && !grad.is_alias_of(bucket_view)) {
          bucket_view.copy_(grad);
          grad = bucket_view; // 梯度被修改了,需要回寫
          // The grad is modefied and needs to be written back.
          return true;
        }
        // The grad is not modified and does not need to be written back.
        return false; // 不需要回寫,因為沒有被修改
      });
    }
  }
}

2.3.1 BucketReplica成員變數

我們先回憶一下BucketReplica的幾個成員變數。

  • at::Tensor contents :把桶的內容展平的結果,即Flattened (1 dimensional) 之後的結果。
  • std::vector<at::Tensor> bucket_views_in :提供了從輸入角度在 contents 之中檢視具體梯度的方法。
  • std::vector<at::Tensor> bucket_views_out :提供了從輸入角度在 contents 之中檢視具體梯度的方法。

關於 std::vector<at::Tensor> bucket_views_instd::vector<at::Tensor> bucket_views_out 的進一步說明:

  • 這兩個變數提供在 contents 之中操作具體梯度的方法,或者說,它們提供了檢視(views),該檢視可以操作contents 之中每個張量的梯度。使用者把這兩個變數作為入口點來把每個梯度的資料從 content 之中移入和移出。
  • 在 PyTorch 之中,檢視是指建立一個方便檢視的東西,檢視與原資料共享記憶體,它只是將原有的資料進行整理,直接顯示其中部分內容或者進行重排序後再顯示出來。

也需要對幾個 PyTorch 函式進行說明。

  • as_strided :依據現有tensor以及給定的步長來建立一個檢視(型別仍然為tensor),需要注意,這裡的結果是檢視,所以這個張量依然和原始張量共享記憶體。
  • narrow :返回一個新的張量,其是原來張量的縮小版,但是這個張量依然和原始張量共享記憶體。

BucketReplica 邏輯具體如下圖:

+------------------------------------------+
| BucketReplica                            |
|                                          |
|       vector<Tensor> bucket_views_in +--------------------+
|                                          |                |
|                                          |                |
|       vector<Tensor> bucket_views_out +--------------+    |
|                                          |           |    |
|                                          |           |    |
|                                          |           v    v
|                                          |     +-----+----+--------------------------+
|       Tensor contents  +---------------------> |Flattened (Tensor1, Tensor2, Tensor3)|
|                                          |     +-------------------------------------+
|                                          |
|                                          |
|       vector<Tensor> variables  +------------>  [Tensor1,Tensor2,Tensor3]
|                                          |
|                                          |
|                                          |
+------------------------------------------+

2.3.2 呼叫

如何呼叫?如果gradient_as_bucket_view_設定為true,則有兩種情況需要處理:

  • rebuild_buckets 之中可以在initialize_bucket內呼叫initialize_bucket_view,如果grad在上一次迭代中已經定義/計算過,則需要將舊的grad複製到新的bucket_view中,並讓grad指向新的bucket_view,
  • 在構造過程中,也可以在initialize_bucket中呼叫initialize_bucket_views。在構造期間不會定義梯度,在這種情況下,不要讓梯度指向bucket_view,因為對於全域性未使用的引數,梯度應保持為未定義。

2.4 初始化本地使用變數

initialize_local_used_map此處是初始化 local_used_maps_,我們回憶一下論文內容,local_used_maps_ 就是用來查詢全域性未使用引數(Globally Unused Parameters):

全域性未使用引數(Globally Unused Parameters)的梯度在向前和向後過程中應保持不變。檢測未使用的引數需要全域性資訊,因為在一個DDP過程中,一個引數可能在一次操作中不存在,但可能在另一個過程的同一次迭代中參與訓練。因此DDP在點陣圖中維護本地未使用的引數資訊,並啟動額外的AllReduce以收集全域性點陣圖。由於點陣圖比張量尺寸小得多,因此模型中的所有引數共享同一點陣圖,而不是建立每桶點陣圖(per-bucket bitmaps)。點陣圖位於CPU上,以避免為每次更新啟動專用CUDA核心。但是,某些ProcessGroup後端可能無法在CPU 張量上執行AllReduce。例如,ProcessGroupNCCL僅支援CUDA張量。此外,由於DDP應該與任何定製的ProcessGroup後端一起工作,它不能假設所有後端都支援CPU張量。為了解決這個問題,DDP在同一裝置上維護另一個點陣圖作為第一個模型引數,並呼叫非阻塞拷貝操作(non-blocking copy)將CPU點陣圖移動到裝置點陣圖以進行集合通訊

具體程式碼如下:

void Reducer::initialize_local_used_map() {
  const auto replica_count = replicas_.size();
  const auto variable_count = replicas_[0].size();
  local_used_maps_.resize(replica_count);
  local_used_maps_dev_.resize(replica_count);

  for (size_t i = 0; i < replica_count; i++) {
    at::TensorOptions options;
    options = options.dtype(at::kInt);

    // Deliberately don't pin the memory even if local_used_maps_dev_ will
    // be cuda. See Note [local_used_maps_ -> local_used_maps_dev copying]
    local_used_maps_[i] =
        at::zeros({static_cast<long>(variable_count)}, options);

    // This tensor needs to be on the same device as replica because backend
    // such as NCCL may not support CPU tensors, and hence it might not work
    // if we always put it on CPU.
    options = options.device(replicas_[i][0].device());
    local_used_maps_dev_[i] =
        at::empty({static_cast<long>(variable_count)}, options);
  }
}

初始化流程大致如下:

                                    +
                                    |
                                    |
                                    v
                  rpc_context_ = ThreadLocalDistAutogradContext
                                    +
                                    |
                                    |
                                    v
                  buckets_ & variable_locators_ (clear & resize)
                                    +
                                    |
                                    |
                                    v
+----------------------->  from 0 ~ bucket_count :  +--------------------------->
|                                                                                +
|                                                                                |
|      +-------------------------------------------------------------------+     |
|      | init Bucket          set bucket_indices                           |     |
|      |                            +                                      |     |
|      |                            |                                      |     |
|      |                            |                                      |     |
|      |                            v                                      |     |
|      |   ^ +------------> from 0 ~ replica_count : +----------------->   |     |
|      |   |                                                           |   |     |
|      |   |  +---------------------------------------------------+    |   |     |
|      |   |  | init BucketReplica                                |    |   |     |
|      |   |  |                                                   |    |   |     |
<----+ |   +--+                                                   | <--+   | <---+
       |      |    bucket.replicas.push_back(std::move(replica))  |        |
       |      |                                                   |        |
       |      +----------------------+----------------------------+        |
       |                             |                                     |
       |                             |                                     |
       |                             v                                     |
       |             buckets_.push_back(std::move(bucket))                 |
       |                             +                                     |
       +-------------------------------------------------------------------+
                                     |
                                     v

得到的 Reducer 大致如下,這裡需要注意的是 ,BucketReplica 每個桶只有一個:

            +----------------------------------------+                 +------------------+
            |tensor index 4, tensor index 5, tensor 6| <------+        | index 2, index 3 |
            +----------------------------------------+        |        +--------------+---+
                                                              |                       ^
                                                              |                       |
+---------------------------+   +---------------------------------------------------------+
| Reducer                   |   | +----------------------------------+     +------------+ |
|                           |   | |Bucket                     |      |     |Bucket    | | |
|                           |   | |                           +      |     |          | | |
| vector<Bucket> buckets_ +---> | | vector<size_t> variable_indices  |     | indices ++ | |
|                           |   | |                                  |     |            | |
|                           |   | |  vector<BucketReplica> replicas  | ... | replicas   | |
|                           |   | |                         +        |     |   +        | |
|                           |   | |                         |        |     |   |        | |
|                           |   | +----------------------------------+     +------------+ |
|                           |   |                           |                  |          |
+---------------------------+   +---------------------------------------------------------+
                                                            |                  |
                                                            |                  |
                                                            v                  v
                          +---------------------------------------+   +-------------------+
                          |  +----------------------------------+ |   | +---------------+ |
                          |  | BucketReplica                    | |   | | BucketReplica | |
                          |  |                                  | |   | |               | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> bucket_views_in  | |   | |   views_in    | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> bucket_views_out | |   | |   views_out   | |
                          |  |                                  | |   | |               | |
                          |  |  Tensor contents                 | |   | |   contents    | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> variables        | |   | |   variables   | |
                          |  |                     +            | |   | |      +        | |
                          |  +----------------------------------+ |   | +---------------+ |
                          +---------------------------------------+   +-------------------+
                                                   |                           |
                                                   |                           |
                                                   v                           v
                                   +---------------+------------+    +---------+----------+
                                   |Tensor 4, Tensor 5, Tensor 6|    | Tensor 2, Tensor 3 |
                                   +----------------------------+    +--------------------+

0x03 靜態圖

3.1 緣由

雖然 PyTorch 是動態圖,但是使用者可以明確地讓DDP知道訓練圖是靜態的,有如下情況時候可以設定:

  1. 已使用和未使用的引數集在整個訓練迴圈中不變,在這種情況下,使用者是否將find_unsued_parameters設定為true並不重要。

  2. 圖形的訓練方式在整個訓練迴圈過程中不會改變(意味著不存在依賴於迭代的控制流)。當圖被設定為靜態時,DDP將支援以前不支援的case,比如:

    1. 可重入的反向傳播。
    2. 多次activation checkpointing。
    3. activation checkpointing 並且find_unused_parameters = true。
    4. 並不是所有的輸出張量都用於損失計算。。
    5. 在前向函式之外有一個模型引數。
    6. 當find_unsued_parameters=true時或者存在未使用的引數,可能會提高效能,因為DDP在每個迭代之內不會搜尋網路來檢查未使用的引數。

3.2 使用

_set_static_graph 可以配置靜態圖,此API應在DistributedDataParallel構造之後,並且在訓練迴圈開始之前呼叫。並且,也應該以同樣的方式對所有的rank 進行呼叫。例如:

ddp_model = DistributedDataParallel(model)
ddp_model._set_static_graph()
for i in range(n):

_set_static_graph 程式碼為:

def _set_static_graph(self):
    """
    Users can explicitly let DDP know the trained graph is static,
    when 1) the set of used and unused parameters will not change
    during the whole training loop; in this case, it does not matter
    whether users set find_unsued_parameters = true or not.
    2) how the graph is trained will not change during the whole training
    loop (meaning there is no control flow depending on iterations).
    When graph is set to be static, DDP will support cases that can not
    be supported in the past: 1) reentrant backwards
    2) activation checkpointing multiple times 3)
    activation checkpointing with find_unused_parameters = true.
    4) not all output tensors are used in loss calculation.
    5) there is model parameter that is outside of forward function.
    6) potentially improve performance when find_unsued_parameters = true
    or there are unused parameters, as DDP will not search graph in each
    iteraton to detect unused parameters when static_graph is set to be True.

    This API should be called after DistributedDataParallel construction, and
    before training loops starts. Also it should be called in the same way for
    all ranks. For example:
        ddp_model = DistributedDataParallel(model)
        ddp_model._set_static_graph()
        for i in range(n):
            .....
    """
    self.static_graph = True
    self.reducer._set_static_graph() # 呼叫 Reducer 進行配置
    self.logger._set_static_graph()
    if self.find_unused_parameters:
        warnings.warn(
            "You passed find_unused_parameters=true to DistributedDataParallel, "
            "`_set_static_graph` will detect unused parameters automatically, so "
            "you do not need to set find_unused_parameters=true, just be sure these "
            "unused parameters will not change during training loop while calling "
            "`_set_static_graph`."
        )

3.2 Reducer

Reducer 只有在第一次迭代之後才能生成靜態圖,因為畢竟PyTorch還是動態的,無論如何也得走一步動態生成。

void Reducer::set_static_graph() {
  std::lock_guard<std::mutex> lock(mutex_);
  TORCH_CHECK(
      num_iterations_ == 0,
      "set_static_graph() should be called before training loop starts "
      "and after DistributedDataParallel is constructed.");
  static_graph_ = true;
  // when static_graph_ is set as true, always initialize_local_used_map
  // and detect the global unused parameters in the first iteration.
  initialize_local_used_map();
}

0x04 重建桶

4.1 為何要重建

因為 PyTorch 是動態生成計算圖,所以需要相應重建桶。但是隻有設定了靜態圖 並且 第一次迭代之後才會重建,如果設定 find_unused_parameters_,就不重建。

  // Returns true if we should rebuild buckets, else false. We only rebuild
  // buckets once after the first iteration and never rebuild them if
  // find_unused_parameters_.
  inline bool should_rebuild_buckets() const {
    return (static_graph_ || !find_unused_parameters_) && !has_rebuilt_bucket_;
  }

4.2 準備重建

我們首先看看重建之前的一些準備。

push_rebuilt_params 就是插入一個重建引數列表。

void Reducer::push_rebuilt_params(const VariableIndex& index) {
  rebuilt_params_.push_back(
      replicas_[index.replica_index][index.variable_index]);
  rebuilt_param_indices_.push_back(index.variable_index);
}

其次,push_rebuilt_params_for_all_indices 會遍歷每個 replica,針對 replica 之中的每個 variable 進行設定。

void Reducer::push_rebuilt_params_for_all_indices() {
  std::lock_guard<std::mutex> lock(mutex_);
  if (!should_rebuild_buckets() || !rebuilt_param_indices_.empty()) {
    return;
  }
  const auto replica_count = replicas_.size();
  for (size_t replica_index = 0; replica_index < replica_count;
       ++replica_index) {
    const auto variable_count = replicas_[replica_index].size();
    for (size_t variable_index = 0; variable_index < variable_count;
         ++variable_index) {
      const auto index = VariableIndex(replica_index, variable_index);
      push_rebuilt_params(index);
    }
  }
}

4.3 重建

我們接下來看看重建機制。

DDP 根據張量在後向傳播中接收梯度的時間,使用 rebuilt_params_ 和 rebuilt_param_indices_ 來重建儲存桶。

rebuild_buckets 函式進行廣播通訊呼叫,並且可以與下一個forward()呼叫重疊,因此它可以是非同步的。

  • 在find_unused_parameters=true情況下重建bucket 就是非同步操作,因為我們可以多次重建bucket,其中子圖經過訓練,引數索引順序可能會更頻繁地更改。
  • 對於find_unused_parameters=false的情況,bucket只重建一次,效能成本可以忽略不計。如果已重建儲存桶, rebuild_buckets 則返回true。
bool Reducer::rebuild_buckets() {
  // Ensure reduction for previous backwards pass is finished. If user's model
  // has unused parameters for example, this will raise an error recommending to
  // run with find_unused_parameters=True, instead of the size mismatch
  // exception below.
  std::lock_guard<std::mutex> lock(mutex_);
  ensure_prior_reduction_finished();
  if (!should_rebuild_buckets() || rebuilt_params_.empty()) {
    return false;
  }

  std::vector<std::vector<size_t>> rebuilt_bucket_indices;
  std::vector<size_t> bucket_size_limits;
  bucket_size_limits.push_back(kDefaultFirstBucketBytes);
  bucket_size_limits.push_back(bucket_bytes_cap_);
  rebuilt_bucket_indices = compute_bucket_assignment_by_size(
      rebuilt_params_,
      bucket_size_limits,
      expect_sparse_gradients_[0],
      rebuilt_param_indices_);

  // For rebuilt bucket indices, it needs to be synced across all ranks.
  // Broadcast the newly rebuilt bucket indices from rank 0 in default.
  // After syncing up rebuilt bucket indices, initialize buckets for reducer.
  sync_bucket_indices(rebuilt_bucket_indices);

  has_rebuilt_bucket_ = true; // 只重建一次
  rebuilt_params_.clear();
  rebuilt_param_indices_.clear();

  initialize_buckets(std::move(rebuilt_bucket_indices));
  return true;
}

4.4 何時設定重建

重建僅在以下情況進行設定:

  1. 第一次重建儲存桶

  2. static_graph_ is true 或 find_unused_parameters_ is false

  3. 此反向傳播過程需要執行allreduce。

在這裡,我們只需基於梯度到達順序將張量及其引數索引轉儲到rebuilt_params_rebuilt_param_indices_。然後在finalize_backward() 結束時,將基於rebuilt_params_rebuilt_param_indices_重建儲存桶,然後廣播和初始化儲存桶。

此外,我們只需要轉儲一個副本的張量和引數索引。

以 mark_variable_ready 為例,其中就會呼叫 push_rebuilt_params(index) 來插入列表。

void Reducer::mark_variable_ready(VariableIndex index) {
  // Rebuild bucket only if 1) it is the first time to rebuild bucket 2)
  // static_graph_ is true or find_unused_parameters_ is false,
  // 3) this backward pass needs to run allreduce.
  // Here, we just dump tensors and their parameter indices into
  // rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving
  // order, and then at the end of finalize_backward(), buckets will be
  // rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then
  // will be broadcasted and initialized. Also we only need to dump tensors
  // and parameter indices of one replica.
  if (should_rebuild_buckets()) {
    push_rebuilt_params(index); // 插入列表
  }

  const auto replica_index = index.replica_index;
  const auto variable_index = index.variable_index;

  if (replica_index == 0) {
    checkAndRaiseMarkedTwiceError(variable_index);
    perIterationReadyParams_.insert(variable_index);
  }
  backward_stats_[replica_index][variable_index] =
      current_time_in_nanos() - cpu_timer_.backward_compute_start_time;

  // Any time we mark a variable ready (be it in line due to unused parameters,
  // or via an autograd hook), we require a call to the finalize function. If
  // this doesn't happen before the next iteration (or call to
  // `prepare_for_backwards`), we know something is wrong.
  require_finalize_ = true;

  const auto& bucket_index = variable_locators_[variable_index];
  auto& bucket = buckets_[bucket_index.bucket_index];
  auto& replica = bucket.replicas[replica_index];

  set_divide_factor();

  if (bucket.expect_sparse_gradient) {
    mark_variable_ready_sparse(index);
  } else {
    mark_variable_ready_dense(index);
  }

  // TODO(@pietern): Make this work for both CPU/CUDA tensors.
  // When using CPU tensors we don't need to do this.
  // // Record event so that we can wait for all of them.
  // auto& event = replica.events[bucket_index.intra_bucket_index];
  // event.record();

  // Check if this was the final gradient for this bucket.
  if (--replica.pending == 0) {
    // Kick off reduction if all replicas for this bucket are ready.
    if (--bucket.pending == 0) {
      mark_bucket_ready(bucket_index.bucket_index);
    }
  }

  // Run finalizer function and kick off reduction for local_used_maps once the
  // final bucket was marked ready.
  if (next_bucket_ == buckets_.size()) {

    if (dynamic_graph_find_unused()) {
      all_reduce_local_used_map();
    }

    // The autograd engine uses the default stream when running callbacks, so we
    // pass in the current CUDA stream in case it is not the default.
    const c10::Stream currentStream = get_current_stream();
    torch::autograd::Engine::get_default_engine().queue_callback([=] {
      std::lock_guard<std::mutex> lock(this->mutex_);
      // Run callback with the current stream
      c10::OptionalStreamGuard currentStreamGuard{currentStream};
      if (should_collect_runtime_stats()) {
        record_backward_compute_end_time();
      }
      // Check that all buckets were completed and had their work kicked off.
      TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size());
      this->finalize_backward();
    });
  }
}

4.5 直接呼叫

_rebuild_buckets 函式也可以直接呼叫,比如如下情況,就是在整個訓練期間內在 forward 呼叫了一次。

def forward(self, *inputs, **kwargs):
    with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
        self.reducer.save_thread_local_state()
        if torch.is_grad_enabled() and self.require_backward_grad_sync:
            self.num_iterations += 1
            self.reducer.prepare_for_forward()
        if self.ddp_uneven_inputs_config.ddp_join_enabled:
            ones = torch.ones(1, device=self.device)
            work = dist.all_reduce(ones, group=self.process_group, async_op=True)
            if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                # Active ranks schedule an allreduce with zeros, inactive
                # ranks schedule them with 1. If the result != 0 it
                # indicates at least one rank has terminated and we should
                # throw.
                zeros = torch.zeros(1, device=self.device)
                dist.all_reduce(zeros, group=self.process_group)
                should_throw_stop_iteration = zeros.item()
                if should_throw_stop_iteration:
                    raise RuntimeError(
                        "Detected at least one rank that exhausted inputs. Throwing across all ranks."
                    )
            else:
                self.reducer._set_forward_pass_work_handle(
                    work,
                    self.ddp_uneven_inputs_config.ddp_join_divide_by_initial_world_size,
                )

        # Calling _rebuild_buckets before forward compuation,
        # It may allocate new buckets before deallocating old buckets
        # inside _rebuild_buckets. To save peak memory usage,
        # call _rebuild_buckets before the peak memory usage increases
        # during forward computation.
        # This should be called only once during whole training period.
        
        # 在這裡進行直接呼叫
        if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): # 設定
            logging.info("Reducer buckets have been rebuilt in this iteration.")

再比如 Join 方法也可以直接呼叫進行重建。

@contextmanager
def join(
    self,
    divide_by_initial_world_size=True,
    enable=True,
    throw_on_early_termination=False,
):
  
  									# 忽略其他程式碼
    
                    else:
                        # Some DDP process still needs to be joined.
                        if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                            # Schedule allreduce telling active ranks to terminate
                            ones = torch.ones(1, device=self.device)
                            dist.all_reduce(ones, group=self.process_group)
                            # Raising StopIteration doesn't throw error in python 3.6
                            # and throws RuntimeError in 3.7+ (PEP 479), so just
                            # raise RuntimeError here.
                            raise RuntimeError(
                                f"Rank {self._distributed_rank} exhausted all inputs."
                            )
                        if is_last_joiner:
                            is_last_joiner = False
                        # It will rebuild buckets only once during training period
                        
                        # 這裡進行呼叫。
                        self.reducer._rebuild_buckets()
                        # Schedule a corresponding broadcast if we are syncing module
                        # buffers in the forward pass.
                        self._check_and_sync_module_buffers()   

既然提到了 Join,我們接下來就看看這個概念。

0x05 Join

Join 是為了解決訓練資料不均勻的問題,就是允許某些輸入較少的worker(其已經完成Join操作)可以繼續和那些尚未結束的worker繼續執行集合通訊,就是一個欺騙操作(Shadow)。

5.1 緣起

支撐DDP背後的是幾個集合通訊庫的all-reduce操作,其完成了各個worker之間的梯度同步。而當訓練資料在 ranks 之間的輸入是不均勻(uneven)的,就會導致DDP會掛起。因為集合通訊要求在程式組中的所有rank都參與,因此如果一個rank的輸入少,其他ranks會hang或者報錯(取決於後端),而且任何類在執行同步集合通訊時,在每次迭代都會遇到這個問題。

因此,DDP 給出了一個 "Join" API,Join是一個上下文管理器,在每個rank的訓練迴圈之中使用。資料量少的 rank 會提前耗盡輸入,這時它將給集合通訊一個假象,從而會構建一個虛擬(dummy)的 all-reduce,以便在資料不足時候與其他 ranks 匹配。具體如何製造這個假象是由註冊hook指定。

其大致思路如下:

                +----------------------------+
                |             Data           |
                |   +--------+   +--------+  |
                |   |        |   | Empty  |  |
                |   |        |   |        |  |
                |   +-----+--+   +--------+  |
                |         |                  |
                |         |                  |
                +----------------------------+
                          |
                          |
        +------------+    |               +------------+
        |            |    |               |            |
+---->  |    Model   |    |               |   Model    | <-----+
|       |            |    |               |            |       |
|       +------+-----+    |               +------+-----+       |
|              |          |                      |             |
|              |          |                      |             |
|              v          |                      v             |
|       +------+-----+    |             +--------+----------+  |
|       |  Forward   +<---+             | _JoinHook         |  |
|       |  (local)   |                  |                   |  |
|       +------+-----+                  |                   |  |
|              |                        |                   |  |
|              |                        |                   |  |
|              v                        | +---------------+ |  |
|       +------+-----+                  | | main_hook     | |  |
|       |  Backward  |                  | |               | |  |
|       |  (local)   |                  | |               | |  |
|       +------+-----+                  | |               | |  |
|              |                        | |               | |  |
|              |                        | |               | |  |
|              v                        | |               | |  |
|       +------+-----+                  | |               | |  |
|       | All-Reduce |     Sync grads   | |   All-Reduce  | |  |
|       |            | <--------------> | |   (Dummy)     | |  |
|       +------+-----+                  | |               | |  |
|              |                        | +---------------+ |  |
|              |                        +-------------------+  |
|              v                                 |             |
|     +--------+-------+                         |             |
|     | Update Weights |                         |             |
|     |                |                         |             |
|     +--------+-------+                         |             |
|              |                                 |             |
|              |                                 |             |
+--------------+                                 +-------------+

5.2 使用

5.2.1 DistributedDataParallel

Join 可以和 DistributedDataParallel 一起使用,比如下面的例子之中,會啟動兩個worker,分別是 rank 0 和 rank 1,rank 0 會得到5個輸入,rank 1會得到6個輸入,這就是輸入不均衡。

如果沒有使用 Join,則 rank 1 會在處理第6個輸入時候死掉掛起,因為rank 0沒有相關輸入,所以rank 1只能等待。如果使用了 Join,則不會出現這種問題,可以順利結束。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

這將產生以下輸出(其中print來自 0 級和 1 級的 ranks,可以任意排序):

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

5.2.2 ZeroRedundancyOptimizer

Join上下文不僅是和一個類合作,也可以和多個類一起,比如PyTorch 的ZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

這將產生與以前相同的輸出。顯著的變化是需要另外將ZeroRedundancyOptimizer例項傳入 Join()

後續會對ZeroRedundancyOptimizer等機制也進行分析。

5.3 原理

在最新文件 https://pytorch.org/tutorials/advanced/generic_join.html 之中,PyTorch 給出了一定解釋,我們翻譯如下。

為了更好的使用,我們將介紹Join類以及支援類JoinableJoinHook

備註:這部分在 v1.10.0 版本程式碼之中。

5.3.1 Joinable

首先,與Join上下文管理器相容的類必須繼承抽象基類Joinable。特別的,Joinable必須實現:

  • join_hook(self, **kwargs) -> JoinHook

這將返回 的JoinHook例項Joinable,用來確定加入的程式應如何影響由Joinable 執行的每次迭代集體通訊。

  • join_device(self) -> torch.device

這將返回Join上下文管理器用來執行集體通訊的裝置,例如torch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

這將返回Join上下文管理器用於執行集體通訊的程式組。

概括一下,JoinHook負責具體行為,join_device 和 join_process_group 負責具體集合通訊

需要注意的是,join_devicejoin_process_group是必需的屬性,他們可以確保上下文管理器能夠安排"加入"和"未加入"程式之間的集體通訊。一種用法是使用 all-reduce 計算每次迭代中"未加入"程式的數量。另一種用法是實現 throw_on_early_termination=True所需的機制,我們將在下面解釋。

DistributedDataParallelZeroRedundancyOptimizer已經繼承Joinable並實現了上面的方法,這就是為什麼我們可以在前面的例子中直接使用它們。

class DistributedDataParallel(Module, Joinable):

class ZeroRedundancyOptimizer(Optimizer, Joinable):

DDP 涉及到提供資料,所以繼承Joinable可以理解,ZeroRedundancyOptimizer 為何也需要繼承?這是因為 ZeroRedundancyOptimizer 可以和 DDP 一起合作,並且 ZeroRedundancyOptimizer 內部也有集合操作,所以需要被 Join 一起管理。

Joinable類應該確保呼叫Joinable建構函式,因為它初始化了一個JoinConfig例項,上下文管理器在內部使用JoinConfig來確保正確性。JoinConfig將在每個Joinable _join_config欄位中儲存。

5.3.2JoinHook

接下來,讓我們分解一下JoinHook類。JoinHook提供了兩個進入上下文管理器的入口點:

  • main_hook(self) -> None

當存在尚未加入(Join)的 rank 時,每個加入(Join)的 rank 都會重複呼叫此鉤子。它目的是在每次訓練迭代(例如,在一次前向傳遞,反向傳遞和優化器步驟)之中,隱藏由Joinable所執行的集體通訊,即已經Join的rank 如何與未Join的rank執行集合通訊

  • post_hook(self, is_last_joiner: bool) -> None

一旦所有 ranks 都加入,這個鉤子就會被呼叫。它傳遞了一個額外的 bool引數is_last_joiner,其表明此 rank 是否是最後加入的 rank 之一。該引數可能對同步有用。

5.3.2.1 ZeroRedundancyOptimizer

我們以 內建的 ZeroRedundancyOptimizer main hook 來給出一個鉤子的具體例子:因為加入的 rank 仍然負責更新和同步其引數分片,所以 main hook 依然執行優化器步驟。

class _ZeROJoinHook(_JoinHook):
    def __init__(self, zero):
        assert isinstance(zero, ZeroRedundancyOptimizer), \
            "ZeRO join hook requires passing in a ZeroRedundancyOptimizer " \
            "instance as the state"
        self.zero = zero
        super().__init__()

    def main_hook(self):
        """
        Performs an optimizer step, which updates the joined process's shard of
        the parameters and broadcasts those parameters.
        """
        self.zero.step()

step函式簡略如下:

def step(
    self,
    closure: Optional[Callable[[], float]] = None,
    **kwargs: Any,
) -> Optional[float]:
    _Join.notify_join_context(self) # 這裡會通知
    # Check if the model trainability has changed
    is_trainable_mask = self._get_is_trainable_mask()
    if is_trainable_mask != self._is_trainable_mask:
        self._build_param_buckets()
        self._is_trainable_mask = is_trainable_mask

    # Sync the exposed `param_groups` attributes to the local optimizer in
    # case they have been updated
    self._sync_param_groups(self.param_groups, self.optim.param_groups)

    # Run the optimizer step on this shard only
    if closure is not None:
        loss = self.optim.step(closure=closure, **kwargs)  # type: ignore[call-arg]
    else:
        loss = self.optim.step(**kwargs)

    # Sync all of the updated parameter shards across the ranks
    self._sync_parameters()

    # Sync any updated attributes in the local optimizer to the exposed
    # `param_groups`
    self._sync_param_groups(self.optim.param_groups, self.param_groups)

    return loss

再來看看DistributedDataParallel

  • main_hook 依然會做相關的一系列操作來欺騙其他rank。
  • post-hook 會從最後加入的rank之一來廣播最終更新的模型,以確保模型在所有rank中都是相同的。
class _DDPJoinHook(_JoinHook):
    def __init__(self, ddp, divide_by_initial_world_size):
        """
        Sets config variables for internal usage.
        """
        ddp.logger._set_uneven_input_join()
        self.ddp = ddp
        self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
        super().__init__()

    def main_hook(self):
        """
        Shadows the DDP collective communication operations in the forward and
        backward passes.
        """
        ddp = self.ddp
        # Buckets are rebuilt only once during a training period
        ddp.reducer._rebuild_buckets()

        # Schedule a broadcast if we are syncing module buffers in the
        # forward pass
        ddp._check_and_sync_module_buffers()

        # Check if need to sync in the backward pass
        work = ddp._check_global_requires_backward_grad_sync(is_joined_rank=True)
        work.wait()
        should_sync_backwards = work.result()[0].item() != 0
        # Forward parameter sync is disabled in the next iteration if we
        # are skipping gradient sync this iteration, so set
        # `require_forward_param_sync` accordingly
        ddp.require_forward_param_sync = should_sync_backwards
        if not should_sync_backwards:
            return

        # Schedule one allreduce per gradient bucket to match the backward
        # pass allreduce
        ddp._match_all_reduce_for_bwd_pass()

        # Check if we need to allreduce locally unused parameters
        if ddp.find_unused_parameters:
            ddp._match_unused_params_allreduce()

        # Rebuilt parameters are pushed only once during a training period
        ddp.reducer._push_all_rebuilt_params()

    def post_hook(self, is_last_joiner: bool):
        """
        Syncs the final model to ensure that the model is the same across all
        processes.
        """
        self.ddp._sync_final_model(is_last_joiner)

_sync_final_model 這裡會廣播最新的模型。

# When running in join model, agrees upon a common rank and broadcast model
# parameters to all other ranks.
def _sync_final_model(self, is_last_joiner):
    # Agree upon the process that will be the authoritative model copy.
    # The current rank is a candidate for being the authoritative copy if
    # is_last_joiner=True. We break ties via picking the larger rank.
    self._authoritative_rank = self._find_common_rank(
        self._distributed_rank, is_last_joiner
    )
    self._sync_params_and_buffers(authoritative_rank=self._authoritative_rank)

5.3.3 Join

最後,讓我們看看這些基礎類是如何適應Join類本身的。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我們在前面的例子中看到的,建構函式接收一個參與訓練迴圈的Joinable列表 。這些應該是在每次迭代中執行集體通訊的類。

enablebool型別,如果您知道不會有不均勻的輸入,則可以設定為 False,在這種情況下,上下文管理器變得類似於contextlib.nullcontext(). 這也可能會在參與Joinable列表之中禁用join-related計算。

throw_on_early_terminationbool型別,其可以設定為True,以便讓每個等級在檢測到不均勻輸入時引發異常。這對於不符合上下文管理器要求的情況很有用,這通常是當來自不同類的集體通訊可以任意交錯(interleaved)時,例如DistributedDataParallel與具有SyncBatchNorm層的模型一起使用時 。在這種情況下,應將此引數設定為 True以便應用程式邏輯可以捕獲異常並確定如何繼續。

  • 核心邏輯出現在該__exit__()方法中,該方法在存在未加入的 rank 時會進行迴圈呼叫每個 Joinable的主鉤子,然後一旦所有rank加入,就呼叫它們的 post 鉤子。主鉤子和後鉤子都按照Joinables 傳入的順序進行迭代。
  • 上下文管理器需要來自未加入程式的心跳。因此,每個Joinable類都應該在每次迭代的集體通訊之前呼叫Join.notify_join_context() 。上下文管理器將確保只有第一個傳入的Joinable實際傳送心跳。

5.4 例子

我們通過一個例子來具體看看。下面程式碼之中,每個rank會列印(1)在Join之前看到的所有rank的輸入數量,以及(2)所有rank的輸入總數。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    # 確定最後join的rank,由於後加入的rank可能不止一個,所以選擇rank最大的rank來同步  
    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由於rank 0看到5個輸入,rank 1看到6個,因此產生輸出:

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

需要強調的一些要點:

  • Counter例項在每次迭代中執行一個all reduce操作,因此:
    • 對於已經Join的rank,其 main hook 也執行單個all reduce來對整體通訊進行矇騙操作( shadow it),注意這個 all-reduce是呼叫一個為0的tensor,所以對整體結果不影響。
    • 其他未 Join 的 rank 會以為這依然是一個正確的滿員的集合操作。
    • 這樣就處理了不均勻輸入。
  • Counter類在其 __call__()方法的開頭呼叫 Join.notify_join_context() ,因為這是每次集合操作(all-reduce)的地方,需要在這裡通知上下文管理器,本示例還沒有Join(已經結束的rank不會呼叫到這裡)。
  • 'is_last_joiner'引數用於確定post-hooks中的廣播源。
  • 我們將 sync_max_count 關鍵字引數傳遞給上下文管理器,上下文管理器會將其轉發給'Counter'的join hook。
  • post-hooks之中,會對 self.counter.max_count 進行廣播。

0xFF 參考

pytorch分散式系列3——分散式訓練時,torch.utils.data.distributed.DistributedSampler做了什麼?

pytorch分散式系列1——搞清torch.distributed.launch相關的環境變數

pytorch分散式系列2——DistributedDataParallel是如何做同步的?

pytorch(分散式)資料並行個人實踐總結——DataParallel/DistributedDataParallel

Pytorch的nn.DataParallel

https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551/20

https://pytorch.org/docs/stable/distributed.html

PyTorch 原始碼解讀之分散式訓練了解一下?

實操教程|PyTorch AutoGrad C++層實現

PYTORCH 自動微分(一)

PyTorch如何加速資料並行訓練?分散式祕籍大揭祕

pytorch分散式訓練(二init_process_group)

https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

https://pytorch.org/docs/master/notes/ddp.html

https://pytorch.org/tutorials/intermediate/dist_tuto.html

PyTorch 原始碼解讀之 DP & DDP:模型並行和分散式訓練解析

Pytorch模型中的parameter與buffer

【PyTorch開發者日 2020】PyTorch分散式資料並行(DDP)

[中文字幕] 深入理解 PyTorch 中的 Hook 機制

[中文字幕] 深入解讀 Pytorch AutoGrad

DISTRIBUTED TRAINING WITH UNEVEN INPUTS USING THE JOIN CONTEXT MANAGER

談談torch1.10中的ZeroRedundancyOptimizer和Join