[原始碼解析] PyTorch 分散式(17) --- 結合DDP和分散式 RPC 框架

羅西的思考發表於2021-12-16

[原始碼解析] PyTorch 分散式(17) --- 結合DDP和分散式 RPC 框架

0x00 摘要

在前面的文章之中,我們已經學習了PyTorch 分散式的基本模組,接下來我們通過幾篇文章來看看如何把這些模組應用到實踐之中,順便把PyTorch分散式邏輯整體梳理一下。本文介紹如何把DDP和RPC framework結合起來。

本文以 COMBINING DISTRIBUTED DATAPARALLEL WITH DISTRIBUTED RPC FRAMEWORK 的翻譯為基礎,加入了自己的理解。

PyTorch分散式其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀

[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎

[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構

[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯

[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法

[原始碼解析] PyTorch 分散式(1)------歷史和概述

[原始碼解析] PyTorch 分散式(2) ----- DataParallel(上)

[原始碼解析] PyTorch 分散式(3) ----- DataParallel(下)

[原始碼解析] PyTorch 分散式(4)------分散式應用基礎概念

[原始碼解析] PyTorch分散式(5) ------ DistributedDataParallel 總述&如何使用

[原始碼解析] PyTorch分散式(6) ---DistributedDataParallel -- 初始化&store

[原始碼解析] PyTorch 分散式(7) ----- DistributedDataParallel 之程式組

[原始碼解析] PyTorch 分散式(8) -------- DistributedDataParallel之論文篇

[原始碼解析] PyTorch 分散式(9) ----- DistributedDataParallel 之初始化

[原始碼解析] PyTorch 分散式(10)------DistributedDataParallel 之 Reducer靜態架構

[原始碼解析] PyTorch 分散式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作

[原始碼解析] PyTorch 分散式(12) ----- DistributedDataParallel 之 前向傳播

[原始碼解析] PyTorch 分散式(13) ----- DistributedDataParallel 之 反向傳播

[原始碼解析] PyTorch 分散式 Autograd (1) ---- 設計

[原始碼解析] PyTorch 分散式 Autograd (2) ---- RPC基礎

[原始碼解析] PyTorch 分散式 Autograd (3) ---- 上下文相關

[原始碼解析] PyTorch 分散式 Autograd (4) ---- 如何切入引擎

[原始碼解析] PyTorch 分散式 Autograd (5) ---- 引擎(上)

[原始碼解析] PyTorch 分散式 Autograd (6) ---- 引擎(下)

[原始碼解析] PyTorch分散式優化器(1)----基石篇

[原始碼解析] PyTorch分散式優化器(2)----資料並行優化器

[原始碼解析] PyTorch分散式優化器(3)---- 模型並行

[原始碼解析] PyTorch 分散式(14) --使用 Distributed Autograd 和 Distributed Optimizer

[原始碼解析] PyTorch 分散式(15) --- 使用分散式 RPC 框架實現引數伺服器

[原始碼解析] PyTorch 分散式(16) --- 使用非同步執行實現批處理 RPC

注:本文沒有完全按照原文順序進行翻譯,而是按照自己理解的思路重新組織了文章。

0x00 綜述

本教程使用一個簡單的示例來演示如何將 DistributedDataParallel (DDP) 與分散式 RPC 框架 相結合,將分散式資料並行性與分散式模型並行性相結合,以訓練一個簡單的模型。該示例的原始碼可以在這裡找到。

前面的教程 入門分散式資料並行入門分散式RPC框架 分別描述瞭如何執行分散式資料並行和分散式模型平行訓練。儘管如此,您可能希望在多種訓練正規化中結合這兩種技術。例如:

  1. 如果我們有一個包含稀疏部分(大型嵌入表)和密集部分(FC 層)的模型,我們可能希望將嵌入表放在引數伺服器上,並使用DistributedDataParallel在多個trainer之間複製 FC 層。分散式RPC框架 就可被用於在引數伺服器上執行嵌入查詢。
  2. PipeDream論文中所述啟用混合並行性。我們可以使用分散式 RPC 框架 將模型的各個階段跨多個worker 進行流水線化,並使用DistributedDataParallel 對每個階段進行資料並行(如果需要)。

在本教程中,我們將介紹上述案例 1。我們的設定中共有 4 個 worker,如下所示:

  • 1 個Master,負責在引數伺服器上建立嵌入表(nn.EmbeddingBag)。master 還負責驅動兩個trainer上的訓練迴圈。
  • 1 個Parameter Server,它將嵌入表儲存在記憶體中,並響應來自 Master 和 Trainer 的 RPC 請求。
  • 2 個trainer,它儲存一個 FC 層 (nn.Linear),其使用DistributedDataParallel 進行資料並行。trainer還負責執行前向傳播、後向傳播和優化器步驟。

整個訓練過程執行如下:

  1. Master 建立一個RemoteModule ,在引數伺服器上儲存一個嵌入表。
  2. Master 在trainer上啟動訓練迴圈,並將遠端模組(remote module)傳播給trainer。
  3. Trainer 建立一個HybridModel,其首先使用 master 提供的遠端模組執行嵌入查詢(embedding lookup),然後執行封裝在 DDP 中的 FC 層。
  4. Trainer 執行模型的前向傳播,並使用Distributed Autograd 對損失執行後向傳播。
  5. 作為反向傳播的一部分,首先計算 FC 層的梯度,並通過 DDP 中的 allreduce 同步到所有trainer。
  6. 接下來,分散式 Autograd 將梯度傳播到引數伺服器,在那裡更新嵌入表的梯度。
  7. 最後,分散式優化器被用於更新所有引數。

注意:如果您將 DDP 和 RPC 結合使用,則應始終使用Distributed Autograd進行反向傳播。

0x01 啟動

我們看看系統如何啟動。首先,在進行訓練之前,需要設定所有worker。我們建立了 4 個程式,其中 rank 0 和 rank 1 是我們的trainer,rank 2是master,rank 3是引數伺服器。

初始化邏輯如下:

  • 我們使用 TCP init_method 在所有 4 個 worker 上初始化 RPC 框架。
  • 對於 Master,程式碼做了如下操作:
    • 完成 RPC 初始化後,master 建立一個遠端模組RemoteModule,該模組指向一個在引數伺服器上儲存的EmbeddingBag層。
    • 然後 master 遍歷每個trainer,並通過使用rpc_async呼叫_run_trainer 在每個trainer之上啟動訓練迴圈。
    • 最後,master 在退出之前等待所有訓練完成。
  • Trainers做了如下操作:
    • Trainers 首先使用 init_process_group為DDP初始化一個world_size = 2(對於兩個trainer)的ProcessGroup
    • 接下來,Trainers 使用 TCP init_method 初始化 RPC 框架。注意RPC初始化和ProcessGroup初始化的埠是不同的。這是為了避免兩個框架的初始化之間的埠衝突。
    • 初始化完成後,trainer只需等待來自 master的_run_trainer RPC。
  • 引數伺服器只是初始化 RPC 框架並等待來自trainer和master的 RPC。

具體程式碼如下:

def run_worker(rank, world_size):
    r"""
    A wrapper function that initializes RPC, calls the function, and shuts down
    RPC.
    """

    # We need to use different port numbers in TCP init_method for init_rpc and
    # init_process_group to avoid port conflicts.
    rpc_backend_options = TensorPipeRpcBackendOptions()
    rpc_backend_options.init_method = "tcp://localhost:29501"

    # Rank 2 is master, 3 is ps and 0 and 1 are trainers.
    if rank == 2: # Master程式碼
        rpc.init_rpc(
            "master",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

        remote_emb_module = RemoteModule( # 指向一個在引數伺服器上儲存的EmbeddingBag層
            "ps",
            torch.nn.EmbeddingBag,
            args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
            kwargs={"mode": "sum"},
        )

        # Run the training loop on trainers.
        futs = []
        for trainer_rank in [0, 1]:
            trainer_name = "trainer{}".format(trainer_rank)
            fut = rpc.rpc_async( # 啟動 trainer迴圈
                trainer_name, _run_trainer, args=(remote_emb_module, trainer_rank)
            )
            futs.append(fut)

        # Wait for all training to finish.
        for fut in futs:
            fut.wait()
    elif rank <= 1:
        # Initialize process group for Distributed DataParallel on trainers.
        dist.init_process_group(
            backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
        )

        # Initialize RPC.
        trainer_name = "trainer{}".format(rank)
        rpc.init_rpc(
            trainer_name,
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

        # 只需等待來自 master的 _run_trainer RPC
        # Trainer just waits for RPCs from master.
    else:
        rpc.init_rpc( # 引數伺服器
            "ps",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )
        # parameter server do nothing
        pass # 啥也不幹,只是等待來自trainer和master的 RPC

    # block until all rpcs finish
    rpc.shutdown()


if __name__ == "__main__":
    # 2 trainers, 1 parameter server, 1 master.
    world_size = 4
    mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)

目前邏輯如下,我們後續會繼續擴充:

                               torch.multiprocessing.spawn
                                          +
                                          |
                                          |
              +----------------------------------------------------------------+----------------------------------+
              |                           |                                    |                                  |
              |                           |                                    |                                  |
              v                           v                                    v                                  v
+-------------+-------------+  +----------+---------------+ +------------------+------------------+ +-------------+--------+
|trainer 0         rank = 0 |  |trainer 1     rank = 1    | | master                     rank = 2 | |ps          rank = 3  |
|                           |  |                          | |                                     | |                      |
|                           |  |                          | |   rpc.init_rpc                      | |     rpc.init_rpc     |
|                           |  |                          | |                                     | |                      |
|   dist.init_process_group |  |  dist.init_process_group | |   remote_emb_module =  RemoteModule | |                      |
|                           |  |                          | |                                     | |                      |
|                           |  |                          | |                                     | |                      |
|   rpc.init_rpc            |  |  rpc.init_rpc            | |   fut = rpc.rpc_async(_run_trainer) | |                      |
|                           |  |                          | |                                     | |                      |
|                           |  |                          | |                                     | |                      |
+---------------------------+  +--------------------------+ +-------------------------------------+ +----------------------+

手機如下:

0x03 支撐系統

支撐系統主要指的就是 _RemoteModule,其作用是在異地建立一個模型,具體程式碼在:torch/distributed/nn/api/remote_module.py。

3.1 功能

RemoteModule例項只能在RPC初始化之後建立,它可以在指定的遠端節點上建立使用者指定的模組,其行為類似於常規的nn.Module方法,但不同之處是 RemoteModule 在遠端節點上執行forward方法。RemoteModule 負責autograd recording,以確保向後傳播可以將梯度傳播回相應的遠端模組。

RemoteModule 可以使用RPC framework <https://pytorch.org/docs/stable/rpc.html> 在處理器之間共享,且不會產生複製實際模組的任何開銷,這相當於使用一個~torch.distributed.rpc.RRef指向遠端模組。

3.2 使用

3.2.1 混合模型

要建立混合模型,通常應該在遠端模組之外建立本地模組,而不是作為任何遠端模組的子模組。如果遠端模組放置在cuda裝置上,那麼任何輸入CPU張量將自動移動到同一cuda裝置之上。混合模型例子如下:

            >>> class HybridModel(nn.Module):
            >>>     def __init__(self):
            >>>         nn.Module.__init__(self)
            >>>         self.remote_embedding = RemoteModule(...) # 在遠端建立嵌入層
            >>>         self.local_linear = nn.Linear(...)

3.2.2 使用

使用例子如下,需要在兩個不同程式上執行如下程式碼,例子之中,RemoteModule 建立時候,傳入了一個"worker1/cpu"引數,意思是在 worker1 的 cpu 裝置上執行這個RemoteModule。具體格式是: <workername> / <device>,其中 <device> 是torch.device型別。

    Example::
        >>> # On worker 0:
        >>> import torch
        >>> import torch.distributed.rpc as rpc
        >>> from torch import nn, Tensor
        >>> from torch.distributed.nn.api.remote_module import RemoteModule
        >>>
        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
        >>> remote_linear_module = RemoteModule(
        >>>     "worker1/cpu", nn.Linear, args=(20, 30),
        >>> )
        >>> input = torch.randn(128, 20)
        >>> ret_fut = remote_linear_module.forward_async(input)
        >>> ret = ret_fut.wait()
        >>> rpc.shutdown()

        >>> # On worker 1:
        >>> import torch
        >>> import torch.distributed.rpc as rpc
        >>>
        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
        >>> rpc.shutdown()

3.3 定義

_RemoteModule定義如下,具體初始化邏輯是:

  • (1). 準備引數。
  • (2). 設定執行的遠端worker和遠端裝置。
  • (3). 如果設定了_module_interface_cls
    • (3.1) 使用 _module_interface_cls 來在遠端構建模組。_
    • (3.2) 在本地構建函式代理生成器。
    • (3.3) 等待建立完成。
    • (3.4) 在本地構建控制程式碼。
  • (4) 沒有設定_module_interface_cls。
    • (4.1) 在本地構建函式代理生成器。
    • (4.2) 在遠端建立模組。
  • (5). 在本地建立遠端函式代理。
class _RemoteModule(nn.Module):
    def __init__(
        self,
        remote_device: str,
        module_cls: nn.Module,
        args: Tuple = None,
        kwargs: Dict[str, Any] = None,
        _module_interface_cls: Any = None,
    ):
        """
        Args:
            remote_device (str): Device on the destination worker where we'd like to place this module.
                The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
                E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
                In addition, the device field can be optional and the default value is "cpu".

        Returns:
            A remote module instance which wraps the :class:`~nn.Module` created by the
            user-provided ``module_cls``, it has a blocking ``forward`` method and an
            asynchronous ``forward_async`` method that returns a future of the ``forward`` call
            on the user-provided module on the remote side.
        """
        super().__init__()

        # NOTE: if a new attribute is added to this class, also need to add it
        # to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` for pickling/unpickling.

        # Default arguments preperation.
        # 1. 準備引數
        args = args if args is not None else ()
        kwargs = kwargs if kwargs is not None else {}

        # 2. 設定執行的遠端worker和遠端裝置
        self.on, self.device = _parse_remote_device(remote_device)
        agent = rpc._get_current_rpc_agent()
        # If the device map of the remote worker is set,
        # then enable moving any input CPU tensors to the same cuda device.
        self.is_device_map_set = bool(
            agent._get_device_map(agent.get_worker_info(self.on))
        )
        # ``enable_moving_cpu_tensors_to_cuda`` is less strict than ``is_device_map_set``:
        # If ``enable_moving_cpu_tensors_to_cuda`` is true, but the device map is not set,
        # then any CPU tensors can still be moved to a cuda device to run forward,
        # but the output must be moved back to CPU before being sent over the wire.
        enable_moving_cpu_tensors_to_cuda = torch.device(self.device).type == "cuda"

        # 3. 如果設定了_module_interface_cls
        if _module_interface_cls is not None:
            # Users reply on this field to know if this generated RemoteModule is TorchScript-able.
            self.is_scriptable = True

            # 3.1 使用 _module_interface_cls 來在遠端構建模組
            # Instantiate template on remote side.
            fut = rpc.rpc_async(
                self.on,
                _instantiate_template,
                (_module_interface_cls, enable_moving_cpu_tensors_to_cuda),
            )

            # 3.2 在本地構建函式代理生成器
            # Instantiate template on local side.
            generated_module = (
                instantiator.instantiate_scriptable_remote_module_template(
                    _module_interface_cls, enable_moving_cpu_tensors_to_cuda
                )
            )
            self.generated_methods = generated_module._generated_methods

            # 3.3 等待建立完成
            # Create the module on the remote side.
            fut.wait()  # Ensure remote_module_cls is available on remote side.

            # 3.4 在本地構建控制程式碼
            self.module_rref = rpc.rpc_sync(
                self.on,
                _create_module_with_interface,
                (module_cls, args, kwargs, self.device, _module_interface_cls),
            )
        else: # 4 沒有設定_module_interface_cls
            self.is_scriptable = False
            # 4.1 在本地構建函式代理生成器
            self.generated_methods = (
                _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
            )
            # 4.2 在遠端建立模組
            # Create the module on the remote side.
            self.module_rref = rpc.remote(
                self.on,
                _create_module,
                (module_cls, args, kwargs, self.device),
            )

        # Install generated methods.
        # 5. 在本地建立遠端函式代理
        for method in self.generated_methods:
            method_name = method.__name__
            method = torch.jit.export(method)
            setattr(self, method_name, types.MethodType(method, self))

3.4 主要函式

其主要函式如下:

  • rpc.rpc_sync 返回指向遠端模組引數的~torch.distributed.rpc.RRef列表。通常可以與~torch.distributed.optim.DistributedOptimizer結合使用。

  • get_module_rref 返回一個指向遠端模組的~torch.distributed.rpc.RRef(RRef[nn.Module])類。

def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
    """
    Returns a list of :class:`~torch.distributed.rpc.RRef` pointing to the
    remote module's parameters. This can typically be used in conjuction
    with :class:`~torch.distributed.optim.DistributedOptimizer`.

    Args:
        recurse (bool): if True, then returns parameters of the remote
            module and all submodules of the remote module. Otherwise,
            returns only parameters that are direct members of the
            remote module.

    Returns:
        A list of :class:`~torch.distributed.rpc.RRef` (``List[RRef[nn.Parameter]]``)
        to remote module's parameters.
    """
    return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse))

def get_module_rref(self) -> rpc.RRef[nn.Module]:
    """
    Returns an :class:`~torch.distributed.rpc.RRef` (``RRef[nn.Module]``)
    pointing to the remote module.
    """
    return self.module_rref

於是邏輯圖轉換如下,在上圖基礎之上多了一個remote_emb_module,其在ps之上建立了一個RemoteModule

                                torch.multiprocessing.spawn
                                           +
                                           |
                                           |
               +----------------------------------------------------------------+----------------------------------+
               |                           |                                    |                                  |
               |                           |                                    |                                  |
               v                           v                                    v                                  v
+--------------+-------------+ +-----------+--------------+ +-------------------+-----------------+  +-------------+--------+
|trainer 0          rank = 0 | |trainer 1     rank = 1    | | master                     rank = 2 |  |ps          rank = 3  |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |     rpc.init_rpc                    |  |     rpc.init_rpc     |
|                            | |                          | |                                     |  |                      |
|    dist.init_process_group | |  dist.init_process_group | |   remote_emb_module +----------------------> RemoteModule     |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |                                     |  |                      |
|    rpc.init_rpc            | |  rpc.init_rpc            | |   fut = rpc.rpc_async(_run_trainer) |  |                      |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |                                     |  |                      |
+----------------------------+ +--------------------------+ +-------------------------------------+  +----------------------+

手機如下:

0x04 HybridModel

在討論 Trainer 的細節之前,讓我們先介紹一下 Trainer使用的HybridModel。該模型由稀疏部分和稠密部分組成。

  • 稠密部分是一個nn.Linear,使用DistributedDataParallel在所有trainer中複製,即 在 DDP 內包裝了一個 nn.Linear層。

  • 稀疏部分是一個遠端模組 (remote_emb_module) ,它持有一個在引數伺服器上的nn.EmbeddingBag。即,此遠端模組可以獲取引數伺服器上嵌入表的遠端引用。

該模型的前向方法非常簡單。它使用 RemoteModule 在引數伺服器上執行嵌入查詢forward ,並將其輸出傳播到 FC 層,這裡的 FC 使用了DDP

class HybridModel(torch.nn.Module):
    r"""
    The model consists of a sparse part and a dense part.
    1) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.
    2) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.
    This remote model can get a Remote Reference to the embedding table on the parameter server.
    """

    def __init__(self, remote_emb_module, device):
        super(HybridModel, self).__init__()
        self.remote_emb_module = remote_emb_module
        self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
        self.device = device

    def forward(self, indices, offsets):
        emb_lookup = self.remote_emb_module.forward(indices, offsets)
        return self.fc(emb_lookup.cuda(self.device))

邏輯擴充如下,兩個trainer 之上也建立了remote_emb_module,指向了ps之上的RemoteModule

                                         torch.multiprocessing.spawn
                                                    +
                                                    |
                                                    |
            +-----------------------------------------------------------------------------------+----------------------------------+
            |                                       |                                           |                                  |
            |                                       |                                           |                                  |
            v                                       v                                           v                                  v
+-----------+-------------+ +-----------------------+-------------------+ +---------------------+---------------+    +-------------+--------+
|trainer 0       rank = 0 | | trainer 1                        rank = 1 | | master                     rank = 2 |    |ps          rank = 3  |
|                         | |                                           | |                                     |    |                      |
|                         | |                                           | |   rpc.init_rpc                      |    |     rpc.init_rpc     |
| dist.init_process_group | | dist.init_process_group                   | |                                     |    |                      |
|                         | |                                           | |   remote_emb_module +------------------------> RemoteModule     |
| rpc.init_rpc            | | rpc.init_rpc                              | |                                     |    |         ^     ^      |
|                         | |                                           | |                                     |    |         |     |      |
|                         | |                                           | |   fut = rpc.rpc_async(_run_trainer) |    |         |     |      |
|                         | |                                           | |                                     |    |         |     |      |
| +---------------------+ | |            +---------------------------+  | |                                     |    |         |     |      |
| | HybridModel         | | |            |HybridModel                |  | |                                     |    |         |     |      |
| |                     | | |            |                           |  | +-------------------------------------+    +----------------------+
| |                     | | |            |                           |  |                                                      |     |
| |   fc = DDP(Linear)  | | |            |      fc = DDP(Linear())   |  |                                                      |     |
| |                     | | |            |                           |  |                                                      |     |
| |   remote_emb_module | | |            |      remote_emb_module+-------------------------------------------------------------+     |
| |             +       | | |            |                           |  |                                                            |
| +---------------------+ | |            +---------------------------+  |                                                            |
|               |         | |                                           |                                                            |
+-------------------------+ +-------------------------------------------+                                                            |
                |                                                                                                                    |
                +--------------------------------------------------------------------------------------------------------------------+

手機如下:

0x05 訓練

5.1 初始化

之前初始化時候,我們漏過了trainer的初始化,這裡我們分析一下。

我們先看看 Trainer 上的設定。

  • 首先,trainer使用遠端模組(remote module)和自己的rank 來建立上面提到的 HybridModel,遠端模組持有引數伺服器上的嵌入表。
  • 其次,我們需要得到一個RRef 列表,該列表指向我們想要使用DistributedOptimizer優化的所有引數。
    • 要從引數伺服器嵌入表之中拿到這些引數,我們可以呼叫 RemoteModule 的remote_parameters,它會遍歷嵌入表的所有引數並返回一個 RRef 列表。trainer通過 RPC 在引數伺服器上呼叫此方法來得到所需引數的 RRef 列表。
    • 由於 DistributedOptimizer 始終持有一個需要優化引數的 RRef 列表,因此我們需要為 FC 層的區域性引數建立 RRef。這是通過遍歷model.fc.parameters()來完成的,其將為每個引數建立一個 RRef 並將其附加到從remote_parameters()返回的列表中。
    • 請注意,我們不能使用model.parameters(),因為它會遞迴呼叫model.remote_emb_module.parameters(),而RemoteModule不支援這種操作。
  • 最後,我們使用所有 RRef 建立我們的 DistributedOptimizer 並定義一個 CrossEntropyLoss 函式。
def _run_trainer(remote_emb_module, rank):
    r"""
    Each trainer runs a forward pass which involves an embedding lookup on the
    parameter server and running nn.Linear locally. During the backward pass,
    DDP is responsible for aggregating the gradients for the dense part
    (nn.Linear) and distributed autograd ensures gradients updates are
    propagated to the parameter server.
    """

    # Setup the model.
    model = HybridModel(remote_emb_module, rank)

    # Retrieve all model parameters as rrefs for DistributedOptimizer.

    # Retrieve parameters for embedding table.
    model_parameter_rrefs = model.remote_emb_module.remote_parameters()

    # model.fc.parameters() only includes local parameters.
    # NOTE: Cannot call model.parameters() here,
    # because this will call remote_emb_module.parameters(),
    # which supports remote_parameters() but not parameters().
    for param in model.fc.parameters(): 
        model_parameter_rrefs.append(RRef(param)) # 這裡新增了需要分散式優化的 DDP 的引數

    # Setup distributed optimizer
    opt = DistributedOptimizer(
        optim.SGD,
        model_parameter_rrefs, # dense引數和sparse引數一起分散式優化
        lr=0.05,
    )

    criterion = torch.nn.CrossEntropyLoss()

我們邏輯擴充如下,這裡省略了 trainer 0 指向 引數伺服器的箭頭,與上圖相比,增加了 DistributedOptimizer。

                                            torch.multiprocessing.spawn
                                                       +
                                                       |
                                                       |
               +-----------------------------------------------------------------------------------+----------------------------------+
               |                                       |                                           |                                  |
               |                                       |                                           |                                  |
               v                                       v                                           v                                  v
+--------------+-------------+ +-----------------------+-------------------+ +---------------------+---------------+  +---------------+-------------+
|trainer 0          rank = 0 | | trainer 1                        rank = 1 | | master                     rank = 2 |  |  ps                rank = 3 |
|                            | |                                           | |                                     |  |                             |
|                            | |                                           | |                                     |  |      rpc.init_rpc           |
| dist.init_process_group    | | dist.init_process_group                   | |   rpc.init_rpc                      |  |                             |
|                            | |                                           | |                                     |  |    +----------------------+ |
| rpc.init_rpc               | | rpc.init_rpc                              | |                            1        |  |    | RemoteModule         | |
|                            | |                                           | |   remote_emb_module +---------------------> |                      | |
| +------------------------+ | | +---------------------------------------+ | |                                     |  |    |                      | |
| | _run_trainer           | | | | _run_trainer                          | | |                                     |  |    |  remote_parameters() | |
| |                        | | | |                                       | | |   fut = rpc.rpc_async(_run_trainer) |  |    |                      | |
| |                        | | | |   output = model(indices, offsets)    | | |                                     |  |    |                      | |
| |                        | | | |   dist_autograd.backward              | | |                                     |  |    +------+--------+------+ |
| |                        | | | |   opt.step                            | | |                                     |  |           ^        ^        |
| |                        | | | |                                       | | |                                     |  |           |        |        |
| | +-------------------+  | | | |                                       | | +-------------------------------------+  +-----------------------------+
| | | HybridModel       |  | | | |  +-----------------------------+      | |                                                      |        |
| | |                   |  | | | |  | HybridModel                 |      | |                                                      |        |
| | | fc = DDP(Linear)  |  | | | |  |                             |      | |                                                      |        |
| | | remote_emb_module |  | | | |  |  fc = DDP(Linear().cuda()   |      | |                                                      |        |
| | |                   |  | | | |  |  remote_emb_module+------------------------------------------------------------------------->        |
| | +-------------------+  | | | |  |                             |      | |                             2                                 |
| |                        | | | |  +-----------------------------+      | |                                                               |
| | +--------------------+ | | | |  +-----------------------------+      | |                                                               |
| | |DistributedOptimizer| | | | |  |DistributedOptimizer         |      | |                                                               |
| | +--------------------+ | | | |  |                             +------------------------------------------------------------------------>
| |                        | | | |  +-----------------------------+      | |                              3
| +------------------------+ | | +---------------------------------------+ |
+----------------------------+ +-------------------------------------------+


手機如下:

5.2 訓練迴圈

現在我們介紹在每個trainer上執行的主訓練迴圈。這裡 get_next_batch只是一個輔助函式,用於生成隨機輸入和訓練目標。我們為多個epoch和每個batch執行該訓練迴圈:

  1. 為Distributed Autograd.設定Distributed Autograd Context
  2. 執行模型的前向傳播並拿到其輸出。
  3. 使用損失函式根據我們的輸出和target來計算損失。
  4. 使用 Distributed Autograd 對損失執行分散式反向傳播。
  5. 最後,執行分散式優化器step 來優化所有引數。
    def get_next_batch(rank):
        for _ in range(10):
            num_indices = random.randint(20, 50)
            indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)

            # Generate offsets.
            offsets = []
            start = 0
            batch_size = 0
            while start < num_indices:
                offsets.append(start)
                start += random.randint(1, 10)
                batch_size += 1

            offsets_tensor = torch.LongTensor(offsets)
            target = torch.LongTensor(batch_size).random_(8).cuda(rank)
            yield indices, offsets_tensor, target

    # Train for 100 epochs
    for epoch in range(100):
        # create distributed autograd context
        for indices, offsets, target in get_next_batch(rank):
            with dist_autograd.context() as context_id:
                output = model(indices, offsets)
                loss = criterion(output, target)

                # Run distributed backward pass
                dist_autograd.backward(context_id, [loss])

                # Tun distributed optimizer
                opt.step(context_id)

                # Not necessary to zero grads as each iteration creates a different
                # distributed autograd context which hosts different grads
        print("Training done for epoch {}".format(epoch))

因為篇幅所限,我們只是把上面的trainer再細化如下圖:

  1. 初始化時候,呼叫 dist.init_process_group 來初始化 DistributedDataParallel,呼叫 rpc.init_rpc 來初始化 RPC。
  2. HybridModel 之中,fc 是DistributedDataParallel方式,remote_emb_module 是引數伺服器上的 RemoteModule。
  3. DistributedOptimizer 之中,對於 HybridModel 的 fc 和 remote_emb_module 都會進行分散式優化。
  4. _run_trainer 之中,使用 model(indices, offsets) 進行前向傳播,其中會呼叫到 HybridModel.forward。
  5. HybridModel.forward 之中則對embedding 和 fc 進行操作。
    1. embedding 是利用RPC 和 引數伺服器。
    2. fc 是利用 DistributedDataParallel。
    3. 將嵌入表放在引數伺服器上,並使用DistributedDataParallel 在多個trainer之間複製 FC 層。

這些序號與下圖中數字對應。

+---------------------------------------------------------------------+
| trainer 1                                                 rank = 1  |
|                +-----------------------------------+                |
|                |    dist.init_process_group      1 |                |
|                |                                   |                |
|                |    rpc.init_rpc                   |                |
|                |                                   |                |
|                +-----------------------------------+                |
| +-----------------------------------------------------------------+ |
| | _run_trainer                                                    | |
| |                                                                 | |
| |     output = model(indices, offsets)                            | |
| |     dist_autograd.backward      +                               | |
| |     opt.step                    |                               | |
| |  +-----------------------------------------------------------+  | |
| |  | HybridModel                  |                          2 |  | |
| |  |                              |                            |  | |
| |  |    fc = DDP(Linear().cuda()  |                            |  | |
| |  |                              |4                           |  | |
| |  |    remote_emb_module         |                            |  | |
| |  |                              |                            |  | |
| |  |                              v                            |  | |
| |  |   +--------------------------+--------------------------+ |  | |
| |  |   |forward                                              | |  | |
| |  |   |  emb_lookup = remote_emb_module.forward()           | |  | |
| |  |   |                  +                                  | |  | |
| |  |   |                  |  5                               | |  | |
| |  |   |                  |                                  | |  | |
| |  |   |                  v                                  | |  | |
| |  |   |  fc(emb_lookup.cuda(device)                         | |  | |
| |  |   |                                                     | |  | |
| |  |   +-----------------------------------------------------+ |  | |
| |  +-----------------------------------------------------------+  | |
| |  +-----------------------------------------------------------+  | |
| |  | DistributedOptimizer                                    3 |  | |
| |  |                                                           |  | |
| |  |         HybridModel.remote_emb_module.remote_parameters() |  | |
| |  |                                                           |  | |
| |  |         HybridModel.fc.parameters()                       |  | |
| |  |                                                           |  | |
| |  +-----------------------------------------------------------+  | |
| +-----------------------------------------------------------------+ |
+---------------------------------------------------------------------+

手機如下:

注,可以在此處找到整個示例的原始碼。

0x06 比對

我們已經看了三篇PyTorch官方樣例,裡面對引數伺服器的實現各有不同。對於本文來說,又加入了一個master作為協調者來統一各個worker。

總的來說,在PyTorch 之中,因為有了 RPC 機制,所以PyTorch 的引數伺服器實現比 ps-lite, paracel 更佳靈活機動:

  • 首先引數伺服器目前可以放在 GPU 之中。
  • 其次,可以在引數伺服器只放置引數,也可以執行優化程式碼,甚至可以在引數服務之上啟動控制trainer。
  • 具體優化器根據實際需要,可以是普通優化器,也可以是DistributedOptimizer。
  • 訓練程式碼從使用者編寫角度看則完全是執行在本地。

0xFF 參考

COMBINING DISTRIBUTED DATAPARALLEL WITH DISTRIBUTED RPC FRAMEWORK

相關文章