[原始碼解析] PyTorch 分散式之彈性訓練(3)---代理

羅西的思考發表於2021-12-25

[原始碼解析] PyTorch 分散式之彈性訓練(3)---代理

0x00 摘要

在前面的文章之中,我們已經學習了PyTorch 分散式的基本模組,介紹了官方的幾個例子,我們接下來會介紹PyTorch的彈性訓練,本文是第三篇,看看彈性代理的基本功能。

彈性訓練系列文章如下:

[原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路

[原始碼解析] PyTorch 分散式之彈性訓練(2)---啟動&單節點流程

0x01 總體背景

我們先總述一下,TE 最重要的是 Agent 和 Rendezvous 這兩個概念。

  • Agent是執行在單節點上的獨立後臺程式,可以認為是 worker manager 或者 process supervisor,其負責啟動worker,監控 worker 執行,捕獲woker異常,通過 rendezvous 實現 worker 間的相互發現,當有成員變動時候負責基於 rendezvous 進行變更同步。
  • 為了實現彈性訓練,需要有一個節點/程式之間彼此發現的機制。rendezvous就是這個發現機制或者說同步元件。當系統啟動或者成員變更時候,所有worker會(重新)集合(rendezvous)以建立一個新的程式組。

1.1 功能分離

TE 是圍繞在 Rendezvous 基礎之上的多個elastic agent構成,這是一種功能分離,讓我們對比一下看看。

  • Agent 偏重具體節點上的邏輯
    • Agent 負責具體業務邏輯相關操作,比如啟動程式執行使用者程式,監控使用者程式執行情況,如果有異常就通知 Rendezvous。
    • Agent 是一個 worker manager,負責啟動/管理 workers 程式,組成一個 worker group,監控 workers 執行狀態,捕獲失效 workers,如果有故障/新加入worker,則重啟 worker group。
    • Agent負責維護 WORLD_SIZE 以及 RANK 資訊。使用者不需要再手動提供,Agent會自動處理這些。
    • Agent 是具體節點上的後臺程式,是獨立個體。Agent自己無法實現整體上的彈性訓練,所以需要一個機制來完成 worker 之間的相互發現,變更同步等等(WORLD_SIZE 和 RANK 這些資訊其實也需要多個節點同步才能確定),這就是下面的 Rendezvous 概念。
  • Rendezvous 負責叢集邏輯,保證節點之間對於""有哪些節點參與訓練"達成強一致共識。
    • 每一個 Agent 內部包括一個 Rendezvous handler,這些 handler 總體上構成了一個 Rendezvous 叢集,從而構成了一個 Agent 叢集。
    • Rendezvous 完成之後,會建立一個共享鍵值儲存(shared key-value store),這個store實現了一個torch.distributed.Store API。此儲存僅由已完成Rendezvous的成員共享,它旨在讓Torch Distributed Elastic在初始化作業過程之中交換控制和資料資訊。
    • Rendezvous 負責在每個agent之上維護當前 group 所有相關資訊。每個 agent 之上有一個 rendezvous,它們會互相通訊,總體維護一套資訊,這些資訊儲存在上面提到的Store 之中。
    • Rendezvous 負責叢集邏輯相關,比如新加入節點,移除節點,分配rank等等。

我們首先從原始碼中取出示意圖看看,大家先有一個總體概念。

1.2 Rendezvous

我們本文只是簡單介紹一下 rendezvous,重點在於介紹 agent。

在 Torch Distributed Elastic 上下文之中,我們使用 rendezvous 這個術語來特指一個特定功能:一個結合了對等發現(peer discovery)的分散式同步(distributed synchronization)原語。

Rendezvous 被Torch Distributed Elastic用來收集一個訓練job的參與者(節點),這樣,參與者們可以商議得到參與者列表和每個參與者的角色,也可以對訓練何時開始/恢復做出一致的集體決定。

Rendezvous 把功能分割解耦,業務邏輯被抽象成為一系列運算元,比如 _RendevzousJoinOp。而 Rendezvous 內部維護了一套狀態機,由運算元決定下一步操作。比如 _RendezvousOpExecutor 來執行各種運算元,依據運算元結果得到下一步應該執行的 Action,從而對本身進行操作。

比如在 _DistributedRendezvousOpExecutor 之中,如果發現了當前 action 是 ADD_TO_WAIT_LIST,會執行 _add_to_wait_list,進而呼叫 self._state.wait_list.add(self._node)

if action == _Action.KEEP_ALIVE:
    self._keep_alive()
elif action == _Action.ADD_TO_PARTICIPANTS:
    self._add_to_participants()
elif action == _Action.ADD_TO_WAIT_LIST: # 發現當前Action
    self._add_to_wait_list() # 然後執行
elif action == _Action.REMOVE_FROM_PARTICIPANTS:
    self._remove_from_participants()
elif action == _Action.REMOVE_FROM_WAIT_LIST:
    self._remove_from_wait_list()
elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
    self._mark_rendezvous_complete()
elif action == _Action.MARK_RENDEZVOUS_CLOSED:
    self._mark_rendezvous_closed()

0x02 Agent 總體邏輯

2.1 功能

Elastic agent 是 torchelastic 的控制檯(control plane),他是一個獨立程式,負責啟動和管理底層 worker 程式,代理具體負責:

  • 與PyTorch原生分散式協同工作:使每個worker都能獲得所有需要的資訊,以便成功呼叫 torch.distributed.init_process_group()
  • 容錯:監控每個worker,當出現錯誤或者異常時能及時終止所有worker並重啟它們。
  • 彈性:對成員更改作出反應,並使用新的成員來重啟所有workers。

下圖來自知乎,算是對上一個圖的細化。

img

2.2 工作基礎

Torchelast agent 和 使用者worker 依據故障切換契約來工作:

  • TE(torchelastic)希望使用者worker以5分鐘為誤差完成工作。
  • 設計DDP應用程式時,最好讓所有worker都失敗,而不只是一個worker失敗。
  • TE不會在代理之間同步重啟次數。
  • TE re-rendezvous不會減少重啟次數。
  • 當單個代理完成其工作(成功或失敗)時,它將關閉rendezvous。如果其他代理仍有worker在工作,他們將被終止。
  • 基於上述情況,如果至少有一個代理完成了任務,則縮容(scale down)不起作用。
  • 當代理檢測到Scale up時,它不會減少 "max_restarts"。
  • Torchelast agent 之間通過etcd或者類似後端來保持協同工作。

2.3 部署

簡單的agent部署在每個節點上,並與本地程式協同工作。更高階的agent可以遠端啟動和管理workers。Agent可以做到徹底的去中心化,與其他agents(管理同一個job的workers)進行溝通協調做出一個集體性決策,決策是基於其管理的 workers 情況來完成。

對於如何配置,原始碼中也給出了示例,如果在GPU上啟動訓練一個擁有 8 個 trainer(每GPU一個trainer)的 job,我們可以做如下配置。

1. Use 8 x single GPU instances, place an agent per instance, managing 1 worker per agent.
2. Use 4 x double GPU instances, place an agent per instance, managing 2 workers per agent.
3. Use 2 x quad GPU instances, place an agent per instance, managing 4 workers per agent.
4. Use 1 x 8 GPU instance, place an agent per instance, managing 8 workers per agent.

2.4 基類

基類ElasticAgent 是一個 Abstract Class,真正執行的代理都需要由此派生。從 ElasticAgent 的註釋可知,代理程式負責管理一個或多個worker 程式。工作程式被假定為常規分散式PyTorch指令碼。當worker程式由代理建立時,代理將為worker程式提供必要的資訊,以便正確初始化torch程式組。部署時,精確的拓撲和 agent-to-worker 比率取決於代理的具體實現和使用者作業放置偏好。

class ElasticAgent(abc.ABC):
    """
    Agent process responsible for managing one or more worker processes.
    The worker processes are assumed to be regular distributed PyTorch scripts.
    When the worker process is created by the agent, the agent provides the
    necessary information for the worker processes to properly initialize
    a torch process group.

    The exact deployment topology and ratio of agent-to-worker is dependent
    on the specific implementation of the agent and the user's job placement
    preferences. 

    Usage
    ::

     group_result = agent.run()
      if group_result.is_failed():
        # workers failed
        failure = group_result.failures[0]
        log.exception(f"worker 0 failed with exit code : {failure.exit_code}")
      else:
        return group_result.return_values[0] # return rank 0's results

    """

    @abc.abstractmethod
    def run(self, role: str = DEFAULT_ROLE) -> RunResult:
        """
        Runs the agent, retrying the worker group on failures up to
        ``max_restarts``.

        Returns:
            The result of the execution, containing the return values or
            failure details for each worker mapped by the worker's global rank.

        Raises:
            Exception - any other failures NOT related to worker process
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
        """
        Returns:
            The ``WorkerGroup`` for the given ``role``.
            Note that the worker group is a mutable object and hence in a
            multi-threaded/process environment it may change state.
            Implementors are encouraged (but not required) to return
            a defensive read-only copy.
        """
        raise NotImplementedError()

ElasticAgent 有兩個派生類:

  • SimpleElasticAgent 實現了基類的部分函式,其目的是為了方便擴充套件新代理的實現。
  • LocalElasticAgent 派生了SimpleElasticAgent ,是目前彈性訓練最終使用的代理,主要用於在本地進行操作,負責管理單機上所有的worker程式。

0x03 Worker

我們首先要看看 worker,這是 Agent 所管理的主體。

3.1 Worker 定義

Worker 類代表了一個worker例項,我們上文介紹了WorkerSpec,Worker 就是依據 WorkerSpec 構建出來的,其重點成員變數如下:

  • id(任意):唯一標識一個worker,具體是由ElasticAgent的特定實現來解釋,對於本地代理,它可以是worker的pid(int),對於遠端代理,它可以被編碼為``host:port(string)`。

  • local_rank :worker的local rank。

  • global_rank:worker的global rank。

  • role_rank:具有相同角色的所有worker的rank。

  • world_size:全域性worker數量。

  • role_world_size:具有相同角色的worker數量。

class Worker:
    """
    Represents a worker instance. Contrast this with ``WorkerSpec`` that
    represents the specifications of a worker. A ``Worker`` is created from
    a ``WorkerSpec``. A ``Worker`` is to a ``WorkerSpec`` as an object is to
    a class.

    The ``id`` of the worker is interpreted
    by the specific implementation of ``ElasticAgent``. For a local
    agent, it could be the ``pid (int)`` of the worker, for a remote
    agent it could be encoded as ``host:port (string)``.

    Args:
        id (Any): uniquely identifies a worker (interpreted by the agent)
        local_rank (int): local rank of the worker
        global_rank (int): global rank of the worker
        role_rank (int): rank of the worker across all workers that have the same role
        world_size (int): number of workers (globally)
        role_world_size (int): number of workers that have the same role
    """

    def __init__(
        self,
        local_rank: int,
        global_rank: int = -1,
        role_rank: int = -1,
        world_size: int = -1,
        role_world_size: int = -1,
    ):
        # unique identifier for this worker
        self.id: Any = None

        # rank of the worker among workers with the same role being monitored
        # by the same ``agent`` instance.
        self.local_rank: int = local_rank

        #  rank of the worker among all the workers across all roles
        #  across all ``agent`` instances.
        #  Global rank is not stable between re-rendezvous.
        self.global_rank: int = global_rank

        #  rank of the worker among all the workers with the same role
        #  across all ``agent`` instances.
        #  Global rank is not stable between re-rendezvous.
        self.role_rank: int = role_rank

        # total number of workers (globally). Due to elasticity
        # the world size may change between re-rendezvous.
        self.world_size: int = world_size

        # total number of workers that share the same role. Due to elasticity
        # the role world size may change between re-rendezvous.
        self.role_world_size: int = role_world_size

3.2 WorkerGroup

WorkerGroup 代表了一個工作組,作為一個整體來管理多個 workers,進行批量處理。

class WorkerGroup:
    """
    Represents the set of ``Worker`` instances for the given ``WorkerSpec``
    managed by ``ElasticAgent``. Whether the worker group contains cross
    instance workers or not depends on the implementation of the agent.
    """
    def __init__(self, spec: WorkerSpec):
        self.spec = spec
        self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]

        # assigned after rdzv
        self.store = None
        self.group_rank = None
        self.group_world_size = None

        self.state = WorkerState.INIT

在SimpleElasticAgent 初始化之中,會建立一個 WorkerGroup。

class SimpleElasticAgent(ElasticAgent):
    """
    An ``ElasticAgent`` that manages workers (``WorkerGroup``)
    for a single ``WorkerSpec`` (e.g. one particular type of worker role).
    """

    def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
        self._worker_group = WorkerGroup(spec)
        self._remaining_restarts = self._worker_group.spec.max_restarts
        self._store = None
        self._exit_barrier_timeout = exit_barrier_timeout
        self._total_execution_time = 0

3.3 WorkerState

WorkerState 表示 WorkerGroup的狀態。工作組中的所有工作人員作為一個整體來維護/更改狀態。如果工作組中的一個worker失敗,則整個工作組被認為是失敗:

  UNKNOWN - agent lost track of worker group state, unrecoverable
  INIT - worker group object created not yet started
  HEALTHY - workers running and healthy
  UNHEALTHY - workers running and unhealthy
  STOPPED - workers stopped (interruped) by the agent
  SUCCEEDED - workers finished running (exit 0)
  FAILED - workers failed to successfully finish (exit !0)

具體這些狀態意義如下:

  • UNKNOWN-代理丟失了對工作組狀態的跟蹤,無法恢復

  • INIT-建立的工作組物件尚未啟動

  • HEALTHY-worker健康執行

  • UNHEALTHY-worker在執行但是不健康

  • STOPPED-代理停止(中斷)worker

  • SUCCEEDED-worker已完成執行(exit數值為0)

  • FAILED-worker未能成功完成(exit數值不等於0)

工作組從初始的INIT狀態開始,然後進入"健康"或"不健康"狀態,最後到達終端"成功"或"失敗"狀態。工作組可以被代理打斷並且臨時置於"停止"狀態。處於"已停止"狀態的工作程式可以在不久的將來被排程重啟,被設定為已停止的狀態的例子為:

  • 觀察到工作組故障|不健康
  • 檢測到成員更改

當工作組上的操作(啟動、停止、rdzv、重試等)失敗,並導致操作部分應用於工作組時,狀態將為"未知"。這通常發生在狀態改變期間發生異常,而且異常未捕獲/未處理的情況下。當工作組處於"未知"狀態,代理不會恢復工作組,因此最好終止作業,並且由job manager重試節點。

WorkerState 具體定義如下:

class WorkerState(str, Enum):
    """
    State of the ``WorkerGroup``. Workers in a worker group change state as a unit.
    If a single worker in a worker group fails the entire set is considered
    failed::

      UNKNOWN - agent lost track of worker group state, unrecoverable
      INIT - worker group object created not yet started
      HEALTHY - workers running and healthy
      UNHEALTHY - workers running and unhealthy
      STOPPED - workers stopped (interruped) by the agent
      SUCCEEDED - workers finished running (exit 0)
      FAILED - workers failed to successfully finish (exit !0)


    A worker group starts from an initial ``INIT`` state,
    then progresses to ``HEALTHY`` or ``UNHEALTHY`` states,
    and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state.

    Worker groups can be interrupted and temporarily put into ``STOPPED`` state
    by the agent. Workers in ``STOPPED`` state are scheduled to be restarted
    in the near future by the agent. Some examples of workers being put into
    ``STOPPED`` state are:

    1. Worker group failure|unhealthy observed
    2. Membership change detected

    When actions (start, stop, rdzv, retry, etc) on worker group fails
    and results in the action being partially applied to the worker group
    the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled
    exceptions during state change events on the agent. The agent is not
    expected to recover worker groups in ``UNKNOWN`` state and is better off
    self terminating and allowing the job manager to retry the node.
    """

    UNKNOWN = "UNKNOWN"
    INIT = "INIT"
    HEALTHY = "HEALTHY"
    UNHEALTHY = "UNHEALTHY"
    STOPPED = "STOPPED"
    SUCCEEDED = "SUCCEEDED"
    FAILED = "FAILED"

    @staticmethod
    def is_running(state: "WorkerState") -> bool:
        """
        Returns:
             True if the worker state represents workers still running
             (e.g. that the process exists but not necessarily healthy).
        """
        return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY}

0x04 SimpleElasticAgent

SimpleElasticAgent 是 Agent 的實現類之一。此抽象是為了方便擴充套件新的 agent 實現。從後面可知,目前內建的 LocalElasticAgent 負責管理單機上的所有 worker 程式,如果使用者希望只用一個代理就管理多機上所有的 worker,而不僅僅是本機 worker,那麼可以通過擴充套件 SimpleElasticAgent 來實現一個自定義 Agent。

class SimpleElasticAgent(ElasticAgent):
    """
    An ``ElasticAgent`` that manages workers (``WorkerGroup``)
    for a single ``WorkerSpec`` (e.g. one particular type of worker role).
    """

    def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
        self._worker_group = WorkerGroup(spec)
        self._remaining_restarts = self._worker_group.spec.max_restarts
        self._store = None
        self._exit_barrier_timeout = exit_barrier_timeout
        self._total_execution_time = 0

4.1 總體執行

SimpleElasticAgent 主迴圈 _invoke_run 是核心邏輯(這裡預設代理和worker在同一個機器之上),其中做如下操作:

  • 使用 self._initialize_workers(self._worker_group) 完成初始化工作,比如來啟動 worker,為每個worker 分配 rank 等等。
  • 然後進入 while True 迴圈,在迴圈之中通過 _monitor_workers 定期輪訓使用者程式執行情況,得到 worker 程式執行結果,然後依據情況進行不同處理。
    • 如果程式正常結束,則返回。
    • 如果程式出錯,則重試,如果重試次數達到,結束workers。
    • 如果節點成員關係有變化,比如scale up就會有新的節點在waiting,這時候就重啟所有workers。
    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
        # NOTE: currently only works for a single role

        spec = self._worker_group.spec
        role = spec.role

        self._initialize_workers(self._worker_group) # 啟動worker
        monitor_interval = spec.monitor_interval
        rdzv_handler = spec.rdzv_handler

        while True:
            assert self._worker_group.state != WorkerState.INIT
            # 定期監控
            time.sleep(monitor_interval)
            # 監控客戶程式執行情況
            run_result = self._monitor_workers(self._worker_group) # 得到程式執行結果
            state = run_result.state
            self._worker_group.state = state

            put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
            put_metric(f"workers.{role}.{state.name.lower()}", 1)

            if state == WorkerState.SUCCEEDED:
                # 程式正常結束
                self._exit_barrier()
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # 程式出錯
                if self._remaining_restarts > 0: # 重試
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group)
                else:
                    self._stop_workers(self._worker_group) # 重試次數達到,結束workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
                # 節點成員關係有變化,比如scale up,就會有新節點waiting
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                # 如果有新的節點在waiting,就重啟所有workers
                if num_nodes_waiting > 0:
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")

上面只是大概講了下這個總體流程,我們接下來對這個總體流程逐一分析。

4.2 初始化workers

代理主迴圈之中,首先使用 self._initialize_workers(self._worker_group) 來啟動 worker。在 _initialize_workers之中:

  • 首先使用 self._rendezvous(worker_group) 進行節點之間的同步共識操作以及rank處理等等。
  • 其次呼叫 _start_workers 啟動 workers。這裡的 _start_workers 是虛擬函式,需要派生類實現。
    @prof
    def _initialize_workers(self, worker_group: WorkerGroup) -> None:
        r"""
        Starts a fresh set of workers for the worker_group.
        Essentially a rendezvous followed by a start_workers.

        The caller should first call ``_stop_workers()`` to stop running workers
        prior to calling this method.

        Optimistically sets the state of the worker group that
        just started as ``HEALTHY`` and delegates the actual monitoring
        of state to ``_monitor_workers()`` method
        """
        role = worker_group.spec.role

        # TODO after stopping workers, wait at least monitor_interval*2 for
        # workers on different nodes to fail on a collective op before waiting
        # on the rdzv barrier, this way we ensure that nodes enter rdzv
        # at around the same time and reduce false positive rdzv timeout errors
        self._rendezvous(worker_group) # 同步共識操作 

        worker_ids = self._start_workers(worker_group) # 啟動worker
        for local_rank, w_id in worker_ids.items():
            worker = worker_group.workers[local_rank]
            worker.id = w_id

        worker_group.state = WorkerState.HEALTHY

4.2.1 _rendezvous

我們首先看看_rendezvous,其做如下操作:

  • 呼叫 next_rendezvous() 來處理成員關係變化,其會返回 world size,store等。
  • 會把 store 配置到 workgroup 之中,後續worker 之間就可以通過這個kvstore進行溝通
  • 呼叫 _assign_worker_ranks 會生成 worker,並且為 worker 建立 ranks,返回的 workers 都賦值在代理的 worker_group.workers 之中。

以上兩點都是利用 rendezvous 的資訊來進行處理,比如從 rendezvous 之中提取 ranks。

    @prof
    def _rendezvous(self, worker_group: WorkerGroup) -> None:
        r"""
        Runs rendezvous for the workers specified by worker spec.
        Assigns workers a new global rank and world size.
        Updates the rendezvous store for the worker group.
        """

        spec = worker_group.spec

        # 處理成員關係變化,注意,這裡得到的是 group rank!
        store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
        self._store = store # store被設定到 Agent之中,store可以被認為是遠端KV儲存

        # 依據 group rank 為 worker 建立 ranks
        workers = self._assign_worker_ranks(store, group_rank, group_world_size, spec)
        worker_group.workers = workers
        worker_group.store = store
        worker_group.group_rank = group_rank
        worker_group.group_world_size = group_world_size

        if group_rank == 0:
            self._set_master_addr_port(store, spec.master_addr, spec.master_port)
        master_addr, master_port = self._get_master_addr_port(store)
        restart_count = spec.max_restarts - self._remaining_restarts
4.2.2.1 處理成員關係變化

Elastic 呼叫 rdzv_handler.next_rendezvous() 來處理成員關係變化,目的是啟動下一輪 rendezvous 操作(因為本worker已經啟動,需要加入叢集)。

注意,next_rendezvous 是 RendezvousHandler 的內部函式。這一函式呼叫會被阻塞,直到 worker 的數量達到了要求。在 worker 被初始化,或者重啟的時候,這一函式都會被呼叫。當函式返回時,不同的 worker group 會以返回中的 rank 作為唯一的標示。其內部邏輯是:

  • 先使用_RendezvousExitOp讓該node退出。
  • 然後再使用_RendezvousJoinOp把該node重新加入。
  • 最後啟動心跳,返回world size,store等。
    def next_rendezvous(self) -> Tuple[Store, int, int]:
        """See base class."""

        self._stop_heartbeats()

        # Delay the execution for a small random amount of time if this is our
        # first run. This will slightly skew the rendezvous attempts across the
        # nodes and reduce the load on the backend.
        if self._state_holder.state.round == 0:
            _delay(seconds=(0, 0.3))

        exit_op = _RendezvousExitOp()
        join_op = _RendezvousJoinOp()

        deadline = self._get_deadline(self._settings.timeout.join)

        self._op_executor.run(exit_op, deadline)
        self._op_executor.run(join_op, deadline)

        self._start_heartbeats()

        rank, world_size = self._get_world()
        store = self._get_store()

        return store, rank, world_size # 返回的是 worker group 的rank
4.2.3.2 為 worker 分配 ranks

接著是呼叫 _assign_worker_ranks 為 worker 建立 ranks。分配 rank 演算法如下:

  1. 每個代理將其配置(group_rank, group_world_size , num_workers)寫入公共儲存
  2. 每個代理檢索所有代理的配置,並使用角色和rank執行兩級排序。
  3. 確定全域性rank:當前代理的global rank是 本代理 的 group_rank 在infos陣列的偏移量(offset)。偏移量的計算方法是,排名低於group_rank的所有代理的local_world之和。workers 的等級為:[offset, offset+local_world_size]。
  4. 確定role rank:使用第3點中的演算法確定role rank,不同之處是:偏移量計算是從與當前角色相同且具有最小 group rank 的第一個代理開始。
  5. 因為所有代理都使用同樣演算法,所以其計算出的 ranks 陣列都是相同的。

然後生成 workers,把 worker 都賦值在 worker_group.workers 之中。

@prof
def _assign_worker_ranks(
    self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
) -> List[Worker]:
    """
    Determines proper ranks for worker processes. The rank assignment
    is done according to the following algorithm:

    1. Each agent writes its configuration(group_rank, group_world_size
       , num_workers) to the common store.
    2. Each agent retrieves configuration for all agents
       and performs two level sort using role and rank.
    3. Determine the global rank: the global rank of the workers for the current
       agent is the offset of the infos array up to group_rank of the agent.
       The offset is computed as a sum of local_world_size of all agents that
       have rank less than the group_rank. The workers would have the ranks:
       [offset, offset+local_world_size)
    4. Determine the role rank: The role rank is determined using the algorithms
       in the point 3 with the exception that the offset is done from the first
       agent that has the same role as current one and has the minimum group rank.
    """

    # 每個代理將其配置(group_rank, group_world_size, num_workers)寫入公共儲存。
    role_infos = self._share_and_gather(store, group_rank, group_world_size, spec)
    # 每個代理檢索所有代理的配置,並使用角色和rank執行兩級排序。
    my_role_info = role_infos[group_rank]
    # 確定全域性rank:當前代理的global rank是 本代理 的 group_rank 在infos陣列的偏移量(offset)。偏移量的計算方法是,排名低於group_rank的所有代理的local_world之和。workers 的等級為:[offset, offset+local_world_size]。
    worker_world_size, worker_global_ranks = self._get_ranks(role_infos, group_rank)
    role_infos = sorted(
        role_infos, key=functools.cmp_to_key(_RoleInstanceInfo.compare)
    )
    role_start_idx, role_end_idx = _RoleInstanceInfo.find_role_boundaries(
        role_infos, my_role_info.role
    )
    role_pos = next(
        idx
        for idx, role_info in enumerate(role_infos)
        if _RoleInstanceInfo.compare(role_info, my_role_info) == 0
    )
    # 確定role rank:使用第3點中的演算法確定role rank,不同之處是:偏移量計算是從與當前角色相同且具有最小 group rank 的第一個代理開始。
    role_world_size, role_ranks = self._get_ranks(
        role_infos, role_pos, role_start_idx, role_end_idx + 1
    )
    # 生成 workers,把 worker 都賦值在 worker_group.workers 之中。
    workers = []
    for ind in range(spec.local_world_size):
        worker = Worker(
            local_rank=ind,
            global_rank=worker_global_ranks[ind],
            role_rank=role_ranks[ind],
            world_size=worker_world_size,
            role_world_size=role_world_size,
        )
        workers.append(worker)
    return workers

4.2.4 啟動 workers 程式

呼叫 派生類的 _start_workers 來啟動 worker 程式,因此基類這裡沒有實現,我們後續會看到派生類如何實現。

    @abc.abstractmethod
    def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
        r"""
        Starts ``worker_group.spec.local_world_size`` number of workers
        according to worker spec for the worker group .

        Returns a map of ``local_rank`` to worker ``id``.
        """
        raise NotImplementedError()

目前邏輯如下,具體是:

  1. 呼叫 rdzv_handler.next_rendezvous 來與其他 Node 進行同步。
  2. rdzv_handler.next_rendezvous 返回 ranks 等資訊給_assign_worker_ranks。
  3. _assign_worker_ranks會生成一些Workers,其中每個 Worker都被自動分配了 rank。這些 workers 被 Agent的worker_group.workers所指向。
+--------------------------------------------------+
| LocalElasticAgent                                |         _initialize_workers
|                                                  |                 +
|                                                  |                 |
|                                                  |                 |
|   +----------------------+                       |                 v
|   |WorkerGroup           |                       |         _rendezvous(worker_group)
|   |                      |                       |                 +
|   |     spec             |                       |                 |
|   |                      |                       |                 | 1
|   |     group_world_size |                       |                 v
|   |                      |                       |        rdzv_handler.next_rendezvous()
|   |     store            |                       |                 +
|   |                      |    +----------------+ |                 |
|   |     group_rank       |    | Worker0(rank 0)| |               2 | ranks
|   |                      |    | Worker1(rank 1)| |  Workers        v
|   |     workers  +----------> | ...            | | <----+ _assign_worker_ranks
|   |                      |    | Workern(rank n)| |    3
|   +----------------------+    +----------------+ |
|                                                  |
+--------------------------------------------------+

接下來會分別把 rank 相關和 worker 相關的函式都分別羅列出來,以便大家更好的理解。

4.3 ranks相關

前面的 _assign_worker_ranks 為 worker 建立 ranks,但是其內部有些細節我們還需要梳理一下。

4.3.1 _RoleInstanceInfo

這裡要介紹一下 _RoleInstanceInfo 這個資料結構。代理使用該類與其他代理交換資訊。該資訊用於確定本代理workers的rank。這些代理工作在異構環境下,不同代理也許有不同數量的workers。其構建引數是:

  • role (str) : 使用者定義的role。
  • rank (int) : 代理的rank。
  • local_world_size (int) : 本地 workers 的數目。
class _RoleInstanceInfo:
    """
    The class is used by the agent to exchange the information with other agents.
    The information is used to determine the rank of the workers that agent
    manages in heterogeneous environments, where different agents can have
    different number of workers.
    """

    __slots__ = ["role", "rank", "local_world_size"]

    def __init__(self, role: str, rank: int, local_world_size: int):
        r"""

        Args:
            role (str): user-defined role for the workers with this spec
            rank (int): the rank of the agent
            local_world_size (int): number of local workers to run
        """

        self.role = role
        self.rank = rank
        self.local_world_size = local_world_size

    def serialize(self) -> bytes:
        dict_data = {
            "role": self.role,
            "rank": self.rank,
            "local_world_size": self.local_world_size,
        }
        return json.dumps(dict_data).encode(encoding="UTF-8")

    @staticmethod
    def deserialize(data: bytes):
        dict_data = json.loads(data.decode(encoding="UTF-8"))
        return _RoleInstanceInfo(
            dict_data["role"], dict_data["rank"], dict_data["local_world_size"]
        )

    @staticmethod
    def compare(obj1, obj2) -> int:
        if obj1.role == obj2.role:
            return obj1.rank - obj2.rank
        elif obj1.role > obj2.role:
            return 1
        else:
            return -1

    @staticmethod
    def find_role_boundaries(roles_infos: List, role: str) -> Tuple[int, int]:
        start_idx, end_idx = -1, -1
        for idx, role_info in enumerate(roles_infos):
            if role_info.role == role:
                if start_idx == -1:
                    start_idx = idx
                end_idx = idx
        return (start_idx, end_idx)

4.3.2 _share_and_gather

_share_and_gather 的作用是在各個代理之間同步,得到角色的總體資訊。每個代理將其配置(group_rank, group_world_size , num_workers)寫入公共儲存這裡就是使用之前 Rendezvous 返回的 store 來進行資訊共享

    def _share_and_gather(
        self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
    ) -> List:
        agent_role_info = _RoleInstanceInfo(
            spec.role, group_rank, spec.local_world_size
        )
        key_prefix = "torchelastic/role_info"
        agent_config_enc = agent_role_info.serialize()
        role_infos_bytes = store_util.synchronize(
            store, agent_config_enc, group_rank, group_world_size, key_prefix
        )
        role_infos = [
            _RoleInstanceInfo.deserialize(role_info_bytes)
            for role_info_bytes in role_infos_bytes
        ]
        return role_infos

4.3.3 _get_ranks

依據 role infos 來確定全域性rank:當前代理的global rank是 本代理 的 group_rank 在infos陣列的偏移量(offset)。偏移量的計算方法是,排名低於group_rank的所有代理的local_world之和。workers 的等級為:[offset, offset+local_world_size]。

def _get_ranks(
    self,
    role_infos: List[_RoleInstanceInfo],
    role_idx: int,
    start_idx: int = 0,
    end_idx: int = -1,
) -> Tuple[int, List[int]]:
    if end_idx == -1:
        end_idx = len(role_infos)
    prefix_sum = 0
    total_sum = 0
    for idx in range(start_idx, end_idx):
        if role_idx > idx:
            prefix_sum += role_infos[idx].local_world_size
        total_sum += role_infos[idx].local_world_size
    return (
        total_sum,
        list(range(prefix_sum, prefix_sum + role_infos[role_idx].local_world_size)),
    )

目前邏輯擴充如下:

  1. 呼叫 rdzv_handler.next_rendezvous() 來和其他節點進行同步,獲得資訊。
  2. 獲得資訊中的store(可以認為就是遠端的KV儲存),group_world_size,group_rank 傳給 Agent。
  3. ranks 等資訊傳給 _assign_worker_ranks方法。
  4. _assign_worker_ranks 之中,呼叫 _share_and_gather 在各個代理之間同步,得到角色的總體資訊。每個代理將其配置(group_rank, group_world_size , num_workers)寫入公共KV儲存。
  5. 依據 role infos 來確定全域性rank:當前代理的global rank是 本代理 的 group_rank 在infos陣列的偏移量(offset)。偏移量的計算方法是,排名低於group_rank的所有代理的local_world之和。
  6. 使用各種資訊建立一系列的 Workers。
  7. Workers 被複制給 Agent 的 WorkerGroup 之中。
                                                              _initialize_workers
                                                                      +
                                                                      |
                                                                      |
                                                                      v
                                                              _rendezvous(worker_group)
                                                                      +
+----------------------------------------------+                      |
| LocalElasticAgent                            |                      | 1
|                                              |   2                  v
|                                         +--------------+  rdzv_handler.next_rendezvous()
| +--------------------+                  |    |                      +
| | WorkerGroup        |                  |    |                      |
| |                    |                  |    |                    3 | ranks
| |                    |                  |    |                      v
| |  spec              |                  |    |       +--------------+------------------+
| |                    |                  |    |       | _assign_worker_ranks            |
| |                    |                  |    |       |                                 |
| |  store   <----------------------------+    |       |                        4        |
| |                    |                  |    |       | role_infos = _share_and_gather( |
| |                    |                  |    |       |               +          store) |
| |  group_world_size<--------------------+    |       |               | 5               |
| |                    |                  |    |       |               |                 |
| |                    |                  |    |       |               v                 |
| |  group_rank <-------------------------+    |       |          _get_ranks(world...)   |
| |                    |                       |       |          _get_ranks(role...)    |
| |                    |   +----------------+  |       |               +                 |
| |  workers  +----------->+ Worker0(rank 0)|  |       |               |                 |
| |                    |   | Worker1(rank 1)|  |       |               | 6               |
| |                    |   | ...            |  |Workers|               v                 |
| |                    |   | Workern(rank n)+<------------+ new Worker(local_rank,       |
| +--------------------+   +----------------+  |    7  |               global_rank,      |
|                                              |       |               role_rank,        |
+----------------------------------------------+       |               world_size,       |
                                                       |               role_world_size)  |
                                                       |                                 |
                                                       +---------------------------------+

_rendezvous 操作之後,Worker 例項已經生成了,接下來就看看如何生成 Worker 程式。但是因為這些方法在 SimpleElasticAgent 之中並沒有實現,所以我們需要在其派生類 LocalElasticAgent 分析小節才能繼續擴充我們的邏輯圖。

4.4 Worker 相關

我們先看看 SimpleElasticAgent 剩餘兩個 worker 相關函式。

4.4.1 重啟

_restart_workers 是重啟 workers。

# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
#  `torch.distributed.elastic.metrics.prof`.
@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
    """
    Restarts (stops, rendezvous, starts) all local workers in the group.
    """

    role = worker_group.spec.role
    self._stop_workers(worker_group)
    worker_group.state = WorkerState.STOPPED
    self._initialize_workers(worker_group)

4.4.2 barrier

實際上,幾乎不可能保證DDP的所有 worker 都能保證同時結束,所以因此TE提供了一個finalization barrier,這個barrier的作用是對worker finalization 實施等待超時(5分鐘)。

    def _exit_barrier(self):
        """
        Wait for ``exit_barrier_timeout`` seconds for all agents to finish
        executing their local workers (either successfully or not). This
        acts as a safety guard against user scripts that terminate at different
        times. This barrier keeps the agent process alive until all workers finish.
        """
        start = time.time()
        try:
            store_util.barrier(
                self._store,
                self._worker_group.group_rank,
                self._worker_group.group_world_size,
                key_prefix=_TERMINAL_STATE_SYNC_ID,
                barrier_timeout=self._exit_barrier_timeout,
            )
        except Exception:
            log.exception(
                f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds"
            )

0x05 LocalElasticAgent

LocalElasticAgent 是彈性訓練最終使用的代理,主要用於在本地進行操作,負責管理單機上所有的worker程式,其派生了 SimpleElasticAgent

此代理在每個主機之上部署,並配置為生成n個工作程式。當使用GPU時,n是主機上可用的GPU數量。本地代理不會與部署在其他主機上的其他本地代理通訊,即使worker可以在主機間通訊。Worker id被解釋為本地程式。代理作為把本機所有工作程式作為一個整體啟動和停止。

傳遞給worker的函式和引數必須與python multiprocessing相容。要將multiprocessing資料結構傳遞給worker,使用者可以在與指定的start_method相同的多處理multiprocessing中建立資料結構,並將其作為函式引數傳遞。

exit_barrier_timeout用來指定等待其他代理完成的時間量(以秒為單位)。這起到了一個安全網的作用,可以處理worker在不同時間完成的情況,以防止代理將提前完成的worker視為scale-down事件。強烈建議使用者程式碼確保worker以同步方式終止,而不是依賴於exit_barrier_timeout。

SimpleElasticAgent 主要是提供給了初始化和總體執行方式,但是遺留了一些抽象函式沒有被實現,比如_start_workers_stop_workers_monitor_workers_shutdown。LocalElasticAgent 就補齊了這些函式。

class LocalElasticAgent(SimpleElasticAgent):
    """
    An implementation of :py:class:`torchelastic.agent.server.ElasticAgent`
    that handles host-local workers.
    This agent is deployed per host and is configured to spawn ``n`` workers.
    When using GPUs, ``n`` maps to the number of GPUs available on the host.

    The local agent does not communicate to other local agents deployed on
    other hosts, even if the workers may communicate inter-host. The worker id
    is interpreted to be a local process. The agent starts and stops all worker
    processes as a single unit.


    The worker function and argument passed to the worker function must be
    python multiprocessing compatible. To pass multiprocessing data structures
    to the workers you may create the data structure in the same multiprocessing
    context as the specified ``start_method`` and pass it as a function argument.

    The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait
    for other agents to finish. This acts as a safety net to handle cases where
    workers finish at different times, to prevent agents from viewing workers
    that finished early as a scale-down event. It is strongly advised that the
    user code deal with ensuring that workers are terminated in a synchronous
    manner rather than relying on the exit_barrier_timeout.
    """

    def __init__(
        self,
        spec: WorkerSpec,
        start_method="spawn",
        exit_barrier_timeout: float = 300,
        log_dir: Optional[str] = None,
    ):
        super().__init__(spec, exit_barrier_timeout)
        self._start_method = start_method
        self._pcontext: Optional[PContext] = None
        rdzv_run_id = spec.rdzv_handler.get_run_id()
        self._log_dir = self._make_log_dir(log_dir, rdzv_run_id)

    def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str):
        base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_")
        os.makedirs(base_log_dir, exist_ok=True)
        dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir)
        return dir

5.1 使用

我們先從其註釋中提取程式碼,看看如何使用。以下是如何把function作為入口來啟動。

    def trainer(args) -> str:
        return "do train"

    def main():
        start_method="spawn"
        shared_queue= multiprocessing.get_context(start_method).Queue()
        spec = WorkerSpec(
                    role="trainer",
                    local_world_size=nproc_per_process,
                    entrypoint=trainer,
                    args=("foobar",),
                    ...<OTHER_PARAMS...>)
        agent = LocalElasticAgent(spec, start_method)
        results = agent.run()

        if results.is_failed():
            print("trainer failed")
        else:
            print(f"rank 0 return value: {results.return_values[0]}")
            # prints -> rank 0 return value: do train

以下是如何把binary作為入口來啟動。

    def main():
        spec = WorkerSpec(
                    role="trainer",
                    local_world_size=nproc_per_process,
                    entrypoint="/usr/local/bin/trainer",
                    args=("--trainer_args", "foobar"),
                    ...<OTHER_PARAMS...>)
        agent = LocalElasticAgent(spec)
        results = agent.run()

        if not results.is_failed():
            print("binary launches do not have return values")

_rendezvous 操作之後,Worker 例項已經生成了,接下來就看看如何生成 Worker 程式。

5.2 停止

以下函式會停止workers。

    @prof
    def _stop_workers(self, worker_group: WorkerGroup) -> None:
        self._shutdown()
        
    def _shutdown(self) -> None:
        if self._pcontext:
            self._pcontext.close()        

5.3 初始化

我們接著前文來說,_rendezvous 操作之後,Worker 例項已經生成了,接下來就看看如何生成 Worker 程式。之前因為這些方法在 SimpleElasticAgent 之中並沒有實現,所以我們在本小結繼續擴充我們的邏輯圖。

我們先再看看初始化workers。在 _initialize_workers之中,首先使用 _rendezvous 建立 workers 例項,其次呼叫 _start_workers 啟動 workers。

    @prof
    def _initialize_workers(self, worker_group: WorkerGroup) -> None:
        r"""
        Starts a fresh set of workers for the worker_group.
        Essentially a rendezvous followed by a start_workers.

        The caller should first call ``_stop_workers()`` to stop running workers
        prior to calling this method.

        Optimistically sets the state of the worker group that
        just started as ``HEALTHY`` and delegates the actual monitoring
        of state to ``_monitor_workers()`` method
        """
        role = worker_group.spec.role

        # TODO after stopping workers, wait at least monitor_interval*2 for
        # workers on different nodes to fail on a collective op before waiting
        # on the rdzv barrier, this way we ensure that nodes enter rdzv
        # at around the same time and reduce false positive rdzv timeout errors
        self._rendezvous(worker_group) # Worker例項已經生成了

        worker_ids = self._start_workers(worker_group) # 啟動Worker程式
        for local_rank, w_id in worker_ids.items():
            worker = worker_group.workers[local_rank]
            worker.id = w_id # 得到程式ID

        worker_group.state = WorkerState.HEALTHY

5.4 啟動 worker 程式

_start_workers 方法會呼叫 start_processes 來啟動 worker 程式,預設_start_method 是 "spawn"。也就是啟動了多個程式,並行執行使用者程式。同時這些程式的執行結果會被監控。start_processes 引數之中,entrypointargs 是使用者命令和引數,entrypoint可以是函式或者字串。

_start_workers 把 start_processes 方法啟動多執行緒的結果儲存在 _pcontext 之中,後續就用 _pcontext 來繼續控制,比如結束 worker 就是直接呼叫 _pcontext 的 close方法。

    @prof
    def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
        spec = worker_group.spec
        store = worker_group.store
        assert store is not None
        master_addr, master_port = super()._get_master_addr_port(store)
        restart_count = spec.max_restarts - self._remaining_restarts

        use_agent_store = spec.rdzv_handler.get_backend() == "static"

        args: Dict[int, Tuple] = {}
        envs: Dict[int, Dict[str, str]] = {}
        for worker in worker_group.workers:
            local_rank = worker.local_rank
            worker_env = {
                "LOCAL_RANK": str(local_rank),
                "RANK": str(worker.global_rank),
                "GROUP_RANK": str(worker_group.group_rank),
                "ROLE_RANK": str(worker.role_rank),
                "ROLE_NAME": spec.role,
                "LOCAL_WORLD_SIZE": str(spec.local_world_size),
                "WORLD_SIZE": str(worker.world_size),
                "GROUP_WORLD_SIZE": str(worker_group.group_world_size),
                "ROLE_WORLD_SIZE": str(worker.role_world_size),
                "MASTER_ADDR": master_addr,
                "MASTER_PORT": str(master_port),
                "TORCHELASTIC_RESTART_COUNT": str(restart_count),
                "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
                "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
                "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
                "NCCL_ASYNC_ERROR_HANDLING": str(1),
            }
            if "OMP_NUM_THREADS" in os.environ:
                worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
            envs[local_rank] = worker_env
            worker_args = list(spec.args)
            worker_args = macros.substitute(worker_args, str(local_rank))
            args[local_rank] = tuple(worker_args)

        # scaling events do not count towards restarts (gets same attempt #)
        # remove existing log dir if this restart is due to a scaling event
        attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
        shutil.rmtree(attempt_log_dir, ignore_errors=True)
        os.makedirs(attempt_log_dir)

        self._pcontext = start_processes( # 把啟動多執行緒的結果儲存在 _pcontext 之中。
            name=spec.role,
            entrypoint=spec.entrypoint,
            args=args,
            envs=envs,
            log_dir=attempt_log_dir,
            start_method=self._start_method,
            redirects=spec.redirects,
            tee=spec.tee,
        )

        return self._pcontext.pids()

5.5 監控

執行之後,TE 會呼叫 _monitor_workers 對workers進行監控。之前把啟動多執行緒的結果儲存在 _pcontext 之中,現在就用 _pcontext 對執行情況進行監控。

    @prof
    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
        role = worker_group.spec.role
        worker_pids = {w.id for w in worker_group.workers}
        assert self._pcontext is not None
        pc_pids = set(self._pcontext.pids().values())
        if worker_pids != pc_pids:
            return RunResult(state=WorkerState.UNKNOWN)

        result = self._pcontext.wait(0) # 對執行結構進行監控
        if result:
            if result.is_failed(): # 如果程式失敗
                # map local rank failure to global rank
                worker_failures = {}
                #  返回的結果內部就包括每個程式的執行結果
                for local_rank, failure in result.failures.items():
                    worker = worker_group.workers[local_rank]
                    worker_failures[worker.global_rank] = failure
                return RunResult(
                    state=WorkerState.FAILED,
                    failures=worker_failures, # 返回執行結果
                )
            else:
                # copy ret_val_queue into a map with a global ranks
                workers_ret_vals = {}
                for local_rank, ret_val in result.return_values.items():
                    worker = worker_group.workers[local_rank]
                    workers_ret_vals[worker.global_rank] = ret_val
                return RunResult(
                    state=WorkerState.SUCCEEDED,
                    return_values=workers_ret_vals, # 返回執行結果
                )
        else:
            return RunResult(state=WorkerState.HEALTHY)

因為啟動和監控涉及到系統整體執行邏輯,需要和 rendezvous 一起才能更好理解,所以我們把這部分的分析推遲,等到 Rendezvous 之後再來做整體分析。

目前總體邏輯如下:

  1. 呼叫 rdzv_handler.next_rendezvous() 來和其他節點進行同步,獲得資訊。
  2. 獲得資訊中的store(可以認為就是遠端的KV儲存),group_world_size,group_rank 傳給 Agent。
  3. ranks 等資訊傳給 _assign_worker_ranks方法。
  4. _assign_worker_ranks 之中,呼叫 _share_and_gather 在各個代理之間同步,得到角色的總體資訊。每個代理將其配置(group_rank, group_world_size , num_workers)寫入公共KV儲存。
  5. 依據 role infos 來確定全域性rank:當前代理的global rank是 本代理 的 group_rank 在infos陣列的偏移量(offset)。偏移量的計算方法是,排名低於group_rank的所有代理的local_world之和。
  6. 使用各種資訊建立一系列的 Workers。
  7. Workers 被複制給 Agent 的 WorkerGroup 之中。
  8. 使用 _start_workers 來啟動 worker 程式。
  9. 把 worker 程式 id 賦值給 Agent 的 worker.id 之中,這樣以後就可以用 worker.id 來操作程式。
  10. 使用 _monitor_workers 監控 worker 程式。
  11. 使用 _exit_barrier 來等待 worker 程式結束。
                                                              _initialize_workers
                                                                      +
                                                                      |
                                                                      |
                                                                      v
                                                              _rendezvous(worker_group)
                                                                      +
+----------------------------------------------+                      |
| LocalElasticAgent                            |                      | 1
|                                              |   2                  v
|                                         +--------------+  rdzv_handler.next_rendezvous()
| +--------------------+                  |    |                      +
| | WorkerGroup        |                  |    |                      |
| |                    |                  |    |                    3 | ranks
| |                    |                  |    |                      v
| |  spec              |                  |    |       +--------------+------------------+
| |                    |                  |    |       | _assign_worker_ranks            |
| |                    |                  |    |       |                                 |
| |  store   <----------------------------+    |       |                        4        |
| |                    |                  |    |       | role_infos = _share_and_gather( |
| |                    |                  |    |       |               +          store) |
| |  group_world_size<--------------------+    |       |               | 5               |
| |                    |                  |    |       |               |                 |
| |                    |                  |    |       |               v                 |
| |  group_rank <-------------------------+    |       |          _get_ranks(world...)   |
| |                    |                       |       |          _get_ranks(role...)    |
| |                    |   +----------------+  |       |               +                 |
| |  workers  +----------->+ Worker0(rank 0)|  |       |               |                 |
| |                    |   | Worker1(rank 1)|  |       |               | 6               |
| |                    |   | ...            |  |Workers|               v                 |
| |                    |   | Workern(rank n)+<------------+ new Worker(local_rank,       |
| +--------------------+   +---------+------+  |    7  |               global_rank,      |
|                                    ^         |       |               role_rank,        |
|                                    |         |       |               world_size,       |
|                                    |         |       |               role_world_size)  |
+----------------------------------------------+       |                                 |
                                     |                 +---------------+-----------------+
                                     |                                 |
                                     |                                 | 8
                                     |              9                  v
                                     +-----------------------+   _start_workers
                                                                       +
                                                                       | 10
                                                                       |
                                                                       v
                                                       +---------------+--------------+
                                                       | state = _monitor_workers     |
                                                  +--> |                              +-->
                                                  |    +---------------+--------------+  |
                                                  |                    |                 |
                                                  <--------------------------------------+
                                                     LOOP  Every 30S   |
                                                                       | 11
                                                                       v
                                                                    _exit_barrier

手機如下:

0xFF 參考

TorchElastic - 彈性、容錯的分散式訓練

相關文章