[原始碼解析] PyTorch 分散式之彈性訓練(7)---節點變化

羅西的思考發表於2021-12-31

[原始碼解析] PyTorch 分散式之彈性訓練(7)---節點變化

0x00 摘要

本文分析如何處理節點變化。即對成員更改作出反應,並使用新的成員來重啟所有workers,從而實現彈性訓練。

總體思路是和當工作程式失敗時的處理一樣:相應elastic agent將殺死該節點上的所有工作程式,與其他代理建立會合(rendezvous),並使用新的會合(rendezvous)資訊重新啟動所有工作程式。

彈性訓練系列文章如下:

[原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路

[原始碼解析] PyTorch 分散式之彈性訓練(2)---啟動&單節點流程

[原始碼解析] PyTorch 分散式之彈性訓練(3)---代理

[原始碼解析] PyTorch 分散式之彈性訓練(4)---Rendezvous 架構和邏輯

[原始碼解析] PyTorch 分散式之彈性訓練(5)---Rendezvous 引擎

[原始碼解析] PyTorch 分散式之彈性訓練(6)---監控/容錯

0x01 變化方式

節點變化有兩點方式。

1.1 Scale-down

節點離開(scale-down)的處理如下:

  • 當Scale down事件發生時,rendezvous將不會通知 torchelastic agent。
  • torchelastic agent 自己會監控到有程式錯誤,從而進行處理。
  • 如果TE agent以max_restarts=0配置啟動,它依賴於底層排程程式來處理作業重新啟動。
  • 如果max_restarts>0,TE代理將終止workers並開始新一輪rendezvous。
    • 代理得到離開的通知,於是現有workers(所有節點上的)都全部停止。
    • 這些workers將形成一個新的“WorkerGroup”,所有worker都將以新的RANKWORLD_SIZE 執行。

1.2 Scale-up

節點加入(scale-up)的處理如下:

  • 當Scale up事件發生時,新節點被提交到作業,torchelastic rendezvous將檢測到有新節點試圖加入。
    • 如果rendezvous已經達到最多節點數,新節點將不會新增到等待列表,因為已經滿了,所以沒有必要拆除已經完全體的rendezvous。新節點將一直等待直到超時(預設為600秒)。
    • 新節點將定期檢查參與節點數目。如果數目變為小於max_nodes,等待節點將被加入到等待列表中。否則它將在600秒之後超時。
  • 當代理決定處理 Scale up時:
    • torchelastic rendezvous將停止所有workers並執行新一輪的 re-rendezvous。
    • 這些workers(現有以及新加入的)將形成一個新的“WorkerGroup”,所有worker都將以新的RANKWORLD_SIZE 執行。

注:scale up發生時,max_restarts 將不會減少。

0x02 節點加入

2.1 新節點加入

假設目前已經有了一個彈性訓練叢集正在執行,彈性區間為 (min=1, max=4)。目前已經有2個節點在執行,使用者想啟動第三個節點,於是使用如下方法啟動一個新程式。

python -m torch.distributed.run
        --nnodes=1:4
        --nproc_per_node=$NUM_TRAINERS
        --rdzv_id=$JOB_ID
        --rdzv_backend=c10d
        --rdzv_endpoint=$HOST_NODE_ADDR
        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

新程式會啟動一個代理。代理經過一系列操作,呼叫 next_rendezvous,其中啟動一個 ExitOp,一個 JoinOp 。

def next_rendezvous(self) -> Tuple[Store, int, int]:
    exit_op = _RendezvousExitOp()
    join_op = _RendezvousJoinOp()
    
    self._op_executor.run(exit_op, deadline)
    self._op_executor.run(join_op, deadline)    

2.2 處理 Join 操作

以下操作是在 _DistributedRendezvousOpExecutor 之中。

有了前文分析,我們知道,業務流程是 run 呼叫 Join 運算元來分析出來下一個 Action,然後根據 Action 來執行對應的業務操作

2.2.1 run處理

_DistributedRendezvousOpExecutor.run 函式實現了基礎邏輯,就是依據 action 型別進行各種操作。對於我們示例,state_handler 就是_RendezvousJoinOp。

    def run(
        self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
    ) -> None:
        """See base class."""
        action = None

        while action != _Action.FINISH: # 一直迴圈,直到結束
            
            # 這裡很重要,在所有node之間做資訊同步
            has_set = self._state_holder.sync() # 因為最新狀態在 rendezvous。
            self._state = self._state_holder.state
            # 利用最新狀態構建了 ctx
            ctx = _RendezvousContext(self._node, self._state, self._settings)

            # Determine the next action to take based on the current state of
            # the rendezvous.
            action = state_handler(ctx, deadline) # 呼叫_RendezvousJoinOp,決定下一個操作

            # 省略後續部分

2.2.2 Join操作

因為之前做了同步,所以這裡的ctx就包括了最新的state,這就是Rendezvous的全域性狀態。因為此時,Rendezvous 已經結束了,所以 state 的狀態是 complete,進入如下流程,返回 _Action.ADD_TO_WAIT_LIST。

    if state.complete:
        # If we are here, it means we are not part of the rendezvous. In
        # case the rendezvous has capacity for additional participants add
        # ourself to the wait list for the next round.
        if len(state.participants) < ctx.settings.max_nodes: # 如果當前節點數目小於最大配置
            if ctx.node not in state.wait_list: # 如果當前node不在等待列表之中
                return _Action.ADD_TO_WAIT_LIST  # 傳送一個等待action

總體程式碼如下:

class _RendezvousJoinOp:
    """Represents a rendezvous join operation."""

    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
        state = ctx.state # 從上下文之中提取 _RendezvousState 狀態

        # A closed rendezvous means that it no longer accepts new nodes.
        if state.closed:
            return _Action.ERROR_CLOSED # 如果已經結束,就返回 _Action.ERROR_CLOSED

        is_participant = ctx.node in state.participants # 看看是參與者

        # If we are part of the rendezvous and it is already complete there is
        # no further action to take.
        if state.complete and is_participant: # 如果是參與者且狀態結束,就返回 _Action.FINISH
            return _Action.FINISH

        now = time.monotonic()
        if now > deadline: # 如果已經超時
            rollback_period = 5  # 5 seconds

            # If we still have time to rollback (a short period on top of the
            # operation deadline), try to remove ourself from the rendezvous.
            # It is okay if we can't though as our keep-alive will eventually
            # expire.
            if now <= deadline + rollback_period: # 如果還有時間來 rollback
                # If we are part of the rendezvous, it means we couldn't find
                # enough participants to complete it on time.
                if is_participant: # 已經是參與者了
                    return _Action.REMOVE_FROM_PARTICIPANTS # 需要從參與者列表移除
                # If we are in the wait list, it means we couldn't wait till the
                # next round of the rendezvous.
                if ctx.node in state.wait_list: # 已經在等待列表之中
                    return _Action.REMOVE_FROM_WAIT_LIST # 需要從等待列表移除
            return _Action.ERROR_TIMEOUT # 返回超時

        if state.complete: # 如果 rendezvous 已經結束
            # If we are here, it means we are not part of the rendezvous. In
            # case the rendezvous has capacity for additional participants add
            # ourself to the wait list for the next round.
            if len(state.participants) < ctx.settings.max_nodes: # 如果還沒有達到最大節點數
                if ctx.node not in state.wait_list: # 如果當前node不在等待列表之中
                    return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,傳送一個等待action
        elif is_participant: # 如果已經在參與者列表
            # If the rendezvous has enough number of participants including us,
            # check whether we have passed the rendezvous deadline. If yes,
            # complete it.
            if len(state.participants) >= ctx.settings.min_nodes: # 如果達到了最小節點數
                if cast(datetime, state.deadline) < datetime.utcnow(): # 如果達到了超時
                    return _Action.MARK_RENDEZVOUS_COMPLETE # 標示 rendezvous 已經結束
        else: # 否則就直接加入到參與者
            # The rendezvous is not complete yet and we are not part of it. Try
            # to join.
            return _Action.ADD_TO_PARTICIPANTS

        if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
            return _Action.KEEP_ALIVE

        # At this point either the rendezvous is not complete, but we are part
        # of it, which means we have to wait for other participants to join; or
        # the rendezvous is complete, but we are not part of it, which means we
        # have to wait for the next round.
        return _Action.SYNC # 否則返回同步狀態 _Action.SYNC

2.2.3 等待業務操作

_DistributedRendezvousOpExecutor 之中,run 函式實現了基礎邏輯,就是依據 action 型別進行各種操作。

    def run(
        self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
    ) -> None:
        """See base class."""
        action = None

        while action != _Action.FINISH: # 一直迴圈,直到結束
     
            # 這裡很重要,在所有node之間做資訊同步
            has_set = self._state_holder.sync() # 因為最新狀態在 rendezvous。
            self._state = self._state_holder.state
					  # 使用最新state構建ctx
            ctx = _RendezvousContext(self._node, self._state, self._settings)

            # Determine the next action to take based on the current state of
            # the rendezvous.
            action = state_handler(ctx, deadline) # 呼叫_RendezvousJoinOp,決定下一個操作,這裡得到了 _Action.ADD_TO_WAIT_LIST

            if action == _Action.SYNC:
                _delay(seconds=1)
            else:
                if action == _Action.KEEP_ALIVE:
                    self._keep_alive()
                elif action == _Action.ADD_TO_WAIT_LIST: # 從 Join 運算元得到了_Action.ADD_TO_WAIT_LIST
                    self._add_to_wait_list() # 進行業務邏輯
                # 省略其他action

                # Attempt to sync our changes back to other nodes.
                self._state_holder.mark_dirty() # 同步回其他節點

具體處理等待操作就是加入到等待列表。

def _add_to_wait_list(self) -> None:
    self._state.wait_list.add(self._node)
    self._keep_alive()

我們回憶一下 _RendezvousState。_RendezvousState 是rendezvous的狀態。是動態資訊。

  • round:Rendezvous的當前輪次
  • complete:一個布林值,指示rendezvous當前一輪是否完成了。
  • deadline:截止時間,如果如果當前輪次一直在等待節點加入,如果這個引數設定了,就是等待的截至時間。
  • closed:一個布林值,指示rendezvous是否結束了。
  • participants:字典,存放參與者和它們對應ranks。
  • wait_list:set結構,存放等待參與下一輪rendezvous操作的一組節點
  • last_heartbeats:字典,包含每個節點上次心跳時間。
class _RendezvousState:
    round: int
    complete: bool
    deadline: Optional[datetime]
    closed: bool
    participants: Dict[_NodeDesc, int] # 參與者,未來會用到的成員變數
    wait_list: Set[_NodeDesc]  # 等待者,這裡用到的成員變數
    last_heartbeats: Dict[_NodeDesc, datetime]

    def __init__(self) -> None:
        self.round = 0
        self.complete = False
        self.deadline = None
        self.closed = False
        self.participants = {}
        self.wait_list = set() # 這裡用到的成員變數
        self.last_heartbeats = {}

目前邏輯如下:

  1. 啟動一個新 worker。此時下圖右側上方的 _RendezvousState 之中,wait_list 為空。
  2. 呼叫 next_rendezvous,發起新一輪 rendezvous。
  3. _RendezvousJoinOp 內部執行,生成 ADD_TO_WAIT_LIST。
  4. executor . run 內部執行 _add_to_wait_list。
  5. 往 wait_list 新增一個新的 node。此時下圖右側上方的 _RendezvousState 之中,wait_list 多了一個 1。
  python -m torch.distributed.run             +-------------------------+     +
      --nnodes=xxx TRAINING_SCRIPT.py         | _RendezvousState        |     |
                 +                            |                         |     |
                 |                            |    participants = [1,2] |     |
                 | 1                          |                         |     |
                 v                            |    wait_list = []       |     |
          next_rendezvous                     |                         |     |
                 +                            +------------+------------+     |
                 | 2                                       |                  |
                 |                                         |                  |
                 v                                         |                  |
+----------------+-----------------------+                 |                  |
| _op_executor.run(_RendezvousJoinOp)    |                 |                  |
|           +              +             |                 |                  |
|           |              | 3           |                 |                  |
|           |              |             |                 |                  |
|           |              v             |                 |                  |
|           |   _Action.ADD_TO_WAIT_LIST |                 v                  |
|           |              +             |                                    |
|           |              |             |    +--------------------------+    |
|           +<-------------+             |    | _RendezvousState         |    |
|           |                            |    |                          |    |
|           |                            |    |    participants = [1,2]  |    |
|           v       4                    | 5  |                          |    |
|      self._add_to_wait_list() +----------------> wait_list = [3]       |    |
|                                        |    |                          |    |
+----------------------------------------+    +--------------------------+    |
                                                                              |
                                                                              v

                                                                         Timeline

2.3 Agent 處理

_DistributedRendezvousOpExecutor . run 處理之後,操作回到了代理之中。代理主迴圈之中,程式會進入 while 迴圈,然後通過 _monitor_workers 定期輪訓使用者程式執行情況,依據情況作出判斷。

    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
        # NOTE: currently only works for a single role

        spec = self._worker_group.spec
        role = spec.role

        self._initialize_workers(self._worker_group) # 啟動worker
        monitor_interval = spec.monitor_interval
        rdzv_handler = spec.rdzv_handler

        while True:
            assert self._worker_group.state != WorkerState.INIT
            # 定期監控
            time.sleep(monitor_interval)
            # 監控客戶程式執行情況
            run_result = self._monitor_workers(self._worker_group)
            state = run_result.state # 程式執行情況
            self._worker_group.state = state

            if state == WorkerState.SUCCEEDED:
                # 程式正常結束
                self._exit_barrier()
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # 程式出錯
                if self._remaining_restarts > 0: # 重試
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group)
                else:
                    self._stop_workers(self._worker_group) # 重試次數達到,結束workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
								# 程式正常執行
                # 節點成員關係有變化,比如scale up
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                # 如果有新的節點在waiting,就重啟所有workers
                if num_nodes_waiting > 0:
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")

所以,代理定期執行 _monitor_workers 監控worker執行情況才是關鍵。run_result.state 是程式執行情況,當狀態是 WorkerState.HEALTHY,說明原有程式正常執行,接下來看看節點成員關係是否有變化。

呼叫 rdzv_handler.num_nodes_waiting() 拿到等待列表數目,如果有新的節點在waiting,就說明有新的節點試圖加入叢集,這時就會發生一個Re-rendezvous。代理將重啟所有workers。重啟時候,會把等待列表中的節點加入到參與列表之中。我們依次看看如何處理。

2.3.1 檢查等待列表

處理時候,首先會呼叫 num_nodes_waiting 看看還有多少節點在等待,具體是看看 state.wait_list 的長度。我們通過之前 Join 操作知道,如果有新節點,會插入到這個列表之中。

num_nodes_waiting 方法的作用是 返回在 rendezvous barrier 上等待的節點數目(這些節點不會在當前工作組被包括)。呼叫者應該週期呼叫這個方法,來確定是否有新節點等候加入當前工作組,因此需要呼叫next_rendezvous() 來提交他們。

def num_nodes_waiting(self) -> int:
    """See base class."""
    with self._heartbeat_lock:
        self._state_holder.sync()

        return len(self._state_holder.state.wait_list)

目前邏輯如下:

  1. 啟動一個新 worker。
  2. 呼叫 next_rendezvous,發起新一輪 rendezvous。
  3. _RendezvousJoinOp 內部執行,生成 ADD_TO_WAIT_LIST。
  4. executor.run 內部執行 _add_to_wait_list。
  5. 往 wait_list 新增一個新的 node。
  6. Agent 之中,定期(比如 30S)執行一次 _monitor_workers,獲取worker 子程式狀態。
  7. 如果是 HEALTHY,則呼叫num_nodes_waiting 獲取 wait_list 個數。
  8. 如果 wait_list 之中等待節點數目大於 0,則:
  9. 呼叫 _restart_workers 重啟程式組。
  python -m torch.distributed.run             +-------------------------+     +
      --nnodes=xxx TRAINING_SCRIPT.py         | _RendezvousState        |     |
                 +                            |                         |     |
                 |                            |    participants = [1,2] |     |
                 | 1                          |                         |     |
                 v                            |    wait_list = []       |     |
          next_rendezvous                     |                         |     |
                 +                            +------------+------------+     |
                 | 2                                       |                  |
                 |                                         |                  |
                 v                                         |                  |
+----------------+-----------------------+                 |                  |
| _op_executor.run(_RendezvousJoinOp)    |                 |                  |
|           +              +             |                 |                  |
|           |              | 3           |                 |                  |
|           |              |             |                 |                  |
|           |              v             |                 |                  |
|           |   _Action.ADD_TO_WAIT_LIST |                 v                  |
|           |              +             |                                    |
|           |              |             |    +--------------------------+    |
|           +<-------------+             |    | _RendezvousState         |    |
|           |                            |    |                          |    |
|           |                            |    |    participants = [1,2]  |    |
|           v       4                    | 5  |                          |    |
|      self._add_to_wait_list() +----------------> wait_list = [3]       |    |
|                                        |    |                          |    |
+----------------------------------------+    +------------+-------------+    |
                                                           |                  |
+----------------------------------------+                 |                  |
| agent._invoke_run                      |                 |                  |
|                                        |                 |                  |
|                                        |                 |                  |
|        _monitor_workers Every 30S      |                 |                  |
|                +                       |                 |                  |
|                | 6                     |                 |                  |
|                |                       |                 v                  |
|                v                       |                                    |
|         WorkerState.HEALTHY            |     +--------------------------+   |
|                +                       |     | _RendezvousState         |   |
|                |                       |     |                          |   |
|                | 7                     |     |     participants = [1,2] |   |
|                v                       |  8  |                          |   |
|        num_nodes_waiting   <-------------------->  wait_list = [3]      |   |
|                +                       |     |                          |   |
|                | 9                     |     |                          |   |
|                |                       |     +--------------------------+   |
|                v                       |                                    |
|        _restart_workers                |                                    v
|                                        |
+----------------------------------------+                               Timeline

2.3.3 重啟worker組

如果等待列表之中有節點,就會重啟workers。我們走一下這個流程。

@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
    """
    Restarts (stops, rendezvous, starts) all local workers in the group.
    """

    role = worker_group.spec.role
    self._stop_workers(worker_group)
    worker_group.state = WorkerState.STOPPED
    self._initialize_workers(worker_group)
2.3.3.1 _stop_workers

首先會停止目前 workers,程式碼在torch/distributed/elastic/agent/server/local_elastic_agent.py。

@prof
def _stop_workers(self, worker_group: WorkerGroup) -> None:
    self._shutdown()
2.3.3.2 _shutdown

_shutdown 就是讓上下文關閉。

def _shutdown(self) -> None:
    if self._pcontext:
        self._pcontext.close()
2.3.3.3 關閉上下文

在 MultiprocessContext 之中,close 方法是關閉所有子程式,然後等待其全部停止。

    def _close(self) -> None:
        if self._pc:
            for proc in self._pc.processes:
                proc.terminate()
                proc.join()
2.3.3.4 _initialize_workers

當關閉了所有當前執行的子程式之後,會重新全部初始化。

@prof
def _initialize_workers(self, worker_group: WorkerGroup) -> None:
    r"""
    Starts a fresh set of workers for the worker_group.
    Essentially a rendezvous followed by a start_workers.

    The caller should first call ``_stop_workers()`` to stop running workers
    prior to calling this method.

    Optimistically sets the state of the worker group that
    just started as ``HEALTHY`` and delegates the actual monitoring
    of state to ``_monitor_workers()`` method
    """
    role = worker_group.spec.role

    # TODO after stopping workers, wait at least monitor_interval*2 for
    # workers on different nodes to fail on a collective op before waiting
    # on the rdzv barrier, this way we ensure that nodes enter rdzv
    # at around the same time and reduce false positive rdzv timeout errors
    self._rendezvous(worker_group)

    worker_ids = self._start_workers(worker_group)
    for local_rank, w_id in worker_ids.items():
        worker = worker_group.workers[local_rank]
        worker.id = w_id

    worker_group.state = WorkerState.HEALTHY

_rendezvous經過一系列操作,呼叫 next_rendezvous,在其中啟動一個 ExitOp,一個 JoinOp 。

def next_rendezvous(self) -> Tuple[Store, int, int]:

    exit_op = _RendezvousExitOp()
    join_op = _RendezvousJoinOp()
    
    self._op_executor.run(exit_op, deadline)
    self._op_executor.run(join_op, deadline)    
2.3.3.5 _RendezvousJoinOp

我們又回來了,這是新一輪 Rendezvous 操作。_DistributedRendezvousOpExecutor 之中,run 函式實現了基礎邏輯,就是依據 action 型別進行各種操作。對於我們示例,state_handler 就是_RendezvousJoinOp

def run(
    self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
    """See base class."""
    action = None

    while action != _Action.FINISH:
        # Reads or writes the latest rendezvous state shared by all nodes in
        # the rendezvous. Note that our local changes might get overridden
        # by another node if that node synced its changes before us.
        has_set = self._state_holder.sync()
        self._state = self._state_holder.state
        ctx = _RendezvousContext(self._node, self._state, self._settings)

        # Determine the next action to take based on the current state of
        # the rendezvous.
        # 呼叫到_RendezvousJoinOp,大家可以過一下 _RendezvousJoinOp 程式碼,發現此時將返回 ADD_TO_PARTICIPANTS
        action = state_handler(ctx, deadline) 

        if action == _Action.SYNC:
            # Delay the execution by one second to avoid overloading the
            # backend if we are asked to poll for state changes.
            _delay(seconds=1)
        else:
            if action == _Action.KEEP_ALIVE:
                self._keep_alive()
            elif action == _Action.ADD_TO_PARTICIPANTS: # 執行到這裡
                self._add_to_participants()
            elif action == _Action.ADD_TO_WAIT_LIST:
                self._add_to_wait_list()
            elif action == _Action.REMOVE_FROM_PARTICIPANTS:
                self._remove_from_participants()
            elif action == _Action.REMOVE_FROM_WAIT_LIST:
                self._remove_from_wait_list()
            elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
                self._mark_rendezvous_complete()
            elif action == _Action.MARK_RENDEZVOUS_CLOSED:
                self._mark_rendezvous_closed()

            # Attempt to sync our changes back to other nodes.
            self._state_holder.mark_dirty()

這次會生成 ADD_TO_PARTICIPANTS。

class _RendezvousJoinOp:
    """Represents a rendezvous join operation."""

    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
        state = ctx.state # 從上下文之中提取 _RendezvousState 狀態

        # A closed rendezvous means that it no longer accepts new nodes.
        if state.closed:
            return _Action.ERROR_CLOSED # 如果已經結束,就返回 _Action.ERROR_CLOSED

        is_participant = ctx.node in state.participants # 看看是參與者

        # If we are part of the rendezvous and it is already complete there is
        # no further action to take.
        if state.complete and is_participant: # 如果是參與者且狀態結束,就返回 _Action.FINISH
            return _Action.FINISH

        now = time.monotonic()
        if now > deadline: # 如果已經超時
            rollback_period = 5  # 5 seconds

            # If we still have time to rollback (a short period on top of the
            # operation deadline), try to remove ourself from the rendezvous.
            # It is okay if we can't though as our keep-alive will eventually
            # expire.
            if now <= deadline + rollback_period: # 如果還有時間來 rollback
                # If we are part of the rendezvous, it means we couldn't find
                # enough participants to complete it on time.
                if is_participant: # 已經是參與者了
                    return _Action.REMOVE_FROM_PARTICIPANTS # 需要從參與者列表移除
                # If we are in the wait list, it means we couldn't wait till the
                # next round of the rendezvous.
                if ctx.node in state.wait_list: # 已經在等待列表之中
                    return _Action.REMOVE_FROM_WAIT_LIST # 需要從等待列表移除
            return _Action.ERROR_TIMEOUT # 返回超時

        if state.complete: # 如果 rendezvous 已經結束
            # If we are here, it means we are not part of the rendezvous. In
            # case the rendezvous has capacity for additional participants add
            # ourself to the wait list for the next round.
            if len(state.participants) < ctx.settings.max_nodes: # 如果還沒有達到最大節點數
                if ctx.node not in state.wait_list: # 如果當前node不在等待列表之中
                    return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,傳送一個等待action
        elif is_participant: # 如果已經在參與者列表
            # If the rendezvous has enough number of participants including us,
            # check whether we have passed the rendezvous deadline. If yes,
            # complete it.
            if len(state.participants) >= ctx.settings.min_nodes: # 如果達到了最小節點數
                if cast(datetime, state.deadline) < datetime.utcnow(): # 如果達到了超時
                    return _Action.MARK_RENDEZVOUS_COMPLETE # 標示 rendezvous 已經結束
        else: # 否則就直接加入到參與者
            # The rendezvous is not complete yet and we are not part of it. Try
            # to join.
            return _Action.ADD_TO_PARTICIPANTS

        if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
            return _Action.KEEP_ALIVE

        # At this point either the rendezvous is not complete, but we are part
        # of it, which means we have to wait for other participants to join; or
        # the rendezvous is complete, but we are not part of it, which means we
        # have to wait for the next round.
        return _Action.SYNC # 否則返回同步狀態 _Action.SYNC
2.3.3.6 _add_to_participants

引擎收到 ADD_TO_PARTICIPANTS 之後,會呼叫 _add_to_participants 從 wait_list 移除節點,插入到 participants。

def _add_to_participants(self) -> None:
    log.debug(
        f"The node '{self._node}' added itself to the participants of round "
        f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
    )

    state = self._state
    state.wait_list.remove(self._node) # 移除節點

    # The ranks of the participants will be set once the rendezvous is
    # complete.
    state.participants[self._node] = 0 # 重新插入

    self._keep_alive()

    if len(state.participants) == self._settings.min_nodes:
        state.deadline = datetime.utcnow() + self._settings.timeout.last_call

    if len(state.participants) == self._settings.max_nodes:
        self._mark_rendezvous_complete()

我們這次從 _restart_workers 開始繪製。

  1. 呼叫 _stop_workers 來關閉worker子程式。此時下圖右側上方 _RendezvousState之中,participants=[1,2]。
  2. 通過 MultiprocessContext.close() 完成關閉操作。
  3. 通過 _initialize_workers 重新初始化 worker。
  4. 呼叫 next_rendezvous 完成新的同步操作。
  5. _RendezvousJoinOp 這次返回ADD_TO_PARTICIPANTS。
  6. 呼叫 _add_to_participants 進行狀態切換。
  7. wait_list 之中的Node被移動到 participants。此時下圖右側上方 _RendezvousState之中,participants=[1,2,3]。
                         +-----------------------------+   +------------------------+  |
                         |  agent._invoke_run          |   | _RendezvousState       |  |
                         |                             |   |                        |  |
                         |       _restart_workers      |   |   participants = [1,2] |  |
                         |              +              |   |                        |  |
+----------------------+ |              |              |   |   wait_list = [3]      |  |
| MultiprocessContext  | |              | 1            |   |                        |  |
|                      | | 2            v              |   +------------------------+  |
|        close()  <-----------+  _stop_workers         |                               |
|                      | |              +              |                               |
+----------------------+ |              |              |                               |
                         |              | 3            |                               |
                         |              v              |                               |
                         |     _initialize_workers     |                               |
                         |              +              |                               |
                         |              |              |                               |
                         +-----------------------------+                               |
                                        |                                              |
                                        | 4                                            |
                                        v                                              |
                                 next_rendezvous                                       |
                                        +                                              |
                                        |                                              |
                                        v                                              |
            +---------------------------+---------------+                              |
            | _op_executor.run(_RendezvousJoinOp)       |                              |
            |           +               +               |                              |
            |           |               |               |                              |
            |           |               | 5             |                              |
            |           |               v               |                              |
            |           |       ADD_TO_PARTICIPANTS     |                              |
            |           |               +               |   +-----------------------+  |
            |           |               |               |   | _RendezvousState      |  |
            |           | <-------------+               |   |                       |  |
            |           |                               |   | participants = [1,2,3]|  |
            |           v     6                  7      |   |                       |  |
            |        _add_to_participants  +--------------> | wait_list = []        |  |
            |                                           |   |                       |  |
            +-------------------------------------------+   +-----------------------+  v

                                                                                 Timeline


0x03 節點離開

3.1 處理機制

節點離開(scale-down)的處理如下:

  • 當Scale down事件發生時,rendezvous將不會通知 torchelastic agent。
  • 如果TE agent以“max_restarts=0”啟動,它依賴於底層排程程式來處理作業重新啟動。
  • 如果“max_restarts>0”,TE代理將終止workers並開始新一輪rendezvous。
    • 代理得到離開的通知,於是現有workers(所有節點上)都全部停止。
    • 這些workers將形成一個新的“WorkerGroup”,所有worker都將以新的RANKWORLD_SIZE 執行。、

3.2 如何模擬

如果想模擬除錯的同學,可以在 test/distributed/elastic/agent/server/test/local_elastic_agent_test.py 之中找到示例程式碼。

def test_double_agent_elastic(self):
    """
    start ``nnodes`` agents, kill odd ones (do not restart), validate
    elasticity (scale-down) works. (scale-up covered in fault_tolerance test)
    """
    min_nodes = 1
    max_nodes = 2
    wait = 2
    node_conf = Conf(entrypoint=_dist_sum, args=(wait,), local_world_size=2)
    agent_results = mp.Queue()
    agent_args = {
        "conf": node_conf,
        "agent_results": agent_results,
        "min_nodes": min_nodes,
        "max_nodes": max_nodes,
        "max_restarts": 2,
    }

    procs = []
    for _ in range(max_nodes):
        p = mp.Process(
            target=self.run_agent,
            kwargs=agent_args,
        )
        procs.append(p)
        p.start()

    # kill odd agents
    for i in range(max_nodes):
        if i % 2 != 0:
            procs[i].kill()

    for i in range(max_nodes):
        p = procs[i]
        p.join()
        if i % 2 == 0:
            self.assertEqual(0, p.exitcode)
        else:
            self.assertEqual(-signal.SIGKILL, p.exitcode)

3.3 如何處理

節點離開,與錯誤處理是同一個程式碼。錯誤處理程式碼如下,如果重試尚未達到最大次數,則試圖重啟workers。如果已經達到了最大次數,則停止 workers。

    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
        
        # 省略
     
        while True:

            # 定期監控
            time.sleep(monitor_interval)
            # 監控客戶程式執行情況
            run_result = self._monitor_workers(self._worker_group)
            
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
            # 程式出錯
            
            if self._remaining_restarts > 0: # 重試
                self._remaining_restarts -= 1
                self._restart_workers(self._worker_group) # 進行重啟
            else:
                self._stop_workers(self._worker_group) # 重試次數達到,結束workers
                self._worker_group.state = WorkerState.FAILED
                self._exit_barrier()
                return run_result

3.3.1 重啟

_restart_workers 會停掉所有 workers,然後重新一輪 rendezvous 。

@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
    """
    Restarts (stops, rendezvous, starts) all local workers in the group.
    """

    role = worker_group.spec.role
    self._stop_workers(worker_group)
    worker_group.state = WorkerState.STOPPED
    self._initialize_workers(worker_group)

3.3.2 停止

停止 workers 就是關閉上下文。

def _shutdown(self) -> None:
    if self._pcontext:
        self._pcontext.close()
        
@prof
def _stop_workers(self, worker_group: WorkerGroup) -> None:
    self._shutdown()

在 MultiprocessContext 之中,close 方法是關閉所有子程式,然後等待其全部停止。

    def _close(self) -> None:
        if self._pc:
            for proc in self._pc.processes:
                proc.terminate()
                proc.join()

流程圖如下:

  1. 監控子程式狀態。
  2. 發現 UNHEALTHY 或者 FAILED,看看重啟次數是否還有。我們假定是3號程式失敗。
  3. 如果沒有,就呼叫 _stop_workers 結束子程式。
  4. 呼叫 MultiprocessContext.close 進行具體結束操作。
  5. 如果還可以重啟,呼叫_restart_workers。
  6. 呼叫 _stop_workers 結束子程式。
  7. 呼叫 MultiprocessContext.close 進行具體結束操作。
  8. 呼叫 _initialize_workers 重新初始化worker。
  9. 呼叫 next_rendezvous 重新同步。
  10. 進行後續操作。
                                                                                 +
+-------------------------------------------+    +---------------------------+   |
| agent._invoke_run                         |    | _RendezvousState          |   |
|                                           |    |                           |   |
|                                           |    |                           |   |
|     _monitor_workers Every 30S            |    |    participants = [1,2,3] |   |
|             +                             |    |                           |   |
|             | 1                           |    |    wait_list = [ ]        |   |
|             |                             |    |                           |   |
|             v                             |    +---------------------------+   |
|     WorkerState.UNHEALTHY,FAILED          |                                    |
|             +                             |                                    |
|             |                             |                                    |
|             | 2                           |                                    |
|             v                             |                                    |
|   self._remaining_restarts > 0 ? +--+     |                                    |
|             +                       |     |                                    |
|          5  | YES                NO | 3   |                                    |
|             |                       |     |                                    |
|             v                       v     |    +----------------------+        |
|     _restart_workers        _stop_workers |    | MultiprocessContext  |        |
|             +                       +     |    |                      |        |
|             | 6                     |  4  |    |                      |        |
|             |                       +--------> |                      |        |
|             v                             |    |        close()       |        |
|      _stop_workers +-------------------------> |                      |        |
|             +                 7           |    +----------------------+        |
|             |                             |                                    |
|             | 8                           |                                    |
|             v                             |                                    |
|    _initialize_workers                    |                                    |
|             +                             |                                    |
|             |                             |                                    |
+-------------------------------------------+                                    |
              | 9                                                                |
              |                                                                  |
              v                                +--------------------------+      |
        next_rendezvous                        | _RendezvousState         |      |
              +                                |                          |      |
              |               10               |     participants = [1,2] |      |
              +---------------------------->   |                          |      |
              |                                |     wait_list = [ ]      |      v
              | 10                             +--------------------------+
              v                                                             Timeline

至此,彈性訓練全部分析完畢,或者說PyTorch分散式分析就告一段落,我們下文會介紹其他框架/庫的分散式實現,敬請期待。

0xFF 參考

[原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路

[原始碼解析] PyTorch 分散式之彈性訓練(2)---啟動&單節點流程

[原始碼解析] PyTorch 分散式之彈性訓練(3)---代理

[原始碼解析] PyTorch 分散式之彈性訓練(4)---Rendezvous 架構和邏輯

[原始碼解析] PyTorch 分散式之彈性訓練(5)---Rendezvous 引擎

相關文章