[原始碼解析] PyTorch 分散式(7) ----- DistributedDataParallel 之程式組

羅西的思考發表於2021-11-20

[原始碼解析] PyTorch 分散式(7) ----- DistributedDataParallel 之程式組

0x00 摘要

本文是 PyTorch 分散式系列的第七篇, 介紹 DistributedDataParallel 所依賴的程式組概念。

本系列其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀

[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎

[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構

[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯

[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法

[原始碼解析] PyTorch 分散式(1)------歷史和概述

[原始碼解析] PyTorch 分散式(2) ----- DataParallel(上)

[原始碼解析] PyTorch 分散式(3) ----- DataParallel(下)

[原始碼解析] PyTorch 分散式(4)------分散式應用基礎概念

[原始碼解析] PyTorch分散式(5) ------ DistributedDataParallel 總述&如何使用

[原始碼解析] PyTorch分散式(6) ---DistributedDataParallel -- 初始化&store

0x01 回顧

1.1 基礎概念

關於分散式通訊,PyTorch 提供的幾個概念是:程式組,後端,初始化,Store。

  • 程式組 :DDP是真正的分散式訓練,可以使用多臺機器來組成一次並行運算的任務。為了能夠讓 DDP 的各個worker之間通訊,PyTorch 設定了程式組這個概念。
  • 後端 :後端這個概念是一個邏輯上的概念。本質上後端是一種IPC通訊機制。對於使用者來說,就是採用那種方式來進行集合通訊,從程式碼上看,就是走什麼流程(一系列流程),以及後端使用 ProcessGroupMPI 還是 ProcessGroupGloo .....。
  • 初始化 : 雖然有了後端和程式組的概念,但是如何讓 worker 在建立程式組之前發現彼此? 這就需要一種初始化方法來告訴大家傳遞一個資訊:如何聯絡到其它機器上的程式?
  • Store : 可以認為是分散式鍵值儲存,這個儲存在組中的程式之間共享資訊以及初始化分散式包 (通過顯式建立儲存來作為init_method的替代)。

1.2 初始化程式組

在呼叫任何 DDP 其他方法之前,需要使用torch.distributed.init_process_group()進行初始化程式組。

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import os

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

該方法會初始化預設分散式程式組和分散式包。此方法會阻塞,直到所有程式都加入,函式定義如下:

init_process_group ( backend , 
                       init_method = None , 
                       timeout = default_pg_timeout , 
                       world_size =- 1 , 
                       rank =- 1 , 
                       store = None , 
                       group_name = '' , 
                       pg_options = None )

初始化程式組有兩種主要方法:

  1. 明確指定 store,rank 和 world_size。
  2. 指定 init_method(一個 URL 字串),它指示在哪裡/如何發現對等點。

如果兩者都沒有指定,init_method則假定為“env://”。

因此大家可以看到,store 和 init_method 是互斥的

引數具體如下:

  • 後端 – 要使用的後端。有效值包括mpigloo,和nccl。該欄位應作為小寫字串(例如"gloo")給出,也可以通過Backend屬性(例如Backend.GLOO)訪問 。如果在nccl後端每臺機器上使用多個程式,則每個程式必須對其使用的每個 GPU 具有獨佔訪問許可權,因為在程式之間共享 GPU 可能會導致死鎖。
  • init_method – 指定如何初始化程式組的 URL。如果未指定init_methodstore指定,則預設為“env://” 。與 store互斥。
  • world_size – 參與作業的程式數。如果store指定,則 world_size 為必需。
  • rank – 當前程式的等級(它應該是一個介於 0 和world_size-1之間的數字)。如果store指定,則 rank 為必需。
  • store – 所有 worker 都可以訪問的鍵/值儲存,用於交換連線/地址資訊。與init_method 互斥。
  • timeout – 針對程式組執行的操作超時。預設值等於 30 分鐘。這適用於gloo後端。對於nccl,這僅在環境變數NCCL_BLOCKING_WAITNCCL_ASYNC_ERROR_HANDLING設定為 1 時 適用。
  • group_name – 組名。
  • pg_options ( Process Group Options , optional ) – 程式組選項,指定在構建特定程式組期間需要傳入哪些附加選項。

0x02 概念與設計

2.1 功能

預設情況下,集合通訊在預設組(也稱為world)上執行,並要求所有程式都進入分散式函式呼叫。但是,一些工作可以從更細粒度的通訊中受益。這就是分散式組發揮作用的地方。new_group() 函式可用於建立一個新分散式組,這個新組是所有程式的任意子集。new_group() 返回一個不透明的組控制程式碼,此控制程式碼可以作為group引數提供給所有集合函式(集合函式是分散式函式,用於在某些程式設計模式中交換資訊)。

2.2 本質

拋開概念,從程式碼看其本質。程式組就是給每一個訓練的 process 建立一個通訊thread。主執行緒(計算執行緒)在前臺進行訓練,這個通訊 thread 在後臺做通訊。我們以 ProcessGroupMPI 為例,是在通訊執行緒之中另外新增了一個 queue,做buffer 和 非同步處理。這樣,程式組中所有程式都可以組成一個集體在後臺進行集合通訊操作。

比如下面,左側worker之中就有兩個執行緒,計算執行緒負責計算梯度,然後要求通訊執行緒與其它worker進行交換梯度。

+---------------------------------------------------------------+        +--------------+
| Worker Process                                                |        | Other Worker |
|                                                               |        |              |
|  +----------------------+       +-----------------------+     |        | +----------+ |
|  | Computation thread   |       | Communication thread  |     |        | |   Comm   | |
|  |                      |       |                       |     |        | |  thread  | |
|  |                      |       |                       |     |        | |          | |
|  |     Main Thread      |       |    workerThread_      |     |        | |          | |
|  |                      |       |                       |     |        | |          | |
|  |                      |       |                       |     |        | |          | |
|  | Gradient computation |       |                       |     |        | |          | |
|  |          +           |       |                       |     |        | |          | |
|  |          |           |       |                       |   + |    +   | |          | |
|  |          |           |       |                       |  /| |    |\  | |          | |
|  |          v           | /|_|\ |                       | / +-+----+ \ | |          | |
|  |    Does All+Reduce   |/ grad\|   Does communication  |/  Gradient  \| |          | |
|  |                      |\  _  /|                       |\            /| |          | |
|  |                      | \| |/ |                       | \ +-+----+ / | |          | |
|  |                      |       |                       |  \| |    |/  | |          | |
|  |                      |       |                       |   + |    +   | |          | |
|  |                      |       |                       |     |        | |          | |
|  |                      |       |                       |     |        | |          | |
|  +----------------------+       +-----------------------+     |        | +----------+ |
|                                                               |        |              |
+---------------------------------------------------------------+        +--------------+

0x03 使用

既然知道了程式組的本質,我們接下來看看如何使用程式組。

首先,在 _ddp_init_helper 之中會生成 dist.Reducer,程式組會作為 Reducer 的引數之一傳入。

def _ddp_init_helper(self, parameters, expect_sparse_gradient, param_to_name_mapping):
    """
    Initialization helper function that does the following:
    (1) bucketing the parameters for reductions
    (2) resetting the bucketing states
    (3) registering the grad hooks
    (4) Logging constructin-time DDP logging data
    (5) passing a handle of DDP to SyncBatchNorm Layer
    """
    self.num_iterations = 0
    # The bucket size limit is specified in the constructor.
    # Additionally, we allow for a single small bucket for parameters
    # that are defined first, such that their gradients don't spill into
    # a much larger bucket, adding unnecessary latency after gradient
    # computation finishes. Experiments showed 1MB is a reasonable value.
    bucket_indices = dist._compute_bucket_assignment_by_size(
        parameters[0],
        [dist._DEFAULT_FIRST_BUCKET_BYTES, self.bucket_bytes_cap],
        expect_sparse_gradient[0],
    )

    # Note: reverse list of buckets because we want to approximate the
    # order in which their gradients are produced, and assume they
    # are used in the forward pass in the order they are defined.
    self.reducer = dist.Reducer(
        parameters,
        list(reversed(bucket_indices)),
        self.process_group, # 這裡使用了
        expect_sparse_gradient,
        self.bucket_bytes_cap,
        self.find_unused_parameters,
        self.gradient_as_bucket_view,
        param_to_name_mapping,
    )

其次,在 Reducer 構建函式之中,會把程式組配置給 Reducer 的成員變數 process_group_ 之上。

Reducer::Reducer(
    std::vector<std::vector<at::Tensor>> replicas,
    std::vector<std::vector<size_t>> bucket_indices,
    c10::intrusive_ptr<c10d::ProcessGroup> process_group, 
    std::vector<std::vector<bool>> expect_sparse_gradients,
    int64_t bucket_bytes_cap,
    bool find_unused_parameters,
    bool gradient_as_bucket_view,
    std::unordered_map<size_t, std::string> paramNames)
    : replicas_(std::move(replicas)),
      process_group_(std::move(process_group)), // 在這裡

最後,當需要對梯度做 all-reduce 時候,則會呼叫 process_group_->allreduce(tensors) 進行處理。

現在,我們就知道如何使用程式組了。

void Reducer::all_reduce_bucket(Bucket& bucket) {
  std::vector<at::Tensor> tensors;
  tensors.reserve(bucket.replicas.size());
  for (const auto& replica : bucket.replicas) {
    tensors.push_back(replica.contents);
  }

  if (comm_hook_ == nullptr) {
    bucket.work = process_group_->allreduce(tensors); // 這裡會進行呼叫
  } else {
    GradBucket grad_bucket(
        next_bucket_,
        tensors[0],
        // Since currently we do not support single-process multiple-device
        // mode, we can assume only one replica in the bucket.
        bucket.replicas[0].offsets,
        bucket.replicas[0].lengths,
        bucket.replicas[0].sizes_vec);
    bucket.future_work = comm_hook_->runHook(grad_bucket);
  }
}

0x04 構建

4.1 Python 世界

4.1.1 rendezvous

從 init_process_group 原始碼之中看,幾種構建實現在細節上有所不同,我們只是看gloo和mpi。

  • gloo利用 rendezvous 設定了master地址。

  • MPI不需要 rendezvous,而是利用mpirun啟動。

兩種方法都生成了一個 ProcessGroup 賦值給 default_pg,然後用 default_pg 設定 GroupMember.WORLD。

def _update_default_pg(pg):
    GroupMember.WORLD = group.WORLD = pg

具體 init_process_group 程式碼如下:

def init_process_group(backend,
                       init_method=None,
                       timeout=default_pg_timeout,
                       world_size=-1,
                       rank=-1,
                       store=None,
                       group_name='',
                       pg_options=None):
    """
    Initializes the default distributed process group, and this will also
    initialize the distributed package.

    There are 2 main ways to initialize a process group:
        1. Specify ``store``, ``rank``, and ``world_size`` explicitly.
        2. Specify ``init_method`` (a URL string) which indicates where/how
           to discover peers. Optionally specify ``rank`` and ``world_size``,
           or encode all required parameters in the URL and omit them.

    If neither is specified, ``init_method`` is assumed to be "env://".
    """
    
    global _pg_group_ranks
    global _backend
    global _default_pg_init_method

    if store is not None:
        assert world_size > 0, 'world_size must be positive if using store'
        assert rank >= 0, 'rank must be non-negative if using store'
    elif init_method is None:
        init_method = "env://"

    backend = Backend(backend)
    if backend == Backend.MPI:
        default_pg = _new_process_group_helper( # 生成了一個 ProcessGroup 賦值給 default_pg
            -1,
            -1,
            [],
            Backend.MPI,
            None,
            group_name=group_name,
            timeout=timeout)
        _update_default_pg(default_pg) # 用 default_pg 設定 GroupMember.WORLD
    else:
        # backward compatible API
        if store is None:
            rendezvous_iterator = rendezvous( # 先生成一個store
                init_method, rank, world_size, timeout=timeout
            )
            store, rank, world_size = next(rendezvous_iterator)
            store.set_timeout(timeout)

        default_pg = _new_process_group_helper( # 再進行構建 ProcessGroup
            world_size,
            rank,
            [],
            backend,
            store,
            pg_options=pg_options,
            group_name=group_name,
            timeout=timeout)
        _update_default_pg(default_pg) # 用 default_pg 設定 GroupMember.WORLD

    _pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())}  # type: ignore[attr-defined, index]
    _backend = _pg_map[GroupMember.WORLD][0]  # type: ignore[index]
    _default_pg_init_method = init_method

    # barrier at the end to ensure that once we return from this method, all
    # process groups including global variables are updated correctly on all
    # ranks.
    if backend == Backend.MPI:
        # MPI backend doesn't use store.
        barrier()
    else:
        # Use store based barrier here since barrier() used a bunch of
        # default devices and messes up NCCL internal state.
        _store_based_barrier(rank, store, timeout)
        # Set sequence numbers for gloo and nccl process groups.
        if get_backend(default_pg) in [Backend.GLOO, Backend.NCCL]:
            default_pg._set_sequence_number_for_group()

4.1.2 _new_process_group_helper

各種後端都會使用 _new_process_group_helper 進行具體構建,_new_process_group_helper 其實就是呼叫了不同的C++實現,比如 ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL。

def _new_process_group_helper(world_size,
                              rank,
                              group_ranks,
                              backend,
                              store,
                              pg_options=None,
                              group_name=None,
                              timeout=default_pg_timeout):
    """
    Create a new distributed process group.

    This function must be called by ALL processes in the global group, even if
    the calling process is not part of the newly created group. In that case,
    this function returns GroupMember.NON_GROUP_MEMBER.

    This function is called with ``group_ranks == []`` for the default group.
    """
    global _pg_map
    global _group_count
    global _pg_names
    if not group_name:
        group_name = str(_group_count)
        _group_count += 1

    # The list of group ranks is empty if we're creating the default group.
    is_default_group = (len(group_ranks) == 0)

    backend = Backend(backend)
    pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL]
    if backend == Backend.MPI:
        pg = ProcessGroupMPI.create(group_ranks) # 構建了 ProcessGroupMPI
        if not pg:
            return GroupMember.NON_GROUP_MEMBER
        _pg_map[pg] = (Backend.MPI, None)
        _pg_names[pg] = group_name
    else:
        # If this is a subgroup (which means group_ranks is specified),
        # we check if the current process is a member of the new group.
        if not is_default_group:
            global_rank = _get_default_group().rank()
            if global_rank not in group_ranks:
                return GroupMember.NON_GROUP_MEMBER

        # Use the group name as prefix in the default store, such that
        # a single store can be reused by multiple groups.
        prefix_store = PrefixStore(group_name, store)

        if backend == Backend.GLOO:
            pg = ProcessGroupGloo( # 構建了 ProcessGroupGloo
                prefix_store,
                rank,
                world_size,
                timeout=timeout)
            _pg_map[pg] = (Backend.GLOO, store)
            _pg_names[pg] = group_name
        elif backend == Backend.NCCL:
            if pg_options is not None:
                assert isinstance(pg_options, ProcessGroupNCCL.Options), \
                    "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
            else:
                # default pg_options for NCCL
                pg_options = ProcessGroupNCCL.Options()
                pg_options.is_high_priority_stream = False
                pg_options._timeout = timeout

            pg = ProcessGroupNCCL( # 構建了 ProcessGroupNCCL
                prefix_store,
                rank,
                world_size,
                pg_options)
            _pg_map[pg] = (Backend.NCCL, store)
            _pg_names[pg] = group_name
        else:
            pg = getattr(Backend, backend.upper())(
                prefix_store,
                rank,
                world_size,
                timeout)
            _pg_map[pg] = (backend, store)
            _pg_names[pg] = group_name

    return pg

目前流程如下:

                                  +
                                  |
                                  |
                                  v
                          init_process_group
                                  +
                                  |
                                  |
                     +------------+-------------+
                     |                          |
                     |                          |
                     v                          v
                Backend.MPI        Backend.GLOO & Backend.NCCL
                     +                          +
                     |                          |
                     |                          |
                     |                          v
                     |                  store = rendezvous()
                     |                          +
                     |                          |
                     |                          |
                     +------------+-------------+
                                  |
                                  |
                                  v

                       _new_process_group_helper
                                  +
                                  |
                                  |
                                  |
       +------------------------------------------------------+
       |                          |                           |
       |                          |                           |
       v                          v                           v

ProcessGroupMPI         ProcessGroupGloo(store)        ProcessGroupNCCL(store)

4.1.3

我們以 ProcessGroupMPI 為例來看看,可以看到ProcessGroupMPI的基類是ProcessGroup。

class ProcessGroupMPI(ProcessGroup):
    def __init__(
        self,
        rank: int,
        size: int,
        pgComm: int,
    ): ...
    @staticmethod
    def create(ranks: List[int]) -> ProcessGroupMPI: ...

ProcessGroup 定義了若干集合通訊函式,但是均未實現,不過從其註釋之中,我們可以看到派生類會有多種過載實現。

class ProcessGroup(__pybind11_builtins.pybind11_object):
    # no doc
    def allgather(self, *args, **kwargs): # real signature unknown; restored from __doc__
        """
        allgather(*args, **kwargs)
        Overloaded function.
        
        1. allgather(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: List[List[at::Tensor]], input_tensors: List[at::Tensor], opts: torch._C._distributed_c10d.AllgatherOptions = <torch._C._distributed_c10d.AllgatherOptions object at 0x000001A9460233F0>) -> c10d::ProcessGroup::Work
        
        2. allgather(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: List[at::Tensor], input_tensor: at::Tensor) -> c10d::ProcessGroup::Work
        """
        pass

    def allgather_coalesced(self, output_lists, *args, **kwargs): # real signature unknown; NOTE: unreliably restored from __doc__ 
        """ allgather_coalesced(self: torch._C._distributed_c10d.ProcessGroup, output_lists: List[List[at::Tensor]], input_list: List[at::Tensor], opts: torch._C._distributed_c10d.AllgatherOptions = <torch._C._distributed_c10d.AllgatherOptions object at 0x000001A946023370>) -> c10d::ProcessGroup::Work """
        pass

    def allreduce(self, *args, **kwargs): # real signature unknown; restored from __doc__
        """
        allreduce(*args, **kwargs)
        Overloaded function.
        
        1. allreduce(self: torch._C._distributed_c10d.ProcessGroup, tensors: List[at::Tensor], opts: torch._C._distributed_c10d.AllreduceOptions = <torch._C._distributed_c10d.AllreduceOptions object at 0x000001A946023570>) -> c10d::ProcessGroup::Work
        
        2. allreduce(self: torch._C._distributed_c10d.ProcessGroup, tensors: List[at::Tensor], op: torch._C._distributed_c10d.ReduceOp = <ReduceOp.SUM: 0>) -> c10d::ProcessGroup::Work
        
        3. allreduce(self: torch._C._distributed_c10d.ProcessGroup, tensor: at::Tensor, op: torch._C._distributed_c10d.ReduceOp = <ReduceOp.SUM: 0>) -> c10d::ProcessGroup::Work
        """
        pass

而無論哪個ProcessGroup的派生類,都指向了C++世界,比如在 torch/csrc/distributed/c10d/init.cpp 之中有如下程式碼:

// Define static create function instead of a constructor, because
// this function may return null. This happens if this process is not
// part of a sub group that is to be created.
processGroupMPI.def_static(
    "create",
    [](std::vector<int> ranks) {
      return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks);
    },
    py::call_guard<py::gil_scoped_release>());

因此可見,最後呼叫到的是 createProcessGroupMPI,於是我們直接去C++世界看看。

4.2 C++ 世界

4.2.1 ProcessGroupMPI 定義

ProcessGroupMPI 定義位於 torch/lib/c10d/ProcessGroupMPI.cpp。這裡相當於做了一個工作佇列,以及非同步操作。幾個注意點如下:

  • ProcessGroupMPI 類上的所有函式都應在組中的程式之間以相同的順序呼叫。這是我們能夠保證跨程式匹配相同呼叫的唯一方法。
  • ProcessGroupMPI 類提供的所有MPI函式都在工作執行緒上非同步排程。因此,ProcessGroupMPI 依賴於MPI實現,該實現用於提供 MPI_THREAD_SERIALIZED 的最小執行緒支援值。也就是說,程式可以是多執行緒的,多個執行緒可以進行MPI呼叫,但一次只能進行一個:MPI呼叫不是從兩個不同的執行緒同時進行的(所有MPI呼叫都是序列化的)。但是,如果使用 MPI_THREAD_SERIALIZED,ProcessGroupMPI將只支援單個程式組。換句話說,全域性建立的程式組不能超過1個。
  • 如果希望使用多個ProcessGroupMPI,它要求MPI實現的執行緒支援值為MPI\u thread\u multiple,也就是說,多個執行緒可以呼叫MPI,沒有任何限制。
  • 還要注意,ProcessGroupMPI只支援單個張量操作。換句話說,輸入張量向量的大小應始終為1。
  • 如果使用的MPI是CUDA-aware MPI,則可以支援CUDA tensor,並且ProcessGroupMPI將自動檢測此支援。
// ProcessGroupMPI implements MPI bindings for c10d.
//
// All functions on this class are expected to be called in the same
// order across processes in the group. This is the only way that we
// can guarantee to match up the same calls across processes.
//
// All MPI functions provided by this class is asynchronously scheduled on a
// Worker thread. Therefore, ProcessGroupMPI requires the MPI implementation
// that is used to have a minimum thread support value of MPI_THREAD_SERIALIZED.
// That is, The process may be multi-threaded, and multiple threads may make
// MPI calls, but only one at a time: MPI calls are not made concurrently from
// two distinct threads (all MPI calls are serialized). However, with
// MPI_THREAD_SERIALIZED, ProcessGroupMPI will only support a singe process
// group. In other words, no more than 1 process group can be created globally.
//
// If you would like to use multiple ProcessGroupMPI, it requres your MPI
// implemenation to have a thread support value of MPI_THREAD_MULTIPLE, that is,
// multiple threads may call MPI, with no restriction.
//
// Also note that ProcessGroupMPI only supports a single Tensor operation. In
// other words, the size of the input Tensor vector should always be 1.
//
// CUDA tensor can be supported if the MPI used is CUDA-aware MPI, and
// ProcessGroupMPI will automatically detect this support.
class ProcessGroupMPI : public ProcessGroup {
 public:
  class WorkMPI : public ProcessGroup::Work {
   public:
    explicit WorkMPI(
        std::vector<at::Tensor> outputTensors,
        const char* profilingTitle = nullptr,
        const c10::optional<std::vector<at::Tensor>>& inputTensors =
            c10::nullopt)
        : ProcessGroup::Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors),
          outputTensors_(std::move(outputTensors)),
          future_(c10::make_intrusive<at::ivalue::Future>(
              c10::ListType::create(c10::TensorType::get()))) {}

    std::vector<at::Tensor> result() override;
    c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;

   protected:
    friend class ProcessGroupMPI;

   private:
    void finishWorkMPI();
    void finishWorkMPIError(std::exception_ptr eptr);
    std::vector<at::Tensor> outputTensors_;
    c10::intrusive_ptr<at::ivalue::Future> future_;
  };

  class AsyncWork : public ProcessGroup::Work {
   public:
    AsyncWork(
        MPI_Request request,
        std::vector<at::Tensor> outputTensors,
        const char* profilingTitle = nullptr,
        const c10::optional<std::vector<at::Tensor>>& inputTensors =
            c10::nullopt);

    virtual ~AsyncWork();
    bool isCompleted() override;
    bool isSuccess() const override;
    int sourceRank() const override;
    bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
    void abort() override;
    std::vector<at::Tensor> result() override;

   protected:
    void populateException();

   private:
    const std::vector<at::Tensor> outputTensors_;
    MPI_Request request_;
    MPI_Status status_;
  };

  // Constructor will spawn up the worker thread loop
  explicit ProcessGroupMPI(int rank, int size, MPI_Comm pgComm);
  virtual ~ProcessGroupMPI();

 protected:
  using WorkType =
      std::tuple<std::unique_ptr<WorkEntry>, c10::intrusive_ptr<WorkMPI>>;
  // Worker thread loop
  void runLoop();
  // Helper function that is called by the destructor
  void destroy();

  c10::intrusive_ptr<ProcessGroup::Work> enqueue(
      std::unique_ptr<WorkEntry> entry,
      const char* profilingTitle = nullptr,
      const c10::optional<std::vector<at::Tensor>>& inputTensors = c10::nullopt);

  bool stop_;
  std::mutex pgMutex_;
  std::thread workerThread_;
  std::deque<WorkType> queue_;
  std::condition_variable queueProduceCV_;
  std::condition_variable queueConsumeCV_;

  // Global states
  static void initMPIOnce();
  static void mpiExit();
  static std::once_flag onceFlagInitMPI;
  static std::mutex pgGlobalMutex_;
  static int mpiThreadSupport_;

  MPI_Comm pgComm_;
};

4.2.2 初始化

createProcessGroupMPI 方法完成了程式組的初始化,其主要是呼叫了 MPI 程式設計常見套路,比如initMPIOnce,MPI_Comm_create,MPI_Barrier之類。

c10::intrusive_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI(
    std::vector<int> ranks) {
  // Once initialization
  initMPIOnce();
  MPI_Comm groupComm = MPI_COMM_WORLD;
  int rank = -1;
  int size = -1;

  {
    std::lock_guard<std::mutex> globalLock(pgGlobalMutex_);

    // If no ranks are specified, assume we're creating the root group
    if (!ranks.empty()) {
      MPI_Group worldGroup;
      MPI_Group ranksGroup;
      MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
      MPI_CHECK(
          MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
      constexpr int kMaxNumRetries = 3;
      bool groupComm_updated = false;
      MPI_Barrier(MPI_COMM_WORLD);
      for (int i = 0; i < kMaxNumRetries; ++i) {
        if (MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)) {
          groupComm_updated = true;
          break;
        }
      }
      MPI_CHECK(groupComm_updated);
      MPI_CHECK(MPI_Group_free(&worldGroup));
      MPI_CHECK(MPI_Group_free(&ranksGroup));
    }

    // Fetch rank and world size for this group (MPI_COMM_WORLD or new)
    if (groupComm != MPI_COMM_NULL) {
      MPI_CHECK(MPI_Comm_rank(groupComm, &rank));
      MPI_CHECK(MPI_Comm_size(groupComm, &size));
    }
  }

  // If this process is not part of the group, we don't construct a
  // process group instance. This is in line with the semantics of the
  // other process group types.
  if (groupComm == MPI_COMM_NULL) {
    return c10::intrusive_ptr<ProcessGroupMPI>(); // 生成
  }

  return c10::make_intrusive<ProcessGroupMPI>(rank, size, groupComm); // 生成
}
4.2.2.1 initMPIOnce

呼叫了 MPI_Init_thread API 初始化了 MPI 執行環境。

void ProcessGroupMPI::initMPIOnce() {
  // Initialize MPI environment
  std::call_once(onceFlagInitMPI, []() {
    MPI_CHECK(MPI_Init_thread(
        nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_));
    if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) {
      throw std::runtime_error(
          "Used MPI implementation doesn't have the "
          "minimum level of threading support: "
          "MPI_THREAD_SERIALIZED. This is required by "
          "c10d package");
    }
    if (std::atexit(ProcessGroupMPI::mpiExit)) {
      throw std::runtime_error("Fail to register the MPI exit handler");
    }
  });

4.2.2.2 ProcessGroupMPI

ProcessGroupMPI 構建方法之中 生成了 workerThread_,其執行 runLoop。

ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm)
    : ProcessGroup(rank, size), stop_(false), pgComm_(pgComm) {
  if (pgComm_ == MPI_COMM_NULL) {
    throw std::runtime_error("pgComm_ must not be MPI_COMM_NULL");
  }

  // Start the worker thread accepting MPI calls
  workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this);
}

4.2.3 執行

4.2.3.1 執行封裝

這裡有兩個封裝,WorkEntry 封裝計算執行,WorkMPI封裝計算執行結果(因為計算是非同步的)。具體如下:

WorkEntry 是執行方法的封裝,或者說每次需要執行的集合通訊操作,都要封裝在這裡。

struct WorkEntry {
  explicit WorkEntry(
      std::vector<at::Tensor>* srcPtr,
      std::vector<at::Tensor>* dstPtr,
      std::function<void(std::unique_ptr<WorkEntry>&)> run)
      : dst(dstPtr ? *dstPtr : std::vector<at::Tensor>()),
        run(std::move(run)) {
    if (srcPtr) {
      src = *srcPtr;
    }
  }

  // Not copyable
  WorkEntry(const WorkEntry&) = delete;
  // Not copy assignable
  WorkEntry& operator=(const WorkEntry&) = delete;

  // For input and output tensors (in-place), we will always use src
  std::vector<at::Tensor> src;

  // Copy of user provided outputs.
  const std::vector<at::Tensor> dst;

  // src rank returned, for recv only
  int* srcRank = nullptr;
  std::function<void(std::unique_ptr<WorkEntry>&)> run;
};

WorkMPI 是執行結果的封裝。

class WorkMPI : public ProcessGroup::Work {
 public:
  explicit WorkMPI(
      std::vector<at::Tensor> outputTensors,
      const char* profilingTitle = nullptr,
      const c10::optional<std::vector<at::Tensor>>& inputTensors =
          c10::nullopt)
      : ProcessGroup::Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors),
        outputTensors_(std::move(outputTensors)),
        future_(c10::make_intrusive<at::ivalue::Future>(
            c10::ListType::create(c10::TensorType::get()))) {}

  std::vector<at::Tensor> result() override;
  c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;

 protected:
  friend class ProcessGroupMPI;

 private:
  void finishWorkMPI();
  void finishWorkMPIError(std::exception_ptr eptr);

  std::vector<at::Tensor> outputTensors_;
  c10::intrusive_ptr<at::ivalue::Future> future_;
};

在往工作queue插入時候,實際插入的是二元組(WorkEntry, WorkMPI),我們後續會講解如何使用。

4.2.3.2 allreduce

以allreduce 為例,看看如何處理。就是把 MPI_Allreduce 封裝到 WorkEntry 之中,然後插入到 queue。

後續 runLoop 之中就是取出 WorkEntry,然後執行 MPI_Allreduce。

c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupMPI::allreduce(
    std::vector<at::Tensor>& tensors,
    const AllreduceOptions& opts) {
  checkSingleTensor(tensors);

  std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
      [opts, this](std::unique_ptr<WorkEntry>& entry) {
        auto data = (entry->src)[0];
        c10::DeviceGuard guard(data.device());
        std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
        MPI_CHECK(MPI_Allreduce( // 封裝了此函式
            MPI_IN_PLACE,
            data.data_ptr(),
            data.numel(),
            mpiDatatype.at(data.scalar_type()),
            mpiOp.at(opts.reduceOp),
            pgComm_));
      };
  auto entry = std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
  return enqueue(
      std::move(entry),
      "mpi:all_reduce",
      c10::optional<std::vector<at::Tensor>>(tensors));
}
4.2.3.3 enqueue

enqueue 方法是往queue插入二元組(WorkEntry, WorkMPI),裡面的 entry->dst 就是 計算結果存放到 WorkMPI 之中。

c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupMPI::enqueue(
    std::unique_ptr<WorkEntry> entry,
    const char* profilingTitle,
    const c10::optional<std::vector<at::Tensor>>& inputTensors) {
  // 生成 WorkMPI,把 entry->dst 就是 計算結果存放到 WorkMPI 之中
  auto work = c10::make_intrusive<WorkMPI>(entry->dst, profilingTitle, inputTensors);
  std::unique_lock<std::mutex> lock(pgMutex_);
  // 插入二元組
  queue_.push_back(std::make_tuple(std::move(entry), work));
  lock.unlock();
  queueProduceCV_.notify_one();
  return work;
}
4.2.3.4 runLoop

主迴圈runLoop方法就是不停的取出entry來處理。

void ProcessGroupMPI::runLoop() {
  std::unique_lock<std::mutex> lock(pgMutex_);

  while (!stop_) {
    if (queue_.empty()) {
      queueProduceCV_.wait(lock);
      continue;
    }

    auto workTuple = std::move(queue_.front());

    queue_.pop_front();

    auto& workEntry = std::get<0>(workTuple); // 進行計算
    auto& work = std::get<1>(workTuple); // 拿到WorkMPI

    lock.unlock();
    queueConsumeCV_.notify_one();

    try {
      workEntry->run(workEntry);
      work->finishWorkMPI(); // 會等待WorkMPI的計算結果
    } catch (...) {
      work->finishWorkMPIError(std::current_exception());
    }

    lock.lock();
  }
}

finishWorkMPI 會標示並且進行通知。

void ProcessGroupMPI::WorkMPI::finishWorkMPI() {
  future_->markCompleted(at::IValue(outputTensors_));
  finish();
}

基類程式碼如下:

void ProcessGroup::Work::finish(std::exception_ptr exception) {
  std::unique_lock<std::mutex> lock(mutex_);
  completed_ = true;
  exception_ = exception;
  if (recordFunctionEndCallback_) {
    recordFunctionEndCallback_();
    recordFunctionEndCallback_ = nullptr;
  }
  lock.unlock();
  cv_.notify_all();
}

具體如下圖:

                                                                        +
                                                             Worker 1   |   Worker 2
                                                                        |
                                                                        |
                                                                        |
+-----------------+           +--------------------------------------+  |   +------------------------------------+            +---------------+
| Main Thread     |           |  ProcessGroupMPI                     |  |   | ProcessGroupMPI                    |            | Main Thread   |
|                 |           |                                      |  |   |                                    |            |               |
|                 |           |                                      |  |   |                                    |            |               |
|                 |           |                                      |  |   |                                    |            |               |
|                 |           |  +--------------------------------+  |  |   |  +------------------------------+  |            |               |
|                 |           |  |  runLoop        workerThread_  |  |  |   |  | runloop    workerThread_     |  |            |               |
|                 |           |  |                                |  |  |   |  |                              |  |            |               |
|                 |           |  |                                |  |  |   |  |                              |  |            |               |
|  +---------+    |           |  |   +-------------------------+  |  |  |   |  |  +-----------------------+   |  |            |               |
|  |         |    | allreduce |  |   | queue_                  |  |  |  |   |  |  | queue_                |   |  | allreduce  |   +---------+ |
|  | Reducer | +-------------------> |                         |  |  |  |   |  |  |                       | <-------------------+ |         | |
|  |         |    |           |  |   |                         |  |  |  |   |  |  |                       |   |  |            |   | Reducer | |
|  +---------+    |           |  |   |  +-------------------+  |  |  |  |   |  |  |  +-----------------+  |   |  |            |   |         | |
|                 |           |  |   |  |WorkEntry          |  |  |  |  |   |  |  |  | WorkEntry       |  |   |  |            |   +---------+ |
|                 |           |  |   |  |                   |  |  |  |  |   |  |  |  |                 |  |   |  |            |               |
|                 |           |  |   |  |   MPI_Allreduce <-----------------------------> MPI_Allreduce|  |   |  |            |               |
|                 |           |  |   |  |                   |  |  |  |  |   |  |  |  |                 |  |   |  |            |               |
|                 |           |  |   |  +-------------------+  |  |  |  |   |  |  |  +-----------------+  |   |  |            |               |
|                 |           |  |   |                         |  |  |  |   |  |  |                       |   |  |            |               |
|                 |           |  |   |                         |  |  |  |   |  |  |                       |   |  |            |               |
|                 |           |  |   +-------------------------+  |  |  |   |  |  +-----------------------+   |  |            |               |
|                 |           |  |                                |  |  |   |  |                              |  |            |               |
|                 |           |  +--------------------------------+  |  |   |  +------------------------------+  |            |               |
|                 |           |                                      |  |   |                                    |            |               |
|                 |           |                                      |  |   |                                    |            |               |
|                 |           |                                      |  |   |                                    |            |               |
+-----------------+           +--------------------------------------+  |   +------------------------------------+            +---------------+
                                                                        |
                                                                        |
                                                                        +

手機如下:

4.4 封裝

PyTorch 對各種 process group 做了封裝,這樣使用者就可以呼叫 GroupMember.WORLD 來完成各種操作,但是使用者是無感的。

def _get_default_group():
    """
    Getting the default process group created by init_process_group
    """
    if not is_initialized():
        raise RuntimeError("Default process group has not been initialized, "
                           "please make sure to call init_process_group.")
    return GroupMember.WORLD

又比如,在 torch/distributed/distributed_c10d.py 之中如下方法可以看到 all_to_all 和 all_gather 之類的函式,其註釋有很詳細的用法(這裡因為篇幅所限略去),大家有興趣可以自行學習。

def all_to_all(output_tensor_list,
               input_tensor_list,
               group=None,
               async_op=False):
    """
    Each process scatters list of input tensors to all processes in a group and
    return gathered list of tensors in output list.

    Args:
        output_tensor_list (list[Tensor]): List of tensors to be gathered one
            per rank.
        input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        async_op (bool, optional): Whether this op should be an async op.

    Returns:
        Async work handle, if async_op is set to True.
        None, if not async_op or if not part of the group.
    """
    if _rank_not_in_group(group):
        return

    opts = AllToAllOptions()
    _check_tensor_list(output_tensor_list, "output_tensor_list")
    _check_tensor_list(input_tensor_list, "input_tensor_list")

    if group is None:
        default_pg = _get_default_group()
        work = default_pg.alltoall(output_tensor_list, input_tensor_list, opts)
    else:
        work = group.alltoall(output_tensor_list, input_tensor_list, opts)

    if async_op:
        return work
    else:
        work.wait()

all_gather 程式碼如下:

def all_gather(tensor_list,
               tensor,
               group=None,
               async_op=False):
    """
    Gathers tensors from the whole group in a list.

    Complex tensors are supported.

    Args:
        tensor_list (list[Tensor]): Output list. It should contain
            correctly-sized tensors to be used for output of the collective.
        tensor (Tensor): Tensor to be broadcast from current process.
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        async_op (bool, optional): Whether this op should be an async op

    Returns:
        Async work handle, if async_op is set to True.
        None, if not async_op or if not part of the group
    """
    _check_tensor_list(tensor_list, "tensor_list")
    _check_single_tensor(tensor, "tensor")
    if _rank_not_in_group(group):
        return

    tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]
    tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)

    if group is None:
        default_pg = _get_default_group()
        work = default_pg.allgather([tensor_list], [tensor])
    else:
        work = group.allgather([tensor_list], [tensor])

    if async_op:
        return work
    else:
        work.wait()

至此,程式組介紹完畢,下一篇我們分析DDP的論文,敬請期待。

0xFF 參考

pytorch(分散式)資料並行個人實踐總結——DataParallel/DistributedDataParallel

https://www.telesens.co/2019/04/04/distributed-data-parallel-training-using-pytorch-on-aws/

DISTRIBUTED TRAINING WITH UNEVEN INPUTS USING THE JOIN CONTEXT MANAGER

相關文章