[原始碼解析] PyTorch 分散式之彈性訓練(5)---Rendezvous 引擎
0x00 摘要
在前面的文章之中,我們已經學習了PyTorch 分散式的基本模組,介紹了官方的幾個例子,我們接下來會介紹PyTorch的彈性訓練,本文是第五篇,看看Rendezvous 的內部引擎,比如如何處理節點加入,節點離開,等待,心跳等等。
彈性訓練系列文章如下:
[原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路
[原始碼解析] PyTorch 分散式之彈性訓練(2)---啟動&單節點流程
[原始碼解析] PyTorch 分散式之彈性訓練(3)---代理
[原始碼解析] PyTorch 分散式之彈性訓練(4)---Rendezvous 架構和邏輯
0x01 前言
1.1 總體系統
彈性訓練可以理解為在 Rendezvous 基礎之上的一個執行系統。
-
Agent 偏重具體節點上的邏輯
- Agent 負責具體業務邏輯相關操作,比如啟動程式執行使用者程式,監控使用者程式執行情況,如果有異常就通知 Rendezvous。
- Agent 是一個 worker manager,負責啟動/管理 workers 程式,組成一個 worker group,監控 workers 執行狀態,捕獲失效 workers,如果有故障/新加入worker,則重啟 worker group。
- Agent負責維護 WORLD_SIZE 以及 RANK 資訊。使用者不需要再手動提供,Agent會自動處理這些。
- Agent 是具體節點上的後臺程式,是獨立個體。Agent自己無法實現整體上的彈性訓練,所以需要一個機制來完成 worker 之間的相互發現,變更同步等等(WORLD_SIZE 和 RANK 這些資訊其實也需要多個節點同步才能確定),這就是下面的 Rendezvous 概念。
-
Rendezvous 負責
叢集邏輯
,保證節點之間對於""有哪些節點參與訓練"達成強一致共識。
- 每一個 Agent 內部包括一個 Rendezvous handler,這些 handler 總體上構成了一個 Rendezvous 叢集,從而構成了一個 Agent 叢集。
- Rendezvous 完成之後,會建立一個共享鍵值儲存(shared key-value store),這個store實現了一個
torch.distributed.Store
API。此儲存僅由已完成Rendezvous的成員共享,它旨在讓Torch Distributed Elastic在初始化作業過程之中交換控制和資料資訊。 - Rendezvous 負責在每個agent之上維護當前 group 所有相關資訊。每個 agent 之上有一個 rendezvous,它們會互相通訊,總體維護一套資訊,這些資訊儲存在上面提到的Store 之中。
- Rendezvous 負責叢集邏輯相關,比如新加入節點,移除節點,分配rank等等。
1.2 Rendezvous
目前為止,Rendezvous 資訊如下,DynamicRendezvousHandler 屬於動態邏輯,其中,_RendezvousStateHolder
是狀態等元資訊儲存(靜態結構),大家會發現圖中還有一個 _RendezvousOpExecutor 沒有介紹,這就是執行時引擎,所以我們本文看看 _RendezvousOpExecutor 如何處理。
+-----------------------------+ +------------------------------------------------+
| LocalElasticAgent | | WorkerSpec |
| | | |
| +------------------------+ | | rdzv_handler = {DynamicRendezvousHandler} -------+
| |WorkerGroup | | | | |
| | spec +--------------> | entry = worker_fn | |
| | workers | | | | |
| | store | | | role = {str} 'trainer' | |
| | group_rank | | | | |
| | group_world_size | | +------------------------------------------------+ |
| | | | |
| +------------------------+ | |
| | |
| rdzv_run_id | |
| store | +-----------------------------------------+ |
| | |DynamicRendezvousHandler | |
+-----------------------------+ | | |
| | |
| _settings: RendezvousSettings | <--+
| |
| _store: Store |
| |
| _state_holder: _RendezvousStateHolder |
| |
| _op_executor: _RendezvousOpExecutor |
| |
+-----------------------------------------+
1.3 解耦
_RendezvousOpExecutor 把功能分割解耦:
- 業務邏輯被抽象成為一系列運算元,比如
_RendevzousJoinOp
。 - Rendezvous 內部維護了一套由業務函式組成的狀態機,比如函式 _add_to_participants 用來新增參與者。
_RendezvousOpExecutor
引擎來執行各種運算元,依據運算元結果,得到一個 Action,再利用 Action 呼叫業務函式進行操作。
本文主要介紹C10d 後端對應的 Rendezvous 引擎。
0x02 引擎實現
2.1 基類
_RendezvousOpExecutor 是引擎的基類,只是定義了run這個虛擬函式。
class _RendezvousOpExecutor(ABC):
"""Executes rendezvous operations."""
@abstractmethod
def run(
self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
"""Executes a rendezvous operation.
An operation is run inside a state machine and is expected to transition
the rendezvous from one state to another.
Args:
state_handler:
A callable that is expected to return the next state transition
action based on the current state of the rendezvous.
deadline:
The time, in seconds, at which the operation will be considered
timed-out.
"""
這裡用到了 _RendezvousContext,其作用是把 Rendezvous 的各種資訊封裝了起來,提供給操作引擎。這裡就有了 _RendezvousState 和 RendezvousSettings 的使用。
class _RendezvousContext:
"""Holds the context of the rendezvous.
Attributes:
node:
The node descriptor associated with the current rendezvous handler
instance.
state:
The current state of the rendezvous.
settings:
The rendezvous settings.
"""
node: _NodeDesc
state: _RendezvousState
settings: RendezvousSettings
def __init__(
self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings
) -> None:
self.node = node
self.state = state
self.settings = settings
2.2 分散式操作引擎
_DistributedRendezvousOpExecutor 擴充了 _RendezvousOpExecutor,是 ElasticTorch 的實際執行者。類似於 Looper,負責訊息分發,呼叫業務,狀態維護。
2.2.1 定義
與其基類相比,_DistributedRendezvousOpExecutor 加入了比如節點資訊,狀態,配置這樣的成員變數。
class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor):
"""Executes rendezvous operations using a shared state.
Args:
node:
The node descriptor associated with the current rendezvous handler
instance.
state_holder:
The ``RendezvousStateHolder`` to use to sync the rendezvous state
with other nodes.
settings:
The rendezvous settings.
"""
_node: _NodeDesc
_state: _RendezvousState
_state_holder: _RendezvousStateHolder
_settings: RendezvousSettings
def __init__(
self,
node: _NodeDesc,
state_holder: _RendezvousStateHolder,
settings: RendezvousSettings,
) -> None:
self._node = node
self._state_holder = state_holder
self._settings = settings
邏輯如下:
+---------------------------------------------------------------+
| _DistributedRendezvousOpExecutor |
| |
| +------------------------+ |
| _state +---> | _RendezvousState | |
| | | |
| | participants | |
| | wait_list | |
| | last_heartbeats | |
| | deadline | |
| +------------------------+ |
| |
| +-------------------------+ |
| _settings +--> | RendezvousSettings | |
| | | |
| +-------------------------+ |
| |
| +--------------------------------------+ |
| _state_holder +---> | _BackendRendezvousStateHolder | |
| | | |
| | _backend: RendezvousBackend | |
| | _state: _RendezvousState | |
| | _settings: RendezvousSettings | |
| | | |
| +--------------------------------------+ |
| +--------------------------------------+ |
| | _NodeDesc | |
| _node +-------> | fqdn: str | |
| | pid: int | |
| | local_id: int | |
| | | |
| +--------------------------------------+ |
+---------------------------------------------------------------+
2.2.2 呼叫
我們舉出幾個例子來看看如何呼叫引擎,可以看到都是先設定運算元,然後呼叫引擎的run函式。
2.2.2.1 _RendezvousKeepAliveOp
def _keep_alive(self) -> None:
self._heartbeat_lock.acquire()
op = _RendezvousKeepAliveOp() # 設定運算元
deadline = self._get_deadline(self._settings.timeout.heartbeat)
self._op_executor.run(op, deadline) # 呼叫
2.2.2.2 _RendezvousCloseOp
def _close(self) -> None:
op = _RendezvousCloseOp() # 設定運算元
deadline = self._get_deadline(self._settings.timeout.close)
self._op_executor.run(op, deadline) # 呼叫
2.2.2.3 _RendezvousJoinOp
def next_rendezvous(self) -> Tuple[Store, int, int]:
"""See base class."""
self._stop_heartbeats()
# Delay the execution for a small random amount of time if this is our
# first run. This will slightly skew the rendezvous attempts across the
# nodes and reduce the load on the backend.
if self._state_holder.state.round == 0:
_delay(seconds=(0, 0.3))
exit_op = _RendezvousExitOp() # 設定運算元
join_op = _RendezvousJoinOp() # 設定運算元
deadline = self._get_deadline(self._settings.timeout.join)
self._op_executor.run(exit_op, deadline) # 這裡會進行呼叫
self._op_executor.run(join_op, deadline) # 呼叫
self._start_heartbeats()
rank, world_size = self._get_world()
store = self._get_store()
return store, rank, world_size
2.2.3 功能
_DistributedRendezvousOpExecutor 之中,run 函式實現了基礎邏輯,就是依據 action 型別進行各種操作。
2.2.3.1 主體迴圈
run 具體程式碼如下:
def run(
self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
"""See base class."""
action = None
while action != _Action.FINISH: # 迴圈,一直到獲得一個FINISH action 為止
# Reads or writes the latest rendezvous state shared by all nodes in
# the rendezvous. Note that our local changes might get overridden
# by another node if that node synced its changes before us.
# 這裡很重要,在所有node之間做資訊同步
has_set = self._state_holder.sync() # 因為最新狀態在 rendezvous。
self._state = self._state_holder.state
ctx = _RendezvousContext(self._node, self._state, self._settings)
# Determine the next action to take based on the current state of
# the rendezvous.
action = state_handler(ctx, deadline) # 決定下一個操作,state_handler 就是運算元
if action == _Action.FINISH:
continue
if action == _Action.ERROR_CLOSED:
raise RendezvousClosedError()
if action == _Action.ERROR_TIMEOUT:
raise RendezvousTimeoutError()
if action == _Action.SYNC:
# Delay the execution by one second to avoid overloading the
# backend if we are asked to poll for state changes.
_delay(seconds=1)
else:
if action == _Action.KEEP_ALIVE:
self._keep_alive()
elif action == _Action.ADD_TO_PARTICIPANTS:
self._add_to_participants()
elif action == _Action.ADD_TO_WAIT_LIST:
self._add_to_wait_list()
elif action == _Action.REMOVE_FROM_PARTICIPANTS:
self._remove_from_participants()
elif action == _Action.REMOVE_FROM_WAIT_LIST:
self._remove_from_wait_list()
elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
self._mark_rendezvous_complete()
elif action == _Action.MARK_RENDEZVOUS_CLOSED:
self._mark_rendezvous_closed()
# Attempt to sync our changes back to other nodes.
self._state_holder.mark_dirty()
具體如下圖。
+-----------------------------------------+ +---------------------------------------------------------------+
|DynamicRendezvousHandler | | _DistributedRendezvousOpExecutor |
| | | |
| | | +------------------------+ |
| _settings: RendezvousSettings | | _state +---> | _RendezvousState | |
| | | | | |
| | | | participants | |
| _store: Store | | | wait_list | |
| | | | last_heartbeats | |
| | | | deadline | |
| _state_holder: _RendezvousStateHolder | | +------------------------+ |
| | run(_RendezvousJoinOp()) | +-------------------------+ |
| | | _settings +--> | RendezvousSettings | |
| _op_executor +------------------------------------------------> | | | |
| | | +-------------------------+ |
| | | +--------------------------------------+ |
+-----------------------------------------+ | _state_holder +---> | _BackendRendezvousStateHolder | |
| | | |
| | _backend: RendezvousBackend | |
| | _state: _RendezvousState | |
| | _settings: RendezvousSettings | |
| | | |
| +--------------------------------------+ |
| +--------------------------------------+ |
| | _NodeDesc | |
| _node +-------> | fqdn: str | |
| | pid: int | |
| | local_id: int | |
| | | |
| +--------------------------------------+ |
+---------------------------------------------------------------+
手機如下:
2.2.3.2 同步
在 run 函式之中,需要注意的是:在執行各種運算元操作之前,會呼叫 self._state_holder.sync() 在各個 worker 之間進行一個狀態同步,達成共識 (consensus)。
def sync(self) -> Optional[bool]:
"""See base class."""
state_bits: Optional[bytes] = None
token = None
has_set: Optional[bool]
if self._dirty: # 如果本node狀態變化了
has_set = False
state_bits = pickle.dumps(self._state)
# 把自己的狀態設定到backend之中
set_response = self._backend.set_state(state_bits, self._token)
if set_response is not None:
state_bits, token, has_set = set_response
else: # 自己沒變化,只能從後端獲取
has_set = None
if self._cache_duration > 0:
# Avoid overloading the backend if we are asked to retrieve the
# state repeatedly. Try to serve the cached state.
if self._last_sync_time >= max(time.monotonic() - self._cache_duration, 0):
return None
get_response = self._backend.get_state() # 從backend獲取其他節點最新狀態
if get_response is not None:
state_bits, token = get_response
if state_bits is not None:
try:
self._state = pickle.loads(state_bits) # 用後端狀態更新本身的狀態
except pickle.PickleError as exc:
raise RendezvousStateError(
"The rendezvous state is corrupt. See inner exception for details."
) from exc
else:
self._state = _RendezvousState()
if has_set and self._dead_nodes and log.isEnabledFor(logging.DEBUG):
node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes)
msg = (
f"As part of the sync operation the node(s) {node_list} have been removed from the "
f"rendezvous '{self._settings.run_id}' since they had no heartbeat."
)
self._record(message=msg)
self._token = token
self._dirty = False
self._last_sync_time = time.monotonic()
self._sanitize()
return has_set
後端
torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py 之中是對應後端程式碼。
後端這裡使用 store 作為一個集中式儲存,是master。每個 node 是 client,會去master更新自己狀態,並且獲取其他node狀態。這樣所有node就會互通有無,達成共識。這裡也會定期刪除不更新後設資料的clients。
get_state 就是簡單的從 store 提取。
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
base64_state: bytes = self._call_store("get", self._key)
return self._decode_state(base64_state)
set_state 會做一個compare set,其返回new state和是否更新了state。
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""See base class."""
base64_state_str: str = b64encode(state).decode()
if token:
# Shortcut if we know for sure that the token is not valid.
if not isinstance(token, bytes):
result = self.get_state()
if result is not None:
tmp = *result, False
# Python 3.6 does not support tuple unpacking in return
# statements.
return tmp
return None
token = token.decode()
else:
token = self._NULL_SENTINEL
base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str)
state_token_pair = self._decode_state(base64_state)
if state_token_pair is None:
return None
new_state, new_token = state_token_pair
# C10d Store's compare_set method does not offer an easy way to find out
# whether our write attempt was successful. As a brute-force solution we
# perform a bitwise comparison of our local state and the remote state.
return new_state, new_token, new_state == state
_sanitize
_sanitize 方法用來依據其他節點訊息做處理,比如清理故障節點。即,如果上一次的心跳時間超過了一定閾值範圍,則會把這些節點標記為dead_node,並且從 participant或者wait list中清除這些節點。
def _sanitize(self) -> None:
state = self._state
expire_time = datetime.utcnow() - (
self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt
)
# Filter out the dead nodes.
self._dead_nodes = [
node
for node, last_heartbeat in state.last_heartbeats.items()
if last_heartbeat < expire_time
]
participant_removed = False
for dead_node in self._dead_nodes:
del state.last_heartbeats[dead_node] # 移除故障節點
try:
del state.participants[dead_node] # 移除故障節點
participant_removed = True
except KeyError:
pass
try:
state.wait_list.remove(dead_node) # 移除故障節點
except KeyError:
pass
if participant_removed:
# Common epilogue shared with the _remove_from_participants()
# function of _DistributedRendezvousOpExecutor.
_remove_participant_epilogue(state, self._settings)
介紹完畢如何執行引擎,我們接下來看看具體運算元。
0x03 運算元
_RendezvousOpExecutor
引擎的業務邏輯被分成兩層:使用者操作 和 內部業務邏輯。使用者操作和內部業務機制之間被解耦。
-
使用者操作被分成各種運算元,包括:心跳,Join,關閉,結束。比如Join 運算元就是
_RendevzousJoinOp
。 -
內部業務邏輯被分成各種業務函式,比如 _add_to_participants 方法從等待列表中移除節點,往 participants 加入這個節點。
-
運算元和內部業務邏輯並不是一一對應,需要一個類似狀態機的機制來控制。
- 比如,心跳操作運算元的結果可能是:超時/keep alive/正常結束,所以應該根據這個結果呼叫不同的內部業務函式。這種對應關係邏輯就是通過 Action 來完成的。
- 各種運算元聯合起來,聚合成了一個狀態機。
- 運算元內部就是生成各種 Action,決定了狀態機的下一步操作。
-
引擎內部就是根據 Action 來執行具體業務邏輯,或者可以說,是通過 Action 進行解耦。
具體如下,引擎從邏輯上可以分成三層:最上面是運算元層,中間是 Action 層,下面是業務函式層。
+-----------------------------------------------------------------------------------------+
| |
| _RendezvousKeepAliveOp _RendezvousCloseOp _RendezvousExitOp _RendezvousJoinOp |
| |
+-------------+---------------------+--------------------+------------------+-------------+
| | | |
| | | |
| | | |
| | | |
v v v v
+-----------------------------------------------------------------------------------------+
| |
| KEEP_ALIVE ADD_TO_PARTICIPANTS ADD_TO_WAIT_LIST REMOVE_FROM_WAIT_LIST ...... |
| |
+-------------+----------+----------+----------+---------+---------+---------+------------+
| | | | | | |
| | | | | | |
| | | | | | |
| | | | | | |
v v v v v v v
+-----------------------------------------------------------------------------------------+
| |
| _add_to_participants _remove_from_participants _add_to_wait_list ...... |
| |
| |
+-----------------------------------------------------------------------------------------+
我們逐一解析。
3.1 操作
先來解析中間層 Action,看看有多少 Action。基於 rendezvous 的狀態,引擎的actions具體如下。程式碼位於 torch/distributed/elastic/rendezvous/dynamic_rendezvous.py
class _Action(Enum):
"""Specifies the possible actions based on the state of the rendezvous."""
KEEP_ALIVE = 1
ADD_TO_PARTICIPANTS = 2
ADD_TO_WAIT_LIST = 3
REMOVE_FROM_PARTICIPANTS = 4
REMOVE_FROM_WAIT_LIST = 5
MARK_RENDEZVOUS_COMPLETE = 6
MARK_RENDEZVOUS_CLOSED = 7
SYNC = 8
ERROR_CLOSED = 9
ERROR_TIMEOUT = 10
FINISH = 11
3.2 運算元
引擎之中實現了一些運算元,基本上,一個操作對應一個運算元,我們給出幾個操作運算元的例子,運算元就是依據rendezvous的狀態來設定操作型別。
3.2.1 心跳
3.2.1.1 檢查心跳
_RendezvousKeepAliveOp 的作用是:依據當前狀態和時間來確定下一步Action。主要是定期檢查本Node是否故障。
class _RendezvousKeepAliveOp:
"""Represents a rendezvous keep-alive update operation."""
def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
if _should_keep_alive(ctx):
if time.monotonic() > deadline:
return _Action.ERROR_TIMEOUT
return _Action.KEEP_ALIVE
return _Action.FINISH
_should_keep_alive 方法為:
def _should_keep_alive(ctx: _RendezvousContext) -> bool:
"""Determines whether a keep-alive heartbeat should be sent."""
try:
last_heartbeat = ctx.state.last_heartbeats[ctx.node]
except KeyError:
return False
return last_heartbeat <= datetime.utcnow() - ctx.settings.keep_alive_interval
3.2.1.2 定期呼叫
這裡要注意的是,因為做任何運算元之前,都要呼叫 sync 操作,而 sync 會在 node 之間同步狀態,因為心跳是定期的,所以同步狀態也是定期的。
DynamicRendezvousHandler 之中會啟動一個timer,定期呼叫_keep_alive_weak方法。
def _start_heartbeats(self) -> None:
self._keep_alive_timer = _PeriodicTimer(
self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self)
)
self._keep_alive_timer.set_name(f"RendezvousKeepAliveTimer_{self._this_node.local_id}")
self._keep_alive_timer.start()
其次,_keep_alive_weak
會呼叫 self._keep_alive()
。
@staticmethod
def _keep_alive_weak(weak_self) -> None:
self = weak_self()
if self is not None:
self._keep_alive()
_keep_alive 會呼叫 _RendezvousKeepAliveOp。
def _keep_alive(self) -> None:
self._heartbeat_lock.acquire()
op = _RendezvousKeepAliveOp()
deadline = self._get_deadline(self._settings.timeout.heartbeat)
try:
self._op_executor.run(op, deadline)
msg = (
f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous "
f"'{self._settings.run_id}'."
)
self._record(message=msg)
log.debug(msg)
except RendezvousError as ex:
msg = (
f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the "
f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}."
)
self._record(message=msg, node_state=NodeState.FAILED)
finally:
self._heartbeat_lock.release()
3.2.1.2 設定心跳
另外,_DistributedRendezvousOpExecutor 有一個 _keep_alive 同名函式,是用來實現內部邏輯,我們後續會講到。
3.2.2 關閉
_RendezvousCloseOp 會依據當前狀態和時間來確定下一步Action。
class _RendezvousCloseOp:
"""Represents a rendezvous close operation."""
def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
if ctx.state.closed:
return _Action.FINISH
if time.monotonic() > deadline:
return _Action.ERROR_TIMEOUT
return _Action.MARK_RENDEZVOUS_CLOSED
3.2.3 結束
_RendezvousExitOp 依據當前狀態和時間來確定下一步Action。如果本Node不在participants之中,不處理。否則返回一個從 participants 列表刪除的下一步Action。如果超時則返回對應Action。
class _RendezvousExitOp:
"""Represents a rendezvous exit operation."""
def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
if ctx.node in ctx.state.participants:
if time.monotonic() > deadline:
return _Action.ERROR_TIMEOUT
return _Action.REMOVE_FROM_PARTICIPANTS
return _Action.FINISH
3.2.4 Join
_RendezvousJoinOp 這裡依據系統狀態不同,做不同處理,比如試圖把本Node加入到participant,或者 waiting list,或者繼續等待,具體可以參見程式碼註釋。
- 從上下文之中提取 _RendezvousState 狀態,把結果存放在 state 之中。
- 如果狀態是closed,則說明此時rendezvous已經結束,則返回_Action.ERROR_CLOSED。
- 看看是不是參與者,把結果存放在is_participant。
- 如果狀態已經結束,且本節點已經是參與者,則說明 rendezvous 可以結束,返回 _Action.FINISH。
- 獲取當前時間 now。
- 如果 now > deadline,說明已經超時。
- 如果還有時間做 rollback,說明本節點要返回之前的狀態。
- 如果本節點已經是參與者,說明此時總節點數目沒有達到 min,雖然已經是參與者,但是需要從參與者列表移除,所以返回 _Action.REMOVE_FROM_PARTICIPANTS。
- 如果本節點在等待列表之中,說明此時總節點數目沒有達到 max,雖然在等待列表之中,但是需要從等待列表移除,所以返回_Action.REMOVE_FROM_WAIT_LIST。
- 否則返回_Action.ERROR_TIMEOUT。
- 如果還有時間做 rollback,說明本節點要返回之前的狀態。
- 否則沒有超時,繼續處理。
- 如果state.complete 並且本節點不是參與者(如果節點是參與者,前面已經處理過了),說明rendezvous 已經結束,如果還沒有達到最大節點數目,並且當前node不在等待列表之中,就需要新增到等待節點列表,等待下次監控週期到的時候,重新做rendezvous,就可以把等待列表中的節點加入到參與列表之中。所以返回_Action.ADD_TO_WAIT_LIST。
- 如果本節點是參與者並且state不是complete狀態(如果是complete狀態,前面已經處理過了),如果已經達到了最小節點數 & 已經超時了,則說明rendezvous 已經結束,則返回_Action.MARK_RENDEZVOUS_COMPLETE。
- 否則說明沒結束,本節點也不是參與者,則直接加入到參與者列表,返回_Action.ADD_TO_PARTICIPANTS。
- 如果需要保持心跳,就返回 _Action.KEEP_ALIVE。
- 否則返回_Action.SYNC。
class _RendezvousJoinOp:
"""Represents a rendezvous join operation."""
def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
state = ctx.state # 從上下文之中提取 _RendezvousState 狀態
# A closed rendezvous means that it no longer accepts new nodes.
if state.closed:
return _Action.ERROR_CLOSED # 如果已經結束,就返回 _Action.ERROR_CLOSED
is_participant = ctx.node in state.participants # 看看是不是參與者
# If we are part of the rendezvous and it is already complete there is
# no further action to take.
if state.complete and is_participant: # 如果是參與者且狀態是結束,就返回 _Action.FINISH
return _Action.FINISH
now = time.monotonic()
if now > deadline: # 如果已經超時
rollback_period = 5 # 5 seconds
# If we still have time to rollback (a short period on top of the
# operation deadline), try to remove ourself from the rendezvous.
# It is okay if we can't though as our keep-alive will eventually
# expire.
if now <= deadline + rollback_period: # 如果還有時間來 rollback
# If we are part of the rendezvous, it means we couldn't find
# enough participants to complete it on time.
if is_participant: # 此時尚未達到min,雖然已經是參與者,但是需要移除
return _Action.REMOVE_FROM_PARTICIPANTS # 需要從參與者列表移除
# If we are in the wait list, it means we couldn't wait till the
# next round of the rendezvous.
if ctx.node in state.wait_list: # 此時已經達到 max,雖然已經在等待列表之中,需要移除
return _Action.REMOVE_FROM_WAIT_LIST # 需要從等待列表移除
return _Action.ERROR_TIMEOUT # 返回超時
if state.complete: # 如果 rendezvous 已經結束
# If we are here, it means we are not part of the rendezvous. In
# case the rendezvous has capacity for additional participants add
# ourself to the wait list for the next round.
if len(state.participants) < ctx.settings.max_nodes: # 如果還沒有達到最大節點數
if ctx.node not in state.wait_list: # 如果當前node不在等待列表之中
return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,傳送一個等待action
elif is_participant: # 如果已經在參與者列表
# If the rendezvous has enough number of participants including us,
# check whether we have passed the rendezvous deadline. If yes,
# complete it.
if len(state.participants) >= ctx.settings.min_nodes: # 如果達到了最小節點數
if cast(datetime, state.deadline) < datetime.utcnow(): # 如果達到了超時
return _Action.MARK_RENDEZVOUS_COMPLETE # 標示 rendezvous 已經結束
else: # 否則就直接加入到參與者
# The rendezvous is not complete yet and we are not part of it. Try
# to join.
return _Action.ADD_TO_PARTICIPANTS
if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
return _Action.KEEP_ALIVE
# At this point either the rendezvous is not complete, but we are part
# of it, which means we have to wait for other participants to join; or
# the rendezvous is complete, but we are not part of it, which means we
# have to wait for the next round.
return _Action.SYNC # 否則返回同步狀態 _Action.SYNC
具體邏輯如下:
state.closed
+--------------------------> _Action.ERROR_CLOSED
|
|
| complete & participant
+--------------------------> _Action.FINISH
|
|
| timeout & participant
+--------------------------> _Action.REMOVE_FROM_PARTICIPANTS
|
|
| timeout & wait
+--------------------------> _Action.REMOVE_FROM_WAIT_LIST
|
+-------------------+ |
| | | timeout
| _RendezvousJoinOp +------------------------------> _Action.ERROR_TIMEOUT
| | |
+-------------------+ | complete & < max & not wait
|
+--------------------------> _Action.ADD_TO_WAIT_LIST
|
| complete & participant & > min & deadline
|
+--------------------------> _Action.MARK_RENDEZVOUS_COMPLETE
|
| not complete & not participant
|
+--------------------------> _Action.ADD_TO_PARTICIPANTS
|
| _should_keep_alive
|
+--------------------------> _Action.KEEP_ALIVE
|
| else
|
+--------------------------> _Action.SYNC
以下是原始碼之中 ETCD 後端 Rendezvous 狀態描述圖,我們可以大致參考比對 c10d的狀態。
可見,etcd 後端的Join可以分為4個階段:
- setup 階段,會往固定目錄寫一個值,這是一個排他鎖,如果寫失敗,說明目前正有一個
rendezvous
過程在進行中。 - join(joinable) 階段。如果寫值成功,則進入join 階段。如果在等待時間結束或者參與訓練的節點達到了最大值,則進入 frozen 階段。
- frozen(confirm)階段。需要所有節點都確認,進入最後的 final 階段。
- final 階段。分配rank,
RANK 0
的例項成為 master。
仿照上圖,我們把 c10d 擴充如下。
+
|
|
v
+-----+------+
| |
| closed +---------------> ERROR_CLOSED
| |
+-----+------+
|
|
v
+-----+------+ is_participant
| |
| complete +---------------> FINISH
| |
+-----+------+
| is_participant
|
v +----> REMOVE_FROM_PARTICIPANTS
+-----+-------+ now > deadline +-----------+ now < rollback +-----------+ |
| | | | | | |
| join +----------------> | timeout +---------------------->+ rollback +-----+
| | | | | | |
+-----+-------+ +----+------+ +-----------+ |
| | | in state.wait_list
| | now > rollback |
| now < deadline | +----> REMOVE_FROM_WAIT_LIST
| +----------> ERROR_TIMEOUT
|
| complete && not is_participant && < max && not in state.wait_list
|
+------------------------------------------------------------------> ADD_TO_WAIT_LIST
|
| not complete && is_participant && > min && > deadline
|
+------------------------------------------------------------------> MARK_RENDEZVOUS_COMPLETE
|
| not complete && not is_participant
|
+-----------------------------------------> ADD_TO_PARTICIPANTS
|
| _should_keep_alive
|
+---------------------------> KEEP_ALIVE
|
|
v
SYNC
手機如下:
0x04 業務操作
_DistributedRendezvousOpExecutor.run 的內部就是依據 action 選擇不同的業務函式來執行。
if action == _Action.KEEP_ALIVE:
self._keep_alive()
elif action == _Action.ADD_TO_PARTICIPANTS:
self._add_to_participants()
elif action == _Action.ADD_TO_WAIT_LIST:
self._add_to_wait_list()
elif action == _Action.REMOVE_FROM_PARTICIPANTS:
self._remove_from_participants()
elif action == _Action.REMOVE_FROM_WAIT_LIST:
self._remove_from_wait_list()
elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
self._mark_rendezvous_complete()
elif action == _Action.MARK_RENDEZVOUS_CLOSED:
self._mark_rendezvous_closed()
我們接下來就看看具體這些內部函式邏輯。
4.1 加入參與者
接受到 ADD_TO_PARTICIPANTS 之後,呼叫 _add_to_participants 從等待列表中移除節點,往 participants 加入這個節點。
def _add_to_participants(self) -> None:
state = self._state
try:
state.wait_list.remove(self._node)
except KeyError:
pass
# The ranks of the participants will be set once the rendezvous is
# complete.
state.participants[self._node] = 0
self._keep_alive()
if len(state.participants) == self._settings.min_nodes:
state.deadline = datetime.utcnow() + self._settings.timeout.last_call
if len(state.participants) == self._settings.max_nodes:
self._mark_rendezvous_complete()
4.2 移除參與者
接受到 REMOVE_FROM_PARTICIPANTS 之後,呼叫 _remove_from_participants 從 participants 和 last_heartbeats 中刪除參與者。
def _remove_from_participants(self) -> None:
state = self._state
del state.participants[self._node]
del state.last_heartbeats[self._node]
if state.complete:
# If we do not have any participants left, move to the next round.
if not state.participants:
state.complete = False
state.round += 1
else:
if len(state.participants) < self._settings.min_nodes:
state.deadline = None
4.3 加入等待序列
接受到 ADD_TO_WAIT_LIST 之後,呼叫 _add_to_wait_list 網 wait_list 中加入節點。
def _add_to_wait_list(self) -> None:
self._state.wait_list.add(self._node)
self._keep_alive()
4.4 移除等待序列
接受到 REMOVE_FROM_WAIT_LIST 之後,呼叫 _remove_from_wait_list 從 wait_list 移除節點。
def _remove_from_wait_list(self) -> None:
self._state.wait_list.remove(self._node)
del self._state.last_heartbeats[self._node]
4.5 設定結束
接受到 MARK_RENDEZVOUS_COMPLETE 之後,當 rendezvous 聚合操作結束之後,給每一個參與者設定 rank。
每個節點上都是按照同樣演算法排序,所以rank在每個節點上都是一樣的。
def _mark_rendezvous_complete(self) -> None:
state = self._state
state.complete = True
state.deadline = None
# Assign the ranks.
for rank, node in enumerate(sorted(state.participants)):
state.participants[node] = rank
def _mark_rendezvous_closed(self) -> None:
self._state.closed = True
4.6 心跳
接收到 KEEP_ALIVE action之後,會呼叫到 _keep_alive 來維持心跳。另外,keep_alive 也會在 _add_to_participants等方法內被呼叫,會更新本地state之中的last heartbeats,下一次 sync 時候,會把 last_heartbeats 寫入鍵值儲存,這樣其他Node就可以知道這個節點的狀態了。而本地則會在 _sanitize 之中依據 last_heartbeats 做處理,我們之前提到過。
def _keep_alive(self) -> None:
msg = (
f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous "
f"'{self._settings.run_id}'. Pending sync."
)
self._record(message=msg)
self._state.last_heartbeats[self._node] = datetime.utcnow()
_record 方法如下:
def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None:
construct_and_record_rdzv_event(
name=f"{self.__class__.__name__}.{get_method_name()}",
run_id=self._settings.run_id,
message=message,
node_state=node_state,
hostname=self._node.fqdn,
pid=self._node.pid,
local_id=self._node.local_id,
)
其就是呼叫如下程式碼記錄log。
def record_rdzv_event(event: RdzvEvent) -> None:
_get_or_create_logger("dynamic_rendezvous").info(event.serialize())
def construct_and_record_rdzv_event(
run_id: str,
message: str,
node_state: NodeState,
name: str = "",
hostname: str = "",
pid: Optional[int] = None,
master_endpoint: str = "",
local_id: Optional[int] = None,
rank: Optional[int] = None,
) -> None:
# We don't want to perform an extra computation if not needed.
if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler):
return
# Set up parameters.
if not hostname:
hostname = socket.getfqdn()
if not pid:
pid = os.getpid()
# Determines which file called this function.
callstack = inspect.stack()
filename = "no_file"
if len(callstack) > 1:
stack_depth_1 = callstack[1]
filename = os.path.basename(stack_depth_1.filename)
if not name:
name = stack_depth_1.function
# Delete the callstack variable. If kept, this can mess with python's
# garbage collector as we are holding on to stack frame information in
# the inspect module.
del callstack
# Set up error trace if this is an exception
if node_state == NodeState.FAILED:
error_trace = traceback.format_exc()
else:
error_trace = ""
# Initialize event object
event = RdzvEvent(
name=f"{filename}:{name}",
run_id=run_id,
message=message,
hostname=hostname,
pid=pid,
node_state=node_state,
master_endpoint=master_endpoint,
rank=rank,
local_id=local_id,
error_trace=error_trace,
)
# Finally, record the event.
record_rdzv_event(event)
至此,引擎部分也已經分析完畢,下一篇我們看看是否可以從整體角度再做一下全面梳理。
0xFF 參考
[原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路
[原始碼解析] PyTorch 分散式之彈性訓練(2)---啟動&單節點流程