[原始碼解析] PyTorch 分散式之彈性訓練(7)---節點變化
0x00 摘要
本文分析如何處理節點變化。即對成員更改作出反應,並使用新的成員來重啟所有workers,從而實現彈性訓練。
總體思路是和當工作程式失敗時的處理一樣:相應elastic agent將殺死該節點上的所有工作程式,與其他代理建立會合(rendezvous),並使用新的會合(rendezvous)資訊重新啟動所有工作程式。
彈性訓練系列文章如下:
[原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路
[原始碼解析] PyTorch 分散式之彈性訓練(2)---啟動&單節點流程
[原始碼解析] PyTorch 分散式之彈性訓練(3)---代理
[原始碼解析] PyTorch 分散式之彈性訓練(4)---Rendezvous 架構和邏輯
[原始碼解析] PyTorch 分散式之彈性訓練(5)---Rendezvous 引擎
[原始碼解析] PyTorch 分散式之彈性訓練(6)---監控/容錯
0x01 變化方式
節點變化有兩點方式。
1.1 Scale-down
節點離開(scale-down)的處理如下:
- 當Scale down事件發生時,rendezvous將不會通知 torchelastic agent。
- torchelastic agent 自己會監控到有程式錯誤,從而進行處理。
- 如果TE agent以
max_restarts=0
配置啟動,它依賴於底層排程程式來處理作業重新啟動。 - 如果
max_restarts>0
,TE代理將終止workers並開始新一輪rendezvous。- 代理得到離開的通知,於是現有workers(所有節點上的)都全部停止。
- 這些workers將形成一個新的“WorkerGroup”,所有worker都將以新的
RANK
和WORLD_SIZE
執行。
1.2 Scale-up
節點加入(scale-up)的處理如下:
- 當Scale up事件發生時,新節點被提交到作業,torchelastic rendezvous將檢測到有新節點試圖加入。
- 如果rendezvous已經達到最多節點數,新節點將不會新增到等待列表,因為已經滿了,所以沒有必要拆除已經完全體的rendezvous。新節點將一直等待直到超時(預設為600秒)。
- 新節點將定期檢查參與節點數目。如果數目變為小於max_nodes,等待節點將被加入到等待列表中。否則它將在600秒之後超時。
- 當代理決定處理 Scale up時:
- torchelastic rendezvous將停止所有workers並執行新一輪的 re-rendezvous。
- 這些workers(現有以及新加入的)將形成一個新的“WorkerGroup”,所有worker都將以新的
RANK
和WORLD_SIZE
執行。
注:scale up發生時,max_restarts
將不會減少。
0x02 節點加入
2.1 新節點加入
假設目前已經有了一個彈性訓練叢集正在執行,彈性區間為 (min=1
, max=4
)。目前已經有2個節點在執行,使用者想啟動第三個節點,於是使用如下方法啟動一個新程式。
python -m torch.distributed.run
--nnodes=1:4
--nproc_per_node=$NUM_TRAINERS
--rdzv_id=$JOB_ID
--rdzv_backend=c10d
--rdzv_endpoint=$HOST_NODE_ADDR
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
新程式會啟動一個代理。代理經過一系列操作,呼叫 next_rendezvous,其中啟動一個 ExitOp,一個 JoinOp 。
def next_rendezvous(self) -> Tuple[Store, int, int]:
exit_op = _RendezvousExitOp()
join_op = _RendezvousJoinOp()
self._op_executor.run(exit_op, deadline)
self._op_executor.run(join_op, deadline)
2.2 處理 Join 操作
以下操作是在 _DistributedRendezvousOpExecutor 之中。
有了前文分析,我們知道,業務流程是 run 呼叫 Join 運算元來分析出來下一個 Action,然後根據 Action 來執行對應的業務操作。
2.2.1 run處理
_DistributedRendezvousOpExecutor.run
函式實現了基礎邏輯,就是依據 action 型別進行各種操作。對於我們示例,state_handler 就是_RendezvousJoinOp。
def run(
self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
"""See base class."""
action = None
while action != _Action.FINISH: # 一直迴圈,直到結束
# 這裡很重要,在所有node之間做資訊同步
has_set = self._state_holder.sync() # 因為最新狀態在 rendezvous。
self._state = self._state_holder.state
# 利用最新狀態構建了 ctx
ctx = _RendezvousContext(self._node, self._state, self._settings)
# Determine the next action to take based on the current state of
# the rendezvous.
action = state_handler(ctx, deadline) # 呼叫_RendezvousJoinOp,決定下一個操作
# 省略後續部分
2.2.2 Join操作
因為之前做了同步,所以這裡的ctx就包括了最新的state,這就是Rendezvous的全域性狀態。因為此時,Rendezvous 已經結束了,所以 state 的狀態是 complete,進入如下流程,返回 _Action.ADD_TO_WAIT_LIST。
if state.complete:
# If we are here, it means we are not part of the rendezvous. In
# case the rendezvous has capacity for additional participants add
# ourself to the wait list for the next round.
if len(state.participants) < ctx.settings.max_nodes: # 如果當前節點數目小於最大配置
if ctx.node not in state.wait_list: # 如果當前node不在等待列表之中
return _Action.ADD_TO_WAIT_LIST # 傳送一個等待action
總體程式碼如下:
class _RendezvousJoinOp:
"""Represents a rendezvous join operation."""
def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
state = ctx.state # 從上下文之中提取 _RendezvousState 狀態
# A closed rendezvous means that it no longer accepts new nodes.
if state.closed:
return _Action.ERROR_CLOSED # 如果已經結束,就返回 _Action.ERROR_CLOSED
is_participant = ctx.node in state.participants # 看看是參與者
# If we are part of the rendezvous and it is already complete there is
# no further action to take.
if state.complete and is_participant: # 如果是參與者且狀態結束,就返回 _Action.FINISH
return _Action.FINISH
now = time.monotonic()
if now > deadline: # 如果已經超時
rollback_period = 5 # 5 seconds
# If we still have time to rollback (a short period on top of the
# operation deadline), try to remove ourself from the rendezvous.
# It is okay if we can't though as our keep-alive will eventually
# expire.
if now <= deadline + rollback_period: # 如果還有時間來 rollback
# If we are part of the rendezvous, it means we couldn't find
# enough participants to complete it on time.
if is_participant: # 已經是參與者了
return _Action.REMOVE_FROM_PARTICIPANTS # 需要從參與者列表移除
# If we are in the wait list, it means we couldn't wait till the
# next round of the rendezvous.
if ctx.node in state.wait_list: # 已經在等待列表之中
return _Action.REMOVE_FROM_WAIT_LIST # 需要從等待列表移除
return _Action.ERROR_TIMEOUT # 返回超時
if state.complete: # 如果 rendezvous 已經結束
# If we are here, it means we are not part of the rendezvous. In
# case the rendezvous has capacity for additional participants add
# ourself to the wait list for the next round.
if len(state.participants) < ctx.settings.max_nodes: # 如果還沒有達到最大節點數
if ctx.node not in state.wait_list: # 如果當前node不在等待列表之中
return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,傳送一個等待action
elif is_participant: # 如果已經在參與者列表
# If the rendezvous has enough number of participants including us,
# check whether we have passed the rendezvous deadline. If yes,
# complete it.
if len(state.participants) >= ctx.settings.min_nodes: # 如果達到了最小節點數
if cast(datetime, state.deadline) < datetime.utcnow(): # 如果達到了超時
return _Action.MARK_RENDEZVOUS_COMPLETE # 標示 rendezvous 已經結束
else: # 否則就直接加入到參與者
# The rendezvous is not complete yet and we are not part of it. Try
# to join.
return _Action.ADD_TO_PARTICIPANTS
if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
return _Action.KEEP_ALIVE
# At this point either the rendezvous is not complete, but we are part
# of it, which means we have to wait for other participants to join; or
# the rendezvous is complete, but we are not part of it, which means we
# have to wait for the next round.
return _Action.SYNC # 否則返回同步狀態 _Action.SYNC
2.2.3 等待業務操作
_DistributedRendezvousOpExecutor 之中,run 函式實現了基礎邏輯,就是依據 action 型別進行各種操作。
def run(
self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
"""See base class."""
action = None
while action != _Action.FINISH: # 一直迴圈,直到結束
# 這裡很重要,在所有node之間做資訊同步
has_set = self._state_holder.sync() # 因為最新狀態在 rendezvous。
self._state = self._state_holder.state
# 使用最新state構建ctx
ctx = _RendezvousContext(self._node, self._state, self._settings)
# Determine the next action to take based on the current state of
# the rendezvous.
action = state_handler(ctx, deadline) # 呼叫_RendezvousJoinOp,決定下一個操作,這裡得到了 _Action.ADD_TO_WAIT_LIST
if action == _Action.SYNC:
_delay(seconds=1)
else:
if action == _Action.KEEP_ALIVE:
self._keep_alive()
elif action == _Action.ADD_TO_WAIT_LIST: # 從 Join 運算元得到了_Action.ADD_TO_WAIT_LIST
self._add_to_wait_list() # 進行業務邏輯
# 省略其他action
# Attempt to sync our changes back to other nodes.
self._state_holder.mark_dirty() # 同步回其他節點
具體處理等待操作就是加入到等待列表。
def _add_to_wait_list(self) -> None:
self._state.wait_list.add(self._node)
self._keep_alive()
我們回憶一下 _RendezvousState
。_RendezvousState 是rendezvous的狀態。是動態資訊。
- round:Rendezvous的當前輪次
- complete:一個布林值,指示rendezvous當前一輪是否完成了。
- deadline:截止時間,如果如果當前輪次一直在等待節點加入,如果這個引數設定了,就是等待的截至時間。
- closed:一個布林值,指示rendezvous是否結束了。
- participants:字典,存放參與者和它們對應ranks。
- wait_list:set結構,存放等待參與下一輪rendezvous操作的一組節點。
- last_heartbeats:字典,包含每個節點上次心跳時間。
class _RendezvousState:
round: int
complete: bool
deadline: Optional[datetime]
closed: bool
participants: Dict[_NodeDesc, int] # 參與者,未來會用到的成員變數
wait_list: Set[_NodeDesc] # 等待者,這裡用到的成員變數
last_heartbeats: Dict[_NodeDesc, datetime]
def __init__(self) -> None:
self.round = 0
self.complete = False
self.deadline = None
self.closed = False
self.participants = {}
self.wait_list = set() # 這裡用到的成員變數
self.last_heartbeats = {}
目前邏輯如下:
- 啟動一個新 worker。此時下圖右側上方的 _RendezvousState 之中,wait_list 為空。
- 呼叫 next_rendezvous,發起新一輪 rendezvous。
- _RendezvousJoinOp 內部執行,生成 ADD_TO_WAIT_LIST。
- executor . run 內部執行 _add_to_wait_list。
- 往 wait_list 新增一個新的 node。此時下圖右側上方的 _RendezvousState 之中,wait_list 多了一個 1。
python -m torch.distributed.run +-------------------------+ +
--nnodes=xxx TRAINING_SCRIPT.py | _RendezvousState | |
+ | | |
| | participants = [1,2] | |
| 1 | | |
v | wait_list = [] | |
next_rendezvous | | |
+ +------------+------------+ |
| 2 | |
| | |
v | |
+----------------+-----------------------+ | |
| _op_executor.run(_RendezvousJoinOp) | | |
| + + | | |
| | | 3 | | |
| | | | | |
| | v | | |
| | _Action.ADD_TO_WAIT_LIST | v |
| | + | |
| | | | +--------------------------+ |
| +<-------------+ | | _RendezvousState | |
| | | | | |
| | | | participants = [1,2] | |
| v 4 | 5 | | |
| self._add_to_wait_list() +----------------> wait_list = [3] | |
| | | | |
+----------------------------------------+ +--------------------------+ |
|
v
Timeline
2.3 Agent 處理
_DistributedRendezvousOpExecutor . run 處理之後,操作回到了代理之中。代理主迴圈之中,程式會進入 while 迴圈,然後通過 _monitor_workers 定期輪訓使用者程式執行情況,依據情況作出判斷。
def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
# NOTE: currently only works for a single role
spec = self._worker_group.spec
role = spec.role
self._initialize_workers(self._worker_group) # 啟動worker
monitor_interval = spec.monitor_interval
rdzv_handler = spec.rdzv_handler
while True:
assert self._worker_group.state != WorkerState.INIT
# 定期監控
time.sleep(monitor_interval)
# 監控客戶程式執行情況
run_result = self._monitor_workers(self._worker_group)
state = run_result.state # 程式執行情況
self._worker_group.state = state
if state == WorkerState.SUCCEEDED:
# 程式正常結束
self._exit_barrier()
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
# 程式出錯
if self._remaining_restarts > 0: # 重試
self._remaining_restarts -= 1
self._restart_workers(self._worker_group)
else:
self._stop_workers(self._worker_group) # 重試次數達到,結束workers
self._worker_group.state = WorkerState.FAILED
self._exit_barrier()
return run_result
elif state == WorkerState.HEALTHY:
# 程式正常執行
# 節點成員關係有變化,比如scale up
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
# 如果有新的節點在waiting,就重啟所有workers
if num_nodes_waiting > 0:
self._restart_workers(self._worker_group)
else:
raise Exception(f"[{role}] Worker group in {state.name} state")
所以,代理定期執行 _monitor_workers 監控worker執行情況才是關鍵。run_result.state 是程式執行情況,當狀態是 WorkerState.HEALTHY,說明原有程式正常執行,接下來看看節點成員關係是否有變化。
呼叫 rdzv_handler.num_nodes_waiting() 拿到等待列表數目,如果有新的節點在waiting,就說明有新的節點試圖加入叢集,這時就會發生一個Re-rendezvous。代理將重啟所有workers。重啟時候,會把等待列表中的節點加入到參與列表之中。我們依次看看如何處理。
2.3.1 檢查等待列表
處理時候,首先會呼叫 num_nodes_waiting 看看還有多少節點在等待,具體是看看 state.wait_list 的長度。我們通過之前 Join 操作知道,如果有新節點,會插入到這個列表之中。
num_nodes_waiting 方法的作用是 返回在 rendezvous barrier 上等待的節點數目(這些節點不會在當前工作組被包括)。呼叫者應該週期呼叫這個方法,來確定是否有新節點等候加入當前工作組,因此需要呼叫next_rendezvous()
來提交他們。
def num_nodes_waiting(self) -> int:
"""See base class."""
with self._heartbeat_lock:
self._state_holder.sync()
return len(self._state_holder.state.wait_list)
目前邏輯如下:
- 啟動一個新 worker。
- 呼叫 next_rendezvous,發起新一輪 rendezvous。
- _RendezvousJoinOp 內部執行,生成 ADD_TO_WAIT_LIST。
- executor.run 內部執行 _add_to_wait_list。
- 往 wait_list 新增一個新的 node。
- Agent 之中,定期(比如 30S)執行一次 _monitor_workers,獲取worker 子程式狀態。
- 如果是 HEALTHY,則呼叫num_nodes_waiting 獲取 wait_list 個數。
- 如果 wait_list 之中等待節點數目大於 0,則:
- 呼叫 _restart_workers 重啟程式組。
python -m torch.distributed.run +-------------------------+ +
--nnodes=xxx TRAINING_SCRIPT.py | _RendezvousState | |
+ | | |
| | participants = [1,2] | |
| 1 | | |
v | wait_list = [] | |
next_rendezvous | | |
+ +------------+------------+ |
| 2 | |
| | |
v | |
+----------------+-----------------------+ | |
| _op_executor.run(_RendezvousJoinOp) | | |
| + + | | |
| | | 3 | | |
| | | | | |
| | v | | |
| | _Action.ADD_TO_WAIT_LIST | v |
| | + | |
| | | | +--------------------------+ |
| +<-------------+ | | _RendezvousState | |
| | | | | |
| | | | participants = [1,2] | |
| v 4 | 5 | | |
| self._add_to_wait_list() +----------------> wait_list = [3] | |
| | | | |
+----------------------------------------+ +------------+-------------+ |
| |
+----------------------------------------+ | |
| agent._invoke_run | | |
| | | |
| | | |
| _monitor_workers Every 30S | | |
| + | | |
| | 6 | | |
| | | v |
| v | |
| WorkerState.HEALTHY | +--------------------------+ |
| + | | _RendezvousState | |
| | | | | |
| | 7 | | participants = [1,2] | |
| v | 8 | | |
| num_nodes_waiting <--------------------> wait_list = [3] | |
| + | | | |
| | 9 | | | |
| | | +--------------------------+ |
| v | |
| _restart_workers | v
| |
+----------------------------------------+ Timeline
2.3.3 重啟worker組
如果等待列表之中有節點,就會重啟workers。我們走一下這個流程。
@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
"""
Restarts (stops, rendezvous, starts) all local workers in the group.
"""
role = worker_group.spec.role
self._stop_workers(worker_group)
worker_group.state = WorkerState.STOPPED
self._initialize_workers(worker_group)
2.3.3.1 _stop_workers
首先會停止目前 workers,程式碼在torch/distributed/elastic/agent/server/local_elastic_agent.py。
@prof
def _stop_workers(self, worker_group: WorkerGroup) -> None:
self._shutdown()
2.3.3.2 _shutdown
_shutdown 就是讓上下文關閉。
def _shutdown(self) -> None:
if self._pcontext:
self._pcontext.close()
2.3.3.3 關閉上下文
在 MultiprocessContext 之中,close 方法是關閉所有子程式,然後等待其全部停止。
def _close(self) -> None:
if self._pc:
for proc in self._pc.processes:
proc.terminate()
proc.join()
2.3.3.4 _initialize_workers
當關閉了所有當前執行的子程式之後,會重新全部初始化。
@prof
def _initialize_workers(self, worker_group: WorkerGroup) -> None:
r"""
Starts a fresh set of workers for the worker_group.
Essentially a rendezvous followed by a start_workers.
The caller should first call ``_stop_workers()`` to stop running workers
prior to calling this method.
Optimistically sets the state of the worker group that
just started as ``HEALTHY`` and delegates the actual monitoring
of state to ``_monitor_workers()`` method
"""
role = worker_group.spec.role
# TODO after stopping workers, wait at least monitor_interval*2 for
# workers on different nodes to fail on a collective op before waiting
# on the rdzv barrier, this way we ensure that nodes enter rdzv
# at around the same time and reduce false positive rdzv timeout errors
self._rendezvous(worker_group)
worker_ids = self._start_workers(worker_group)
for local_rank, w_id in worker_ids.items():
worker = worker_group.workers[local_rank]
worker.id = w_id
worker_group.state = WorkerState.HEALTHY
_rendezvous經過一系列操作,呼叫 next_rendezvous,在其中啟動一個 ExitOp,一個 JoinOp 。
def next_rendezvous(self) -> Tuple[Store, int, int]:
exit_op = _RendezvousExitOp()
join_op = _RendezvousJoinOp()
self._op_executor.run(exit_op, deadline)
self._op_executor.run(join_op, deadline)
2.3.3.5 _RendezvousJoinOp
我們又回來了,這是新一輪 Rendezvous 操作。_DistributedRendezvousOpExecutor
之中,run 函式實現了基礎邏輯,就是依據 action 型別進行各種操作。對於我們示例,state_handler
就是_RendezvousJoinOp
。
def run(
self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
"""See base class."""
action = None
while action != _Action.FINISH:
# Reads or writes the latest rendezvous state shared by all nodes in
# the rendezvous. Note that our local changes might get overridden
# by another node if that node synced its changes before us.
has_set = self._state_holder.sync()
self._state = self._state_holder.state
ctx = _RendezvousContext(self._node, self._state, self._settings)
# Determine the next action to take based on the current state of
# the rendezvous.
# 呼叫到_RendezvousJoinOp,大家可以過一下 _RendezvousJoinOp 程式碼,發現此時將返回 ADD_TO_PARTICIPANTS
action = state_handler(ctx, deadline)
if action == _Action.SYNC:
# Delay the execution by one second to avoid overloading the
# backend if we are asked to poll for state changes.
_delay(seconds=1)
else:
if action == _Action.KEEP_ALIVE:
self._keep_alive()
elif action == _Action.ADD_TO_PARTICIPANTS: # 執行到這裡
self._add_to_participants()
elif action == _Action.ADD_TO_WAIT_LIST:
self._add_to_wait_list()
elif action == _Action.REMOVE_FROM_PARTICIPANTS:
self._remove_from_participants()
elif action == _Action.REMOVE_FROM_WAIT_LIST:
self._remove_from_wait_list()
elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
self._mark_rendezvous_complete()
elif action == _Action.MARK_RENDEZVOUS_CLOSED:
self._mark_rendezvous_closed()
# Attempt to sync our changes back to other nodes.
self._state_holder.mark_dirty()
這次會生成 ADD_TO_PARTICIPANTS。
class _RendezvousJoinOp:
"""Represents a rendezvous join operation."""
def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
state = ctx.state # 從上下文之中提取 _RendezvousState 狀態
# A closed rendezvous means that it no longer accepts new nodes.
if state.closed:
return _Action.ERROR_CLOSED # 如果已經結束,就返回 _Action.ERROR_CLOSED
is_participant = ctx.node in state.participants # 看看是參與者
# If we are part of the rendezvous and it is already complete there is
# no further action to take.
if state.complete and is_participant: # 如果是參與者且狀態結束,就返回 _Action.FINISH
return _Action.FINISH
now = time.monotonic()
if now > deadline: # 如果已經超時
rollback_period = 5 # 5 seconds
# If we still have time to rollback (a short period on top of the
# operation deadline), try to remove ourself from the rendezvous.
# It is okay if we can't though as our keep-alive will eventually
# expire.
if now <= deadline + rollback_period: # 如果還有時間來 rollback
# If we are part of the rendezvous, it means we couldn't find
# enough participants to complete it on time.
if is_participant: # 已經是參與者了
return _Action.REMOVE_FROM_PARTICIPANTS # 需要從參與者列表移除
# If we are in the wait list, it means we couldn't wait till the
# next round of the rendezvous.
if ctx.node in state.wait_list: # 已經在等待列表之中
return _Action.REMOVE_FROM_WAIT_LIST # 需要從等待列表移除
return _Action.ERROR_TIMEOUT # 返回超時
if state.complete: # 如果 rendezvous 已經結束
# If we are here, it means we are not part of the rendezvous. In
# case the rendezvous has capacity for additional participants add
# ourself to the wait list for the next round.
if len(state.participants) < ctx.settings.max_nodes: # 如果還沒有達到最大節點數
if ctx.node not in state.wait_list: # 如果當前node不在等待列表之中
return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,傳送一個等待action
elif is_participant: # 如果已經在參與者列表
# If the rendezvous has enough number of participants including us,
# check whether we have passed the rendezvous deadline. If yes,
# complete it.
if len(state.participants) >= ctx.settings.min_nodes: # 如果達到了最小節點數
if cast(datetime, state.deadline) < datetime.utcnow(): # 如果達到了超時
return _Action.MARK_RENDEZVOUS_COMPLETE # 標示 rendezvous 已經結束
else: # 否則就直接加入到參與者
# The rendezvous is not complete yet and we are not part of it. Try
# to join.
return _Action.ADD_TO_PARTICIPANTS
if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
return _Action.KEEP_ALIVE
# At this point either the rendezvous is not complete, but we are part
# of it, which means we have to wait for other participants to join; or
# the rendezvous is complete, but we are not part of it, which means we
# have to wait for the next round.
return _Action.SYNC # 否則返回同步狀態 _Action.SYNC
2.3.3.6 _add_to_participants
引擎收到 ADD_TO_PARTICIPANTS 之後,會呼叫 _add_to_participants 從 wait_list 移除節點,插入到 participants。
def _add_to_participants(self) -> None:
log.debug(
f"The node '{self._node}' added itself to the participants of round "
f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
)
state = self._state
state.wait_list.remove(self._node) # 移除節點
# The ranks of the participants will be set once the rendezvous is
# complete.
state.participants[self._node] = 0 # 重新插入
self._keep_alive()
if len(state.participants) == self._settings.min_nodes:
state.deadline = datetime.utcnow() + self._settings.timeout.last_call
if len(state.participants) == self._settings.max_nodes:
self._mark_rendezvous_complete()
我們這次從 _restart_workers 開始繪製。
- 呼叫 _stop_workers 來關閉worker子程式。此時下圖右側上方 _RendezvousState之中,participants=[1,2]。
- 通過 MultiprocessContext.close() 完成關閉操作。
- 通過 _initialize_workers 重新初始化 worker。
- 呼叫 next_rendezvous 完成新的同步操作。
- _RendezvousJoinOp 這次返回ADD_TO_PARTICIPANTS。
- 呼叫 _add_to_participants 進行狀態切換。
- wait_list 之中的Node被移動到 participants。此時下圖右側上方 _RendezvousState之中,participants=[1,2,3]。
+-----------------------------+ +------------------------+ |
| agent._invoke_run | | _RendezvousState | |
| | | | |
| _restart_workers | | participants = [1,2] | |
| + | | | |
+----------------------+ | | | | wait_list = [3] | |
| MultiprocessContext | | | 1 | | | |
| | | 2 v | +------------------------+ |
| close() <-----------+ _stop_workers | |
| | | + | |
+----------------------+ | | | |
| | 3 | |
| v | |
| _initialize_workers | |
| + | |
| | | |
+-----------------------------+ |
| |
| 4 |
v |
next_rendezvous |
+ |
| |
v |
+---------------------------+---------------+ |
| _op_executor.run(_RendezvousJoinOp) | |
| + + | |
| | | | |
| | | 5 | |
| | v | |
| | ADD_TO_PARTICIPANTS | |
| | + | +-----------------------+ |
| | | | | _RendezvousState | |
| | <-------------+ | | | |
| | | | participants = [1,2,3]| |
| v 6 7 | | | |
| _add_to_participants +--------------> | wait_list = [] | |
| | | | |
+-------------------------------------------+ +-----------------------+ v
Timeline
0x03 節點離開
3.1 處理機制
節點離開(scale-down)的處理如下:
- 當Scale down事件發生時,rendezvous將不會通知 torchelastic agent。
- 如果TE agent以“max_restarts=0”啟動,它依賴於底層排程程式來處理作業重新啟動。
- 如果“max_restarts>0”,TE代理將終止workers並開始新一輪rendezvous。
- 代理得到離開的通知,於是現有workers(所有節點上)都全部停止。
- 這些workers將形成一個新的“WorkerGroup”,所有worker都將以新的
RANK
和WORLD_SIZE
執行。、
3.2 如何模擬
如果想模擬除錯的同學,可以在 test/distributed/elastic/agent/server/test/local_elastic_agent_test.py 之中找到示例程式碼。
def test_double_agent_elastic(self):
"""
start ``nnodes`` agents, kill odd ones (do not restart), validate
elasticity (scale-down) works. (scale-up covered in fault_tolerance test)
"""
min_nodes = 1
max_nodes = 2
wait = 2
node_conf = Conf(entrypoint=_dist_sum, args=(wait,), local_world_size=2)
agent_results = mp.Queue()
agent_args = {
"conf": node_conf,
"agent_results": agent_results,
"min_nodes": min_nodes,
"max_nodes": max_nodes,
"max_restarts": 2,
}
procs = []
for _ in range(max_nodes):
p = mp.Process(
target=self.run_agent,
kwargs=agent_args,
)
procs.append(p)
p.start()
# kill odd agents
for i in range(max_nodes):
if i % 2 != 0:
procs[i].kill()
for i in range(max_nodes):
p = procs[i]
p.join()
if i % 2 == 0:
self.assertEqual(0, p.exitcode)
else:
self.assertEqual(-signal.SIGKILL, p.exitcode)
3.3 如何處理
節點離開,與錯誤處理是同一個程式碼。錯誤處理程式碼如下,如果重試尚未達到最大次數,則試圖重啟workers。如果已經達到了最大次數,則停止 workers。
def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
# 省略
while True:
# 定期監控
time.sleep(monitor_interval)
# 監控客戶程式執行情況
run_result = self._monitor_workers(self._worker_group)
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
# 程式出錯
if self._remaining_restarts > 0: # 重試
self._remaining_restarts -= 1
self._restart_workers(self._worker_group) # 進行重啟
else:
self._stop_workers(self._worker_group) # 重試次數達到,結束workers
self._worker_group.state = WorkerState.FAILED
self._exit_barrier()
return run_result
3.3.1 重啟
_restart_workers 會停掉所有 workers,然後重新一輪 rendezvous 。
@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
"""
Restarts (stops, rendezvous, starts) all local workers in the group.
"""
role = worker_group.spec.role
self._stop_workers(worker_group)
worker_group.state = WorkerState.STOPPED
self._initialize_workers(worker_group)
3.3.2 停止
停止 workers 就是關閉上下文。
def _shutdown(self) -> None:
if self._pcontext:
self._pcontext.close()
@prof
def _stop_workers(self, worker_group: WorkerGroup) -> None:
self._shutdown()
在 MultiprocessContext 之中,close 方法是關閉所有子程式,然後等待其全部停止。
def _close(self) -> None:
if self._pc:
for proc in self._pc.processes:
proc.terminate()
proc.join()
流程圖如下:
- 監控子程式狀態。
- 發現 UNHEALTHY 或者 FAILED,看看重啟次數是否還有。我們假定是3號程式失敗。
- 如果沒有,就呼叫 _stop_workers 結束子程式。
- 呼叫 MultiprocessContext.close 進行具體結束操作。
- 如果還可以重啟,呼叫_restart_workers。
- 呼叫 _stop_workers 結束子程式。
- 呼叫 MultiprocessContext.close 進行具體結束操作。
- 呼叫 _initialize_workers 重新初始化worker。
- 呼叫 next_rendezvous 重新同步。
- 進行後續操作。
+
+-------------------------------------------+ +---------------------------+ |
| agent._invoke_run | | _RendezvousState | |
| | | | |
| | | | |
| _monitor_workers Every 30S | | participants = [1,2,3] | |
| + | | | |
| | 1 | | wait_list = [ ] | |
| | | | | |
| v | +---------------------------+ |
| WorkerState.UNHEALTHY,FAILED | |
| + | |
| | | |
| | 2 | |
| v | |
| self._remaining_restarts > 0 ? +--+ | |
| + | | |
| 5 | YES NO | 3 | |
| | | | |
| v v | +----------------------+ |
| _restart_workers _stop_workers | | MultiprocessContext | |
| + + | | | |
| | 6 | 4 | | | |
| | +--------> | | |
| v | | close() | |
| _stop_workers +-------------------------> | | |
| + 7 | +----------------------+ |
| | | |
| | 8 | |
| v | |
| _initialize_workers | |
| + | |
| | | |
+-------------------------------------------+ |
| 9 |
| |
v +--------------------------+ |
next_rendezvous | _RendezvousState | |
+ | | |
| 10 | participants = [1,2] | |
+----------------------------> | | |
| | wait_list = [ ] | v
| 10 +--------------------------+
v Timeline
至此,彈性訓練全部分析完畢,或者說PyTorch分散式分析就告一段落,我們下文會介紹其他框架/庫的分散式實現,敬請期待。
0xFF 參考
[原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路
[原始碼解析] PyTorch 分散式之彈性訓練(2)---啟動&單節點流程
[原始碼解析] PyTorch 分散式之彈性訓練(3)---代理