[原始碼解析] PyTorch 分散式之彈性訓練(6)---監控/容錯

羅西的思考發表於2021-12-29

[原始碼解析] PyTorch 分散式之彈性訓練(6)---監控/容錯

0x00 摘要

關於PyTorch彈性訓練,迄今為止我們已經分別介紹了 Agent 和 rendezous,但是有些部分並沒有深入,比如監控,本文就把它們統一起來,對彈性訓練做一個整體邏輯上的梳理。

彈性訓練系列文章如下:

[原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路

[原始碼解析] PyTorch 分散式之彈性訓練(2)---啟動&單節點流程

[原始碼解析] PyTorch 分散式之彈性訓練(3)---代理

[原始碼解析] PyTorch 分散式之彈性訓練(4)---Rendezvous 架構和邏輯

[原始碼解析] PyTorch 分散式之彈性訓練(5)---Rendezvous 引擎

0x01 總體邏輯

我們需要從幾個角度來看看系統邏輯,大致是從上到下,由整體到區域性。

1.1 Node叢集角度

我們首先從 Node 叢集角度看看,可以認為是從上到下來鳥瞰彈性系統。在這種視角下,每個Node 上執行一個 Agent,Agent之中包含一個 rendezous,負責分散式協商,Agent 同時負責啟動workers,監控 workers。

1.2 Agent總體邏輯圖

我們然後深入到代理內部,由前文得知,目前總體邏輯如下圖。

  • 1)呼叫 _initialize_workers 來啟動 worker 程式,也就是啟動了多個程式並行執行使用者程式進行訓練。
    • 2)呼叫 _rendezvous,其內部:
      • 呼叫 next_rendezvous 處理成員關係變化,
      • 呼叫 _assign_worker_ranks 為 worker 建立 ranks。
    • 3)呼叫 _start_workers 啟動 workers。
  • 4)呼叫 _monitor_workers 監控這些程式的執行結果。

1.3 監控角度

彈性訓練最核心的就是監控/動態處理,所以我們深入到監控模組內部進行分析。從監控的角度看,代理 Agent 主迴圈 _invoke_run 具體邏輯如下:

  • 呼叫 _initialize_workers 啟動 workers。
    • 呼叫 _rendezvous,其內部:
      • 呼叫 next_rendezvous 處理成員關係變化,
      • 呼叫 _assign_worker_ranks 為 worker 建立 ranks。
    • 呼叫 _start_workers 啟動 workers。
  • 程式進入 while 迴圈,然後通過 _monitor_workers 定期輪訓監控使用者程式執行情況,依據情況作出判斷
  • 如果 worker 程式出錯或者不健康,進入到 elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: 這裡。
    • 首先呼叫 _restart_workers 進行重啟啟動新的rendezvous,並重新啟動worker程式。
    • 如果超過最大重啟次數,則關閉任務。
  • 如果程式正常執行,進入到 state == WorkerState.HEALTHY 這裡。
    • 如果是scale up,則有新的節點在waiting,就重啟所有workers。

具體程式碼如下:

    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
        # NOTE: currently only works for a single role

        spec = self._worker_group.spec
        role = spec.role

        self._initialize_workers(self._worker_group) # 啟動worker
        monitor_interval = spec.monitor_interval
        rdzv_handler = spec.rdzv_handler

        while True:
            assert self._worker_group.state != WorkerState.INIT
            # 定期監控
            time.sleep(monitor_interval)
            # 監控客戶程式執行情況
            run_result = self._monitor_workers(self._worker_group)
            state = run_result.state # 程式執行情況
            self._worker_group.state = state

            if state == WorkerState.SUCCEEDED:
                # 程式正常結束
                self._exit_barrier() # 有一個成功了就全部結束
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # 程式出錯
                if self._remaining_restarts > 0: # 重試
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group) # 進行重啟
                else:
                    self._stop_workers(self._worker_group) # 重試次數達到,結束workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
								# 程式正常執行
                # 節點成員關係有變化,比如scale up
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                # 如果有新的節點在waiting,就重啟所有workers
                if num_nodes_waiting > 0:
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")

我們再次細化,具體如下草圖:

  _initialize_workers  <---------------------------------+                 Node 1    +   Node 2                  _initialize_workers
           +                                             |                           |                                   +
           |                                             |                           |                                   |
           |                                             |  +-----------------+      |      +-----------------+          |
           v                                             |  |RendezvousHandler|    sync     |RendezvousHandler|          v
      _rendezvous +---------------------------------------->+                 | <----+----> |                 +<---+ _rendezvous
           +                          next_rendezvous    |  |                 |      |      |                 |          +
           |                                             |  |                 |      |      |                 |          |
    _assign_worker_ranks                                 |  |                 |  heartbeat  |                 |          |
           |                                             |  |                 | <----+----> |                 |
           v                                             |  +-----------------+      |      +-----------------+          v
     _start_workers                                      |                           |                              _start_workers
           +                                             |                           |                                   +
           |                                             |                           |                                   |
           |                                             |                           |                                   |
           v                                             |                           |                                   v
     +-----+-------------------------------------------------------+                 |                          +--------+---------+
     |                                                   |         |                 |                          |                  |
     |state = _monitor_workers                           |         |                 |                          |                  |
     |   +                                               |         |                 |                          |                  |
     |   |                                               |         |                 |                          |                  |
     |   | UNHEALTHY,FAILED   1. Process fail            |         |                 |                          |                  |
+--> |   +-----------------> _restart_workers +--+       |         +-->              |                          |                  |
|    |   |                                       |       +         |  |              |                          |                  |
|    |   |                                       +--> _stop_workers|  |              |                          |  LOOP Every 30S  |
|    |   | HEALTHY            2. Node change     |                 |  |              |                          |                  |
|    |   +-----------------> _restart_workers +--+                 |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|    |   | SUCCEEDED                                               |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|    |   | 3. exit                                                 |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|    +-------------------------------------------------------------+  |              |                          |                  |
|        |                                                            |              |                          |                  |
<---------------------------------------------------------------------+              |                          +--------+---------+
         |        LOOP  Every 30S                                                    |                                   |
         |                                                                           |                                   |
         v                                                                           |                                   v
       _exit_barrier                                                                 +                             _exit_barrier

手機如圖:

或者可以參見下圖,圖片來自 https://zhuanlan.zhihu.com/p/408382623。

0x02 多程式

監控機制是監控多個正在執行的訓練worker,這就涉及到了多程式的啟動和監控,我們需要介紹多程式。這就要從啟動worker程式這個入口來看。

2.1 啟動workers

_start_workers 呼叫 start_processes 來啟動 worker 程式,預設_start_method 是 "spawn"。也就是啟動了多個程式,並行執行使用者程式。同時這些程式的執行結果會被監控。start_processes 引數之中,entrypointargs 是使用者命令和引數,entrypoint可以是函式或者字串。

然後,_start_workers 把 start_processes 方法啟動多執行緒的結果儲存在 _pcontext 之中,後續就用 _pcontext 來繼續控制,比如結束 worker 就是直接呼叫 _pcontext 的 close方法。

    @prof
    def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
        spec = worker_group.spec        store = worker_group.store
        assert store is not None
        master_addr, master_port = super()._get_master_addr_port(store)
        restart_count = spec.max_restarts - self._remaining_restarts

        use_agent_store = spec.rdzv_handler.get_backend() == "static"

        args: Dict[int, Tuple] = {}
        envs: Dict[int, Dict[str, str]] = {}
        for worker in worker_group.workers:
            local_rank = worker.local_rank
            worker_env = {
                "LOCAL_RANK": str(local_rank),
                "RANK": str(worker.global_rank),
                "GROUP_RANK": str(worker_group.group_rank),
                "ROLE_RANK": str(worker.role_rank),
                "ROLE_NAME": spec.role,
                "LOCAL_WORLD_SIZE": str(spec.local_world_size),
                "WORLD_SIZE": str(worker.world_size),
                "GROUP_WORLD_SIZE": str(worker_group.group_world_size),
                "ROLE_WORLD_SIZE": str(worker.role_world_size),
                "MASTER_ADDR": master_addr,
                "MASTER_PORT": str(master_port),
                "TORCHELASTIC_RESTART_COUNT": str(restart_count),
                "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
                "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
                "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
                "NCCL_ASYNC_ERROR_HANDLING": str(1),
            }
            if "OMP_NUM_THREADS" in os.environ:
                worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
            envs[local_rank] = worker_env
            worker_args = list(spec.args)
            worker_args = macros.substitute(worker_args, str(local_rank))
            args[local_rank] = tuple(worker_args)

        # scaling events do not count towards restarts (gets same attempt #)
        # remove existing log dir if this restart is due to a scaling event
        attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
        shutil.rmtree(attempt_log_dir, ignore_errors=True)
        os.makedirs(attempt_log_dir)

        assert spec.entrypoint is not None
        self._pcontext = start_processes( # 把啟動多執行緒的結果儲存在 _pcontext 之中。
            name=spec.role,
            entrypoint=spec.entrypoint, # 訓練程式碼入口
            args=args, # 這裡重要的是local rank
            envs=envs,
            log_dir=attempt_log_dir,
            start_method=self._start_method,
            redirects=spec.redirects,
            tee=spec.tee,
        )

        return self._pcontext.pids()

2.1.1 start_processes

注意,這裡 start_processes 的程式碼在 torch/distributed/elastic/multiprocessing/api.py 之中,和後面用到的 mp的 start_processes 不同。start_processes 會從args之中提取 local rank,然後依據 local_rank 做操作,比如建立每個程式的log檔案。其意義是:把每個worker程式同local_rank 聯絡起來,一個 local_rank 對應一個 worker程式

def start_processes(
    name: str,
    entrypoint: Union[Callable, str],
    args: Dict[int, Tuple],
    envs: Dict[int, Dict[str, str]],
    log_dir: str,
    start_method: str = "spawn",
    redirects: Union[Std, Dict[int, Std]] = Std.NONE,
    tee: Union[Std, Dict[int, Std]] = Std.NONE,
) -> PContext:
    """
    Starts ``n`` copies of ``entrypoint`` processes with the provided options.
    ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary).
    The number of copies is determined by the number of entries for ``args`` and
    ``envs`` arguments, which need to have the same key set.

    ``args`` and ``env`` parameters are the arguments and environment variables
    to pass down to the entrypoint mapped by the replica index (local rank).
    All local ranks must be accounted for.
    That is, the keyset should be ``{0,1,...,(nprocs-1)}``.

    Args:
        name: a human readable short name that describes what the processes are
              (used as header when tee'ing stdout/stderr outputs)
        entrypoint: either a ``Callable`` (function) or ``cmd`` (binary)
        args: arguments to each replica
        envs: env vars to each replica
        log_dir: directory used to write log files
        nprocs: number of copies to create (one on each process)
        start_method: multiprocessing start method (spawn, fork, forkserver)
                      ignored for binaries
        redirects: which std streams to redirect to a log file
        tees: which std streams to redirect + print to console

    """

    # listdir raises FileNotFound or NotADirectoryError so no need to check manually
    if os.listdir(log_dir):
        raise RuntimeError(
            f"log_dir: {log_dir} is not empty, please provide an empty log_dir"
        )

    nprocs = len(args)
    _validate_full_rank(args, nprocs, "args")
    _validate_full_rank(envs, nprocs, "envs")

    # create subdirs for each local rank in the logs_dir
    redirs = to_map(redirects, nprocs)
    ts = to_map(tee, nprocs)

    # to tee stdout/stderr we first redirect into a file
    # then tail -f stdout.log/stderr.log so add tee settings to redirects
    for local_rank, tee_std in ts.items():
        redirect_std = redirs[local_rank]
        redirs[local_rank] = redirect_std | tee_std

    stdouts = {local_rank: "" for local_rank in range(nprocs)}
    stderrs = {local_rank: "" for local_rank in range(nprocs)}
    tee_stdouts: Dict[int, str] = {}
    tee_stderrs: Dict[int, str] = {}
    error_files = {}

    # 大量使用了local_rank
    for local_rank in range(nprocs):
        clogdir = os.path.join(log_dir, str(local_rank))
        os.mkdir(clogdir)

        rd = redirs[local_rank]
        if (rd & Std.OUT) == Std.OUT:
            stdouts[local_rank] = os.path.join(clogdir, "stdout.log")
        if (rd & Std.ERR) == Std.ERR:
            stderrs[local_rank] = os.path.join(clogdir, "stderr.log")

        t = ts[local_rank]
        if t & Std.OUT == Std.OUT:
            tee_stdouts[local_rank] = stdouts[local_rank]
        if t & Std.ERR == Std.ERR:
            tee_stderrs[local_rank] = stderrs[local_rank]

        error_file = os.path.join(clogdir, "error.json")
        error_files[local_rank] = error_file
        envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file

    context: PContext
    if isinstance(entrypoint, str):
        context = SubprocessContext(
            name=name,
            entrypoint=entrypoint,
            args=args,
            envs=envs,
            stdouts=stdouts,
            stderrs=stderrs,
            tee_stdouts=tee_stdouts,
            tee_stderrs=tee_stderrs,
            error_files=error_files,
        )
    else:
        context = MultiprocessContext(
            name=name,
            entrypoint=entrypoint,
            args=args,
            envs=envs,
            stdouts=stdouts,
            stderrs=stderrs,
            tee_stdouts=tee_stdouts,
            tee_stderrs=tee_stderrs,
            error_files=error_files,
            start_method=start_method,
        )

    try:
        context.start()
        return context
    except Exception:
        context.close()
        raise

2.1.2 RunResult

工作程式的執行結果由RunResult標示。RunResult 是工作執行緒返回的結果。執行結果遵循"all-or-nothing"策略,其中只有當且僅當此agent管理的所有本地worker成功完成時,執行才會成功。

前面提到,把每個worker程式同local_rank 聯絡起來了,這想想也對,假如有5個GPU,當然就啟動5個工作程式訓練,這5個工作程式就對應了local rank 0~4。

但是 RunResult 註釋之中註明:如果結果成功(例如is_failed() = False),則return_values欄位包含此代理管理的工作程式的輸出(返回值),這些工作程式由其GLOBAL ranks對映。即,result.return_values[0]是全域性 rank 0的返回值。所以,在 _monitor_workers 之中會有一個從 local rank 到 gloabl rank 的對映,我們後續會講到。

@dataclass
class RunResult:
    """
    Results returned by the worker executions. Run results follow an "all-or-nothing" policy
    where the run is successful if and only if ALL local workers managed by this agent
    complete successfully.

    If the result is successful (e.g. ``is_failed() = False``) then the ``return_values``
    field contains the outputs (return values) of the workers managed by THIS agent mapped
    by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of
    global rank 0.

    .. note:: ``return_values`` are only meaningful for when the worker entrypoint
              is a function. Workers specified as a binary entrypoint do not canonically
              have a return value and the ``return_values`` field is meaningless and
              may be empty.

    If ``is_failed()`` returns ``True`` then the ``failures`` field contains the
    failure information, again, mapped by the GLOBAL rank of the worker that failed.

    The keys in ``return_values`` and ``failures`` are mutually exclusive, that is,
    a worker's final state can only be one of: succeeded, failed. Workers intentionally
    terminated by the agent according to the agent's restart policy, are not represented
    in either ``return_values`` nor ``failures``.
    """

    state: WorkerState
    return_values: Dict[int, Any] = field(default_factory=dict)
    failures: Dict[int, ProcessFailure] = field(default_factory=dict)

    def is_failed(self) -> bool:
        return self.state == WorkerState.FAILED

2.1 TE 使用

TE 使用 torch.mp 和 subprocess 包進行多程式處理。在啟動多程式時候,把結果儲存在 _pcontext 之中,這是一個 PContext 型別的例項

    self._pcontext = start_processes( # 把啟動多執行緒的結果儲存在 _pcontext 之中。
        name=spec.role,
        entrypoint=spec.entrypoint,
        args=args,
        envs=envs,
        log_dir=attempt_log_dir,
        start_method=self._start_method,
        redirects=spec.redirects,
        tee=spec.tee,
    )

其中,start_processes, PContext 來自如下:

from torch.distributed.elastic.multiprocessing import start_processes, PContext

_monitor_workers 在監控時候,就使用 _pcontext 進行監控。在監控時候會依據執行緒結果轉為WorkerState.FAILED,WorkerState.HEALTHY 或者WorkerState.SUCCEEDED返回給上層。

@prof
def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
    role = worker_group.spec.role
    worker_pids = {w.id for w in worker_group.workers}
    assert self._pcontext is not None
    pc_pids = set(self._pcontext.pids().values())
    
    result = self._pcontext.wait(0) # 對執行結果進行監控
    if result:
        if result.is_failed():
            # map local rank failure to global rank
            worker_failures = {}
            for local_rank, failure in result.failures.items():
                worker = worker_group.workers[local_rank]
                worker_failures[worker.global_rank] = failure
            return RunResult(
                state=WorkerState.FAILED, # 程式出錯,返回 WorkerState.FAILED
                failures=worker_failures,
            )
        else:
            # copy ret_val_queue into a map with a global ranks
            workers_ret_vals = {}
            for local_rank, ret_val in result.return_values.items():
                worker = worker_group.workers[local_rank]
                workers_ret_vals[worker.global_rank] = ret_val
            return RunResult(
                state=WorkerState.SUCCEEDED,
                return_values=workers_ret_vals,
            )
    else:
        return RunResult(state=WorkerState.HEALTHY)

可見,PContext是關鍵,所以我們就看看這個類。

2.2 PContext

PContext 就是一個抽象類,實際上就是些基本配置。

class PContext(abc.ABC):
    """
    The base class that standardizes operations over a set of processes
    that are launched via different mechanisms. The name ``PContext``
    is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``.

    .. warning:: stdouts and stderrs should ALWAYS be a superset of
                 tee_stdouts and tee_stderrs (respectively) this is b/c
                 tee is implemented as a redirect + tail -f <stdout/stderr.log>
    """
    def __init__(
        self,
        name: str,
        entrypoint: Union[Callable, str],
        args: Dict[int, Tuple],
        envs: Dict[int, Dict[str, str]],
        stdouts: Dict[int, str],
        stderrs: Dict[int, str],
        tee_stdouts: Dict[int, str],
        tee_stderrs: Dict[int, str],
        error_files: Dict[int, str],
    ):
        self.name = name
        # validate that all mappings have the same number of keys and
        # all local ranks are accounted for
        nprocs = len(args)
        _validate_full_rank(stdouts, nprocs, "stdouts")
        _validate_full_rank(stderrs, nprocs, "stderrs")

        self.entrypoint = entrypoint
        self.args = args
        self.envs = envs
        self.stdouts = stdouts
        self.stderrs = stderrs
        self.error_files = error_files
        self.nprocs = nprocs

        self._stdout_tail = TailLog(name, tee_stdouts, sys.stdout)
        self._stderr_tail = TailLog(name, tee_stderrs, sys.stderr)    

但是其有兩個派生類很關鍵:MultiprocessContext 和 SubprocessContext。前文提到,start_processes 引數之中,entrypointargs 是使用者命令和引數,entrypoint可以是函式或者字串。如果entrypoint是函式,則使用MultiprocessContext。如果是字串型別,使用SubprocessContext。

def start_processes(
    name: str,
    entrypoint: Union[Callable, str],
    args: Dict[int, Tuple],
    envs: Dict[int, Dict[str, str]],
    log_dir: str,
    start_method: str = "spawn",
    redirects: Union[Std, Dict[int, Std]] = Std.NONE,
    tee: Union[Std, Dict[int, Std]] = Std.NONE,
) -> PContext:
  
    context: PContext
    if isinstance(entrypoint, str): # 如果是字串
        context = SubprocessContext(
            name=name,
            entrypoint=entrypoint,
            args=args,
            envs=envs,
            stdouts=stdouts,
            stderrs=stderrs,
            tee_stdouts=tee_stdouts,
            tee_stderrs=tee_stderrs,
            error_files=error_files,
        )
    else:
        context = MultiprocessContext( # 函式則來到這裡
            name=name,
            entrypoint=entrypoint,
            args=args,
            envs=envs,
            stdouts=stdouts,
            stderrs=stderrs,
            tee_stdouts=tee_stdouts,
            tee_stderrs=tee_stderrs,
            error_files=error_files,
            start_method=start_method,
        )

    try:
        context.start() # 呼叫到這裡
        return context
    except Exception:
        context.close()
        raise  

具體來說,兩個派生類的基礎不同。

  • MultiprocessContext 使用torch.multiprocessing.start_processes來啟動程式。
  • SubprocessContext 使用subprocess.Popen來啟動程式。

我們接下來僅使用 MultiprocessContext 來分析。

2.3 MultiprocessContext

MultiprocessContext 定義如下,其中最有意義的是 _pc 這個成員變數,其實際是 ProcessContext 這個變數

import torch.multiprocessing as mp

class MultiprocessContext(PContext):
    """
    ``PContext`` holding worker processes invoked as a function.
    """

    def __init__(
        self,
        name: str,
        entrypoint: Callable,
        args: Dict[int, Tuple],
        envs: Dict[int, Dict[str, str]],
        stdouts: Dict[int, str],
        stderrs: Dict[int, str],
        tee_stdouts: Dict[int, str],
        tee_stderrs: Dict[int, str],
        error_files: Dict[int, str],
        start_method: str,
    ):
        super().__init__(
            name,
            entrypoint,
            args,
            envs,
            stdouts,
            stderrs,
            tee_stdouts,
            tee_stderrs,
            error_files,
        )

        self.start_method = start_method
        # each ret_val queue will always contain a single element.
        self._ret_vals = {
            local_rank: mp.get_context(self.start_method).SimpleQueue()
            for local_rank in range(self.nprocs)
        }

        # see comments in ``join()`` for what this is
        self._return_values: Dict[int, Any] = {}
        self._pc: Optional[mp.ProcessContext] = None # 這裡是關鍵
        self._worker_finished_event = mp.get_context(self.start_method).Event()

2.3.1 start

MultiprocessContext start 是呼叫mp.start_processes,然後儲存結果。

import torch.multiprocessing as mp

		def _start(self):
        if self._pc:
            raise ValueError(
                "The process context already initialized."
                " Most likely the start method got called twice."
            )
        self._pc = mp.start_processes( # 這裡返回了 mp.ProcessContext
            fn=_wrap,
            args=(
                self.entrypoint,
                self.args,
                self.envs,
                self.stdouts,
                self.stderrs,
                self._ret_vals,
                self._worker_finished_event,
            ),
            nprocs=self.nprocs,
            join=False,
            daemon=False,
            start_method=self.start_method,
        )

2.3.2 wait

wait 方法是在其基類 class PContext(abc.ABC): 之中。就是迴圈呼叫 _poll 函式來定期檢測

    def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
        """
        Waits for the specified ``timeout`` seconds, polling every ``period`` seconds
        for the processes to be done. Returns ``None`` if the processes are still running
        on timeout expiry. Negative timeout values are interpreted as "wait-forever".
        A timeout value of zero simply queries the status of the processes (e.g. equivalent
        to a poll).
        """
        if timeout == 0:
            return self._poll()
        if timeout < 0:
            timeout = sys.maxsize

        expiry = time.time() + timeout
        while time.time() < expiry: # 定期操作
            pr = self._poll() # 用poll來檢測
            if pr:
                return pr
            time.sleep(period)

        return None

2.3.3 _poll

_poll 函式是具體做檢測的,呼叫了 torch.mp.ProcessContext.join 來做檢測。torch.mp.ProcessContext 在部分/所有工作程式失敗時引發異常。如果超時,則會檢查工作程式狀態並立即返回。因為我們使用 synchronize.Event 等待所有程式完成,所以 Join 將永遠不會返回成功。

PyTorch 使用 multiprocessing.Queue 將工作程式返回值帶回父程式,最後返回的結果內部就包括每個程式的執行結果。

def _poll(self) -> Optional[RunProcsResult]:

    try:
        # torch.mp.ProcessContext Throws an Exception if some/all of
        # worker processes failed
        # timeout < 0 checks worker status and return immediately
        # Join will never return success since we use synchronize.Event to wait
        # for all processes to finish.
        self._pc.join(-1)

        # IMPORTANT: we use multiprocessing.Queue to carry worker return values
        # back to the parent, the worker process will wait before terminating
        # until all the buffered items are fed by the feeder thread to the underlying
        # pipe. Hence to prevent deadlocks on large return values,
        # we opportunistically try queue.get on each join call
        # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms
        
        for local_rank in range(0, self.nprocs): # 遍歷自己下面的程式
            return_queue = self._ret_vals[local_rank]
            if not return_queue.empty():
                # save the return values temporarily into a member var
                self._return_values[local_rank] = return_queue.get() # 得到程式執行結果

        if self._is_done():
            # we should ALWAYS have ALL the return values when all the processes are done
            self._worker_finished_event.set()
            # Wait untill all processes are finished. At this point workers finished executing user function
            self._pc.join()
            self.close()
            return RunProcsResult(
                return_values=self._return_values, # 返回程式結果
                stdouts=self.stdouts,
                stderrs=self.stderrs,
            )
        else:
            return None
          
    except (mp.ProcessRaisedException, mp.ProcessExitedException) as e:
        failed_local_rank = e.error_index

        # entrypoint for MultiprocessContext will always be a Callable
        fn_name = self.entrypoint.__qualname__  # type: ignore[union-attr]
        failed_proc = self._pc.processes[failed_local_rank]
        error_filepath = self.error_files[failed_local_rank]

        self.close()
        return RunProcsResult( # 返回程式結果
            failures={
                failed_local_rank: ProcessFailure(
                    local_rank=failed_local_rank,
                    pid=e.pid,
                    exitcode=failed_proc.exitcode,
                    error_file=error_filepath,
                )
            },
            stdouts=self.stdouts,
            stderrs=self.stderrs,
        )

2.4 ProcessContext

由前面可知,MultiprocessContext 的關鍵變數是:_pc: Optional[mp.ProcessContext],這個成員變數是通過 start_processes 來構建,所以我們需要看看torch.mp.ProcessContext。

2.4.1 start_processes

start_processes 在 torch/multiprocessing/spawn.py 之中,返回 ProcessContext。注意,從此之後,訓練程式就會跑自己的訓練程式碼,彷彿沒有agent一樣,因為agent已經把torch.distributed.launch 的工作做完了

def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
    mp = multiprocessing.get_context(start_method)
    error_queues = []
    processes = []
    for i in range(nprocs):
        error_queue = mp.SimpleQueue()
        process = mp.Process(
            target=_wrap,
            args=(fn, i, args, error_queue), # 訓練程式開始跑訓練程式碼
            daemon=daemon,
        )
        process.start()
        error_queues.append(error_queue)
        processes.append(process)

    context = ProcessContext(processes, error_queues)
    if not join:
        return context

    # Loop on join until it returns True or raises an exception.
    while not context.join():
        pass

2.4.2 ProcessContext

torch.mp.ProcessContext 才是最終發揮作用的類。其實,torch.mp.ProcessContext 的內部實現和如何啟動我們並不在意,因為通過 start_processes 方法,torch.mp.ProcessContext 事實上已經啟動了,我們把它當作一個功能性黑盒子即可,我們真正關心的是如何使用 torch.mp.ProcessContext 來進行監控。

從其註釋中我們可以知道,torch.mp.ProcessContext在部分/所有工作程式失敗時引發異常。如果超時,則會檢查工作程式狀態並立即返回。因為我們使用synchronize.Event等待所有程式完成,所以Join將永遠不會返回成功。

# torch.mp.ProcessContext Throws an Exception if some/all of
# worker processes failed
# timeout < 0 checks worker status and return immediately
# Join will never return success since we use synchronize.Event to wait
# for all processes to finish.

2.5 總結

目前關係如下:

  • 在生成時候,LocalElasticAgent 生成了 MultiprocessContext,MultiprocessContext 又生成了 ProcessContext。
  • LocalElasticAgent._pcontext 儲存了 MultiprocessContextMultiprocessContext._pc 儲存了 ProcessContext
  • 監控時候,LocalElasticAgent._monitor_workers 呼叫了 MultiprocessContext.wait,MultiprocessContext 又呼叫了 ProcessContext.join,ProcessContext.join 具體監控程式的執行狀態,這樣完成了監控的整體邏輯。
  • 子程式有變化或者超時之後,ProcessContext.join 返回了程式結果,MultiprocessContext.wait 把程式結果轉發回去,_monitor_workers 把程式結果轉換為 WorkerState.SUCCEEDED 或者 WorkerState.FAILED。

具體如圖:

+--------------------------------------------------------------------------------------+   +------------------------------------+   +----------------+
| LocalElasticAgent                                                                    |   | MultiprocessContext                |   | ProcessContext |
|                                                                                      |   |                                    |   |                |
|                                                                                      |   |                                    |   |                |
|  +----------------------------------------+       MultiprocessContext _pcontext      |   |       ProcessContext _pc           |   |                |
|  | _invoke_run                            |                                          |   |                                    |   |                |
|  |                                        |                                          |   |                                    |   |                |
|  |   _initialize_workers  +-------------------->  _pcontext = start_processes  +-------------->  start():                     |   |                |
|  |                                        |                                          |   |         _pc = mp.start_processes +----------->          |
|  |                                        |                                          |   |                                    |   |                |
|  |   while True:                          |      +--------------------------------+  |   |                                    |   |                |
|  |       _monitor_workers(_worker_group)+------> | _monitor_workers               |  |   |                                    |   |                |
|  |                                        |      |                                |  |   |                                    |   |                |
|  |                                        |      |             _pcontext.wait +--------------->  wait +---> poll:             |   |                |
|  |                                        |      |                                |  |   |                    _pc.join  +--------------->          |
|  +----------------------------------------+      +--------------------------------+  |   |                                    |   |                |
|                                                                                      |   |                                    |   |                |
+--------------------------------------------------------------------------------------+   +------------------------------------+   +----------------+

手機如下:

0x03 監控機制

從前面 _monitor_workers 程式碼中可以看到, _monitor_workers 會把子程式執行結果轉換為 WorkerState 的具體狀態。當代理拿到 _monitor_workers 的監控結果之後,會根據情況進行處理。

            # 監控客戶程式執行情況
            run_result = self._monitor_workers(self._worker_group)
            state = run_result.state # 程式執行情況
            self._worker_group.state = state

            if state == WorkerState.SUCCEEDED:
                # 程式正常結束
                self._exit_barrier() # 有一個成功了就全部結束
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # 程式出錯
                if self._remaining_restarts > 0: # 重試
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group) # 進行重啟
                else:
                    self._stop_workers(self._worker_group) # 重試次數達到,結束workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
								# 程式正常執行
                # 節點成員關係有變化,比如scale up
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                # 如果有新的節點在waiting,就重啟所有workers
                if num_nodes_waiting > 0:
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")

3.1 監控

這裡會呼叫 _pcontext.wait(0) 來獲取目前 worker 子程式們的狀態,然後依據返回結果,轉換不同的 WorkerState 返回給呼叫者。這裡就提到了前面講的,RunResult 應該和 global rank 對映,所以_monitor_workers就有一個從 local rank 到 gloabl rank 的對映

為何要使用 Global rank 作為程式狀態的標示?因為在Node之間需要溝通,這時候需要用Global rank。

    @prof
    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
        role = worker_group.spec.role
        worker_pids = {w.id for w in worker_group.workers} # 拿到本agent所有worker的pid
        pc_pids = set(self._pcontext.pids().values())
        if worker_pids != pc_pids:
            return RunResult(state=WorkerState.UNKNOWN)

        result = self._pcontext.wait(0) # 對執行結構進行監控
        if result:
            if result.is_failed(): # 如果程式失敗
                # map local rank failure to global rank
                worker_failures = {}
                #  返回的結果內部就包括每個程式的執行結果
                for local_rank, failure in result.failures.items(): # local_rank是程式index
                    worker = worker_group.workers[local_rank] # 拿到對應的worker
                    worker_failures[worker.global_rank] = failure # 拿到其 global_rank,進而設定worker狀態
                return RunResult(
                    state=WorkerState.FAILED,
                    failures=worker_failures, # 返回執行結果
                )
            else:
                # copy ret_val_queue into a map with a global ranks
                workers_ret_vals = {}
                for local_rank, ret_val in result.return_values.items():
                    worker = worker_group.workers[local_rank] # 
                    workers_ret_vals[worker.global_rank] = ret_val
                return RunResult(
                    state=WorkerState.SUCCEEDED,
                    return_values=workers_ret_vals, # 返回執行結果
                )
        else:
            return RunResult(state=WorkerState.HEALTHY)

3.2 處理

根據返回狀態不同,會有不同處理:

  • 如果 WorkerState.SUCCEEDED,則說明訓練結束,正常返回。
  • 如果 WorkerState.HEALTHY,則說明訓練正常執行,這時候會檢查是否有新節點加入,我們後文會詳解。
  • 如果 WorkerState.UNHEALTHY, WorkerState.FAILED,說明訓練出現問題,這裡有兩種情況。
    • 一種是程式出錯,TE 會進行重試。
    • 一種是節點退出,我們在下文分析,但是其處理流程與程式出錯一致。

接下來我們就分析一下如何處理訓練結束 和 程式出錯。

0x04 訓練結束

        if state == WorkerState.SUCCEEDED:
            # 程式正常結束
            self._exit_barrier() # 有一個成功了就全部結束
            return run_result

以上是訓練正常結束時候的處理,特殊就在於_exit_barrier的使用。

4.1 統一完成

Torchelastic目前支援DDP風格的應用程式。也就是說TE希望所有 workers 大約同時完成。實際上,幾乎不可能保證DDP的所有工人都能保證同時結束,所以因此TE提供了一個finalization barrier,這個barrier的作用是對worker finalization 實施等待超時(5分鐘)。也就是說,如果有一個worker 訓練完成,TE(torchelastic)希望使用者所有worker以5分鐘的誤差完成。

def _exit_barrier(self):
    """
    Wait for ``exit_barrier_timeout`` seconds for all agents to finish
    executing their local workers (either successfully or not). This
    acts as a safety guard against user scripts that terminate at different
    times. This barrier keeps the agent process alive until all workers finish.
    """

    start = time.time()
    try:
        store_util.barrier(
            self._store,
            self._worker_group.group_rank,
            self._worker_group.group_world_size,
            key_prefix=_TERMINAL_STATE_SYNC_ID,
            barrier_timeout=self._exit_barrier_timeout,
        )
    except Exception:
        log.exception(
            f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds"
        )

exit_barrier_timeout 的預設值就是300秒,即5分鐘。

exit_barrier_timeout: float = 300,

4.2 同步

在 torch/distributed/elastic/utils/store.py 之中,barrier 會呼叫 synchronize 進行同步。

def barrier(
    store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300
) -> None:
    """
    A global lock between agents.

    Note: Since the data is not removed from the store, the barrier can be used
        once per unique ``key_prefix``.
    """
    data = f"{rank}".encode(encoding="UTF-8")
    synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)

synchronize 則是通過store進行同步。

def get_all(store, prefix: str, size: int):
    r"""
    Given a store and a prefix, the method goes through the array of keys
    of the following format: ``{prefix}{idx}``, where idx is in a range
    from 0 to size, and tries to retrieve the data.

    Usage

    ::

     values = get_all(store, 'torchelastic/data', 3)
     value1 = values[0] # retrieves the data for key torchelastic/data0
     value2 = values[1] # retrieves the data for key torchelastic/data1
     value3 = values[2] # retrieves the data for key torchelastic/data2

    """
    data_arr = []
    for idx in range(size):
        data = store.get(f"{prefix}{idx}")
        data_arr.append(data)
    return data_arr

def synchronize(
    store,
    data: bytes,
    rank: int,
    world_size: int,
    key_prefix: str,
    barrier_timeout: float = 300,
) -> List[bytes]:
    """
    Synchronizes ``world_size`` agents between each other using the underlying c10d store.
    The ``data`` will be available on each of the agents.

    Note: The data on the path is not deleted, as a result there can be stale data if
        you use the same key_prefix twice.
    """
    store.set_timeout(timedelta(seconds=barrier_timeout))
    store.set(f"{key_prefix}{rank}", data)
    agent_data = get_all(store, key_prefix, world_size)
    return agent_data

0x05 錯誤處理

5.1 錯誤型別

分散式PyTorch作業中的每個主機都執行一個TorchElastic 代理和多個worker(作為TorchElastic代理的子程式)。由於worker是使用者提供的(PyTorch script/job),TorchElastic可以通過代理將錯誤傳播到trainer之上,直至排程程式(scheduler),最終把這些作業的狀態通知終端使用者並應用一些重試策略。

TE 把錯誤歸為如下幾類。

+----------------+----------------+--------------------------------------------------------------+
| Category       | Sub-Category   |  Description                                                 |
+================+================+==============================================================+
| User Error     | Input Error    | invalid inputs to TorchElastic APIs (e.g. min > max nodes)   |
|                +----------------+--------------------------------------------------------------+
|                | Worker Failure | any failures on the worker child process                     |
+----------------+----------------+--------------------------------------------------------------+
| Platform Error |      n/a       | failures caused by the agent                                 |
+----------------+----------------+--------------------------------------------------------------+
| Infra Error    |      n/a       | failures outside the domain of the agent and workers         |
|                |                | (e.g. host failures)                                         |
+----------------+----------------+--------------------------------------------------------------+

5.1 錯誤處理模式

對應的錯誤處理模式如下,我們按照從小到大的故障級別來看:

  • User Error:具體又分為如下處理方式:
    • User Error :比如錯誤輸入,這樣直接程式捕獲即可。
    • Worker Failure:
      • Worker Failures是特殊的,因為異常/失敗源於與代理不同的程式,因此錯誤需要在程式間傳播(例如,代理不能簡單地 try-catch 一個工作程式上引發的異常)。
        • TorchElastic代理使用 torch.distributed.elastic.multiprocessing.start_processes啟動worker,它內建了一個簡單的基於檔案的程式間錯誤傳播。
        • 任何用record修飾的函式或二進位制入口點都會將未捕獲的異常(帶有跟蹤資訊)寫入環境變數 TORCHELASTIC_ERROR_FILE指定的檔案。父程式(例如代理)在其啟動的每個子程式之上設定此環境變數,然後聚合所有子程式的錯誤檔案,並傳播具有最小時間戳的錯誤檔案(例如第一個錯誤)。
      • 文件中有如下論述:對於有“n”個 workers 的訓練job,如果“k<=n”名 worker 失敗,那麼所有 worker 都會停止並重新啟動,直到達到 “max_restarts” 次數。上面這句話的意思其實就是:如果有一個worker失敗了,而且還沒有達到了最大重啟次,TE 將啟動新的rendezvous,並且重啟所有workers,因為是新的 rendezvous,所以其他 TE 代理也會重啟其 workers
      • 一個worker的失敗將導致整個叢集失敗:如果單個worker不斷失敗,則會導致TE agent 的 max_restarts 變數變為零。這將導致agent完成其工作並關閉rendezvous。如果在不同的代理上有任何其他worker,它們也將被終止。
  • Platform Error(就是代理故障)
    • 非Worker故障(Worker Failure)之外的所有錯誤都會從代理程式中正常引發,隱式或顯式地使代理程式崩潰。因此可以應用標準語言(python)提供的異常處理策略。
    • 代理失敗也可以導致本地工作組失敗。如何處理取決於job manager,比如使整個作業(gang語義)失敗或嘗試替換節點。兩種行為均由代理支援。
  • Infra Error(就是節點故障 ):與代理故障同樣方式來處理。

我們接下來就具體看看如何處理"Worker Failure"。

5.2 處理機制

錯誤處理具體機制如下,如果重試尚未達到最大次數,則試圖重啟workers。如果已經達到了最大次數,則停止 workers。

        elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
            # 程式出錯
            if self._remaining_restarts > 0: # 重試
                self._remaining_restarts -= 1
                self._restart_workers(self._worker_group) # 進行重啟
            else:
                self._stop_workers(self._worker_group) # 重試次數達到,結束workers
                self._worker_group.state = WorkerState.FAILED
                self._exit_barrier()
                return run_result

5.2.1 重啟

_restart_workers 會停掉所有 workers,然後重新一輪 rendezvous 。

@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
    """
    Restarts (stops, rendezvous, starts) all local workers in the group.
    """

    role = worker_group.spec.role
    self._stop_workers(worker_group)
    worker_group.state = WorkerState.STOPPED
    self._initialize_workers(worker_group)

5.2.2 停止

停止 workers 就是關閉上下文。

def _shutdown(self) -> None:
    if self._pcontext:
        self._pcontext.close()
        
@prof
def _stop_workers(self, worker_group: WorkerGroup) -> None:
    self._shutdown()

在 MultiprocessContext 之中,close 方法是關閉所有子程式,然後等待其全部停止。

    def _close(self) -> None:
        if self._pc:
            for proc in self._pc.processes:
                proc.terminate()
                proc.join()

5.4 其他代理重啟

從原始碼註釋可知:新一輪 rendezvous 會讓其他 agent 也重啟它們的worker。

When worker fails, TE will check the number of restarts available, if there is more than 0 restarts, TE will start a new rendezvous round and restart the worker process. New rendezvous round will other TE agents to terminate their workers.

這是如何做到的?具體如下:

  1. Agent 0(故障Agent)通過 monitoring 發現了故障。
  2. Agent 0 呼叫 _restart_workers 重啟worker。
  3. Agent 0 會呼叫 next_rendezvous 發起新一輪 rendezvous。
  4. Agent 0 在做任何操作之前,比如 keep alive 操作之前,會呼叫 sync 來從kvstore獲取叢集資訊,這樣可以保證 Agent拿到的是叢集最新狀態。
  5. Agent 0 會把自己加入到本地的 waiting_list 之中。
  6. Agent 0 同時會呼叫 mark_dirty,意思是我狀態更新了,需要寫入KVStore。
  7. Agent 0 會呼叫sync把自己的waiting_list 被寫入 KVStore。
  8. Agent 1(其他正常工作的 agent)會在做任何操作之前,比如 keep alive 操作之前,會呼叫 sync 操作從KVStore 獲取最新資訊。
  9. Agent 1 利用這些資訊來更新自己的狀態,這樣本地 waiting_list 就會更新。
  10. Agent 1 的 train loop 在每 30 秒監控之後,因為系統正常,是 Healthy 狀態。
  11. Agent 1 所以呼叫 num_nodes_waiting() 看看 waiting_list 數目。
  12. Agent 1 會獲取本地 waiting list 的數目。
  13. 如果 waiting list 不為空,也呼叫_restart_workers。
  14. 其最終會呼叫next_rendezvous。

具體如下:

 Agent 0                                      Agent 1
+---------------------------+                 +--------------------------------------------+
|    _invoke_run            |                 |                       _invoke_run          |
|          +                |                 |                           +                |
|          |                |                 |                           |                |
|          | 1              |                 |                           |                |
|          v                |                 |                           |                |
| Worker Process Error      |                 |                           |                |
|          +                |                 |                           |                |
|          |                |                 |                           | 10             |
|          | 2              |                 |                           v                |
|          v                |                 |                        HEALTHY             |
|  _restart_workers         |                 |                           +                |
|          +                |                 |                           | 11             |
|          |                |                 |                           |                |
|          | 3              |                 |                           v                |
|          v                |                 |              +-->  num_nodes_waiting() > 0 |
|   next_rendezvous         |                 |              |            +                |
|          +                |                 |              |            |                |
|          | 4              |                 |              | 12         | 13             |
|          |                +   +----------+  |              |            v                |
|          v      cluster info  |          |  |              |       _restart_workers      |
|        sync  <------------+-> | KV Store |  |              |            +                |
|          +                |   |          |  |              |            |                |
|          | 5              |   |          |  |              |            | 14             |
|          v                |   |          |  |              |            v                |
|  Add to local waiting_list|   |          |  |              |        next_rendezvous      |
|          +                |   |          |  |              |                             |
|          |                |   |          |  |              |                             |
|          | 6              |   |          |  |              v                             |
|          v                |   |          |  |                                            |
|     mark_dirty            |   |          |  |  Add to local waiting_list                 |
|          +                |   |          |  |              ^                             |
|          |                |   |          |  |              |                             |
|          | 7              |   |          |  |            9 | waiting_list                |
|          v         7      |   |          |  |    8         +                             |
|        sync +---------------> |          +--------------> sync                           |
|              waiting_list |   |          |  |waiting_list                                |
|                           |   +----------+  |                                            |
+---------------------------+                 +--------------------------------------------+


至此,我們監控機制初步介紹完成,因為篇幅所限,我們下一篇繼續介紹Scale up/down如何處理。

0xFF 參考

雲原生的彈性 AI 訓練系列之二:PyTorch 1.9.0 彈性分散式訓練的設計與實現

PyTorch Elastic原始碼閱讀

相關文章