[原始碼解析] PyTorch 分散式之彈性訓練(4)---Rendezvous 架構和邏輯
0x00 摘要
在前面的文章之中,我們已經學習了PyTorch 分散式的基本模組,介紹了官方的幾個例子,我們接下來會介紹PyTorch的彈性訓練,本文是第四篇,看看Rendezvous 的結構和總體邏輯。
彈性訓練系列文章如下:
[原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路
[原始碼解析] PyTorch 分散式之彈性訓練(2)---啟動&單節點流程
[原始碼解析] PyTorch 分散式之彈性訓練(3)---代理
0x01 總體背景
TE 是圍繞在 Rendezvous 基礎之上的多個elastic agent構成,這是一種功能分離,讓我們對比一下看看。
- Agent 偏重具體節點上的邏輯。
- Agent 負責具體業務邏輯相關操作,比如啟動程式執行使用者程式,監控使用者程式執行情況,如果有異常就通知 Rendezvous。
- Agent 是一個 worker manager,負責啟動/管理 workers 程式,組成一個 worker group,監控 workers 執行狀態,捕獲失效 workers,如果有故障/新加入worker,則重啟 worker group。
- Agent負責維護 WORLD_SIZE 以及 RANK 資訊。使用者不需要再手動提供,Agent會自動處理這些。
- Agent 是具體節點上的後臺程式,是獨立個體。Agent自己無法實現整體上的彈性訓練,所以需要一個機制來完成 worker 之間的相互發現,變更同步等等(WORLD_SIZE 和 RANK 這些資訊其實也需要多個節點同步才能確定),這就是下面的 Rendezvous 概念。
- Rendezvous 負責叢集邏輯,保證節點之間對於""有哪些節點參與訓練"達成強一致共識。
- 每一個 Agent 內部包括一個 Rendezvous handler,這些 handler 總體上構成了一個 Rendezvous 叢集,從而構成了一個 Agent 叢集。
- Rendezvous 完成之後,會建立一個共享鍵值儲存(shared key-value store),這個store實現了一個
torch.distributed.Store
API。此儲存僅由已完成Rendezvous的成員共享,它旨在讓Torch Distributed Elastic在初始化作業過程之中交換控制和資料資訊。 - Rendezvous 負責在每個agent之上維護當前 group 所有相關資訊。每個 agent 之上有一個 rendezvous,它們會互相通訊,總體維護一套資訊,這些資訊儲存在上面提到的Store 之中。
- Rendezvous 負責叢集邏輯相關,比如新加入節點,移除節點,分配rank等等。
0x02 基本概念
在 Torch Distributed Elastic 上下文之中,人們使用 rendezvous 這個術語來特指一個特定功能:一個結合了對等發現(peer discovery)的分散式同步(distributed synchronization)原語。
其可以理解為一個分散式治理過程:Rendezvous 被Torch Distributed Elastic用來收集一個訓練job的參與者(節點),這樣,參與者們可以商議得到參與者列表和每個參與者的角色,也可以對訓練何時開始/恢復做出一致的集體決定。即,通過 rendezvous,系統對參與者達成共識,給每一個參與者分配 rank,local rank,通知 world size等等,當需要彈性伸縮或者出現故障時候,就會重新進行 rendezvous 操作。
為了實現彈性訓練,需要有一個節點/程式之間彼此發現的機制。在TorchElastic中,rendezvous就是這個發現機制或者說同步元件,其被用來作為對等發現的分散式同步(治理)機制,用於同步、收集各個worker的資訊,包括節點列表、各節點worker角色等,然後各個Agent才能共同決定訓練的開始、結束、恢復等。
圖片來自 PyTorch 原始碼。
或者使用 TE 原始碼之中的圖片,能更清楚的看出來這是三個Node。
Rendezvous會提供以下細分功能。
2.1 Barrier
執行會合的節點將全部阻塞到 rendezvous 完成,即至少有min
個節點(針對同一作業)已加入到Barrier,這也意味著對於固定大小的節點數目,barrier是不必要的。
在達到"min"數量後,rendezvous 不會立刻宣佈完成,而是會等待額外的一小段時間,這用來保證rendezvous不會"過快"完成,因為如果立刻完成,就會錯過那些加入時只慢了一點點的節點。當然如果在Barrier處聚集了max
個節點,則rendezvous立即完成。
另外,還有一個總超時時間配置 :如果在超時時間之內 min
個節點一直沒有達到,則會導致 rendezvous 失敗,這是一個簡單的故障安全(fail-safe)解決方案,用來幫助釋放部分分配的作業資源,防止資源浪費。
2.2 排他性(Exclusivity)
一個簡單的分散式屏障是不夠的,因為我們還需要確保在任何給定的時間(對於給定的作業)只存在一組節點。換言之,對於同一個job,新節點(即後來加入的節點)不能組成一個新的並行的獨立worker group。
Torch Distributed Elastic 會確保如果一組節點已經完成rendezvous(可能已經在訓練),那麼其他試圖加入的"遲到"節點只會被認為是等待狀態,且必須等到現有rendezvous被結束。
2.3 一致性(Consistency)
rendezvous完成後,其所有成員將對工作成員資格以及每個人在其中的角色(role)達成共識。此角色(role)使用一個介於 0 ~ world size 之間的整型來表示,被稱之為rank。
請注意,rank是不穩定的,比如,同一個的節點在下一次(重新)rendezvous中可能被分配了不同的rank。
2.4 容錯(Fault-tolerance)
Torch Distributed Elastic rendezvous 在 rendezvous 過程中有容錯機制:
-
在開始join rendezvous 和 rendezvous 完成之間,如果有程式崩潰(或網路故障等),就會自動引發一個re-rendezvous,剩餘健康節點會自動重組。
-
節點也可能在
rendezvous
完成後失敗(或被其他節點觀察到失敗),這個場景由Torch Distributed Elastictrain_loop
負責,也會觸發一個re-rendezvous,訓練過程不會中斷。
2.5 共享鍵值儲存
Rendezvous 完成後,將建立一個共享鍵值儲存(key-value store)並返回給node。此儲存實現了一個torch.distributed.store
API(請參見https://pytorch.org/docs/stable/distributed.html
)。
此儲存僅由已完成rendezvous的成員共享,被Torch Distributed Elastic用作交換初始化作業控制和資料平面所必需的資訊。
2.6 等待worker和rendezvous關閉
Torch Distributed Elastic rendezvous handler提供了額外功能:
- 查詢在barrier之後有多少worker加入(遲到了),他們將在下一次rendezvous 中參與進來。
- 設定 rendezvous 為關閉狀態,以通知所有節點不參與下一次rendezvous 。
2.7 DynamicRendzvousHandler
Torch Distributed Elastic 提供了DynamicRendzvousHandler類,該類實現了上述的 rendezvous mechanism。
這個類需要我們在構建時候指定後端(RendezvousBackend)。使用者可以自己實現後端,或者使用如下PyTorch附帶實現之一:
- C10dRendezvousBackend,其使用 C10d 儲存(預設是 TCPStore) 作為 rendezvous backend,其優勢是不需要依賴第三方,比如etcd,來構建一個rendezvous 。
- EtcdRendezvousBackend,其使用EtcdRendezvousHandler,EtcdRendezvousBackend 等類來基於 etcd 完成,缺點是需要搭建 Etcd。
比如:
store = TCPStore("localhost")
backend = C10dRendezvousBackend(store, "my_run_id")
rdzv_handler = DynamicRendezvousHandler.from_backend(
run_id="my_run_id",
store=store,
backend=backend,
min_nodes=2,
max_nodes=4
)
2.8 問題&設計
知道了所需要實現的功能,我們就可以思考 Rendezvous 應該具備哪些內部模組,才能滿足這些需求。
- 需要有一個節點概念,這樣才能把系統表達出來。
- 需要有一個狀態概念,就是節點的狀態。
- 需要有一個總體靜態類,用來把節點,狀態以及其他資訊統一維護起來。
- 需要有一個共享共享鍵值儲存,可以集中儲存上述資訊,也可以用來彼此交換資訊,達成共識。
- 需要有一個動態server,或者handler,其提供一套API以供外界訪問。
我們就按照這個思路分析,首先看看靜態結構,然後看看動態邏輯。
0x03 靜態結構
我們接下來看看相關支撐系統。這裡要注意的是,elastic 內部有一套 Rendezvous,和 distributed 原有的 Rendezvous 那套不一樣,別搞混了。distributed 原有的 Rendezvous 就是一套簡單的 KV 儲存。elastic Rendezvous 則要複雜得多。
我們仔細分析一下 Rendezvous 的支撐系統。
3.1 啟動引數
RendezvousParameters 是構建RendezvousHandler
所需引數。
- backend :後端名稱。
- endpoint :端點,格式是
[: ]。 - run_id : rendezvous 的 id。
- min_nodes :rendezvous 的最小節點數目。
- max_nodes :rendezvous 的最大節點數目。
- kwargs :後端的附加引數。
class RendezvousParameters:
"""Holds the parameters to construct a :py:class:`RendezvousHandler`.
Args:
backend:
The name of the backend to use to handle the rendezvous.
endpoint:
The endpoint of the rendezvous, usually in form <hostname>[:<port>].
run_id:
The id of the rendezvous.
min_nodes:
The minimum number of nodes to admit to the rendezvous.
max_nodes:
The maximum number of nodes to admit to the rendezvous.
**kwargs:
Additional parameters for the specified backend.
"""
def __init__(
self,
backend: str,
endpoint: str,
run_id: str,
min_nodes: int,
max_nodes: int,
**kwargs,
):
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")
if min_nodes < 1:
raise ValueError(
f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero."
)
if max_nodes < min_nodes:
raise ValueError(
f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or "
f"equal to the minimum number of rendezvous nodes ({min_nodes})."
)
self.backend = backend
self.endpoint = endpoint
self.run_id = run_id
self.min_nodes = min_nodes
self.max_nodes = max_nodes
self.config = kwargs
3.2 配置
RendezvousSettings 類用來儲存rendezvous的配置。可以理解為靜態元資訊。
- run_id : rendezvous 的 id。
- min_nodes :rendezvous 的最小節點數目。
- max_nodes :rendezvous 的最大節點數目。
- timeout :超時時間。
- keep_alive_interval :節點在傳送心跳之間等待的時間量。
- keep_alive_max_attempt : 心跳的最大重試次數。
@dataclass(repr=False, eq=False, frozen=True)
class RendezvousSettings:
"""Holds the settings of the rendezvous.
Attributes:
run_id:
The run id of the rendezvous.
min_nodes:
The minimum number of nodes to admit to the rendezvous.
max_nodes:
The maximum number of nodes to admit to the rendezvous.
timeout:
The timeout configuration of the rendezvous.
keep_alive_interval:
The amount of time a node waits before sending a heartbeat to keep
it alive in the rendezvous.
keep_alive_max_attempt:
The maximum number of failed heartbeat attempts after which a node
is considered dead.
"""
run_id: str
min_nodes: int
max_nodes: int
timeout: RendezvousTimeout
keep_alive_interval: timedelta
keep_alive_max_attempt: int
3.3 狀態
_RendezvousState 是rendezvous的狀態。是動態資訊,每一個 node 都會維護一個本地 state。
-
round:Rendezvous的當前輪次
-
complete:一個布林值,指示rendezvous當前一輪是否完成了。
-
deadline:截止時間,如果如果當前輪次一直在等待節點加入,如果這個引數設定了,就是等待的截至時間。
-
closed:一個布林值,指示rendezvous是否結束了。
-
participants:字典結構,存放參與者和它們對應ranks。
-
wait_list:set結構,存放等待參與下一輪rendezvous操作的一組節點
-
last_heartbeats:字典,包含每個節點上次心跳時間。
class _RendezvousState:
"""Holds the state of a rendezvous.
Attributes:
round:
The current round of the rendezvous.
complete:
A boolean value indicating whether the current round of the
rendezvous is complete.
deadline:
The time at which the current round of the rendezvous will be
considered complete if it is still waiting for nodes to join.
closed:
A boolean value indicating whether the rendezvous is closed.
participants:
A dictionary of the participants and their corresponding ranks.
wait_list:
A set of nodes that are waiting to participate in the next round of
the rendezvous.
last_heartbeats:
A dictionary containing each node's last heartbeat time.
"""
round: int
complete: bool
deadline: Optional[datetime]
closed: bool
participants: Dict[_NodeDesc, int]
wait_list: Set[_NodeDesc]
last_heartbeats: Dict[_NodeDesc, datetime]
def __init__(self) -> None:
self.round = 0
self.complete = False
self.deadline = None
self.closed = False
self.participants = {}
self.wait_list = set()
self.last_heartbeats = {}
3.4 節點
_NodeDesc 是rendezvous的一個節點。
@dataclass(eq=True, order=True, frozen=True)
class _NodeDesc:
"""Describes a node in the rendezvous.
Attributes:
fqdn:
The FQDN of the node.
pid:
The id of the process in which the rendezvous handler runs.
local_id:
A process-wide unique id.
"""
fqdn: str
pid: int
local_id: int
def __repr__(self) -> str:
return f"{self.fqdn}_{self.pid}_{self.local_id}"
3.5 後端
在 PyTorch 之中,backend 概念指的是當前程式要使用的通訊後端,一般來說,支援的通訊後端有 gloo
,mpi
,nccl
。建議用 nccl
。
在彈性訓練這裡,DynamicRendezvousHandler 需要我們在構建時候指定後端(RendezvousBackend)。使用者可以自己實現後端,或者使用如下PyTorch附帶實現之一:
- C10dRendezvousBackend,其使用 C10d 儲存(預設是 TCPStore) 作為 rendezvous backend,其優勢是不需要依賴第三方,比如etcd,來構建一個rendezvous 。
- EtcdRendezvousBackend,其使用EtcdRendezvousHandler,EtcdRendezvousBackend 等類來基於 etcd 完成。
因為 EtcdRendezvousBackend 必須依賴 ETCD,需要安裝一個 ETCD叢集,所以推薦使用 c10d 後端,其易用性更好。我們接下來就主要介紹 c10d 後端。
C10d 後端主要基於一個 TCPStore,通過 TCP 進行同步。我們在之前文章中介紹過 TCPStore,TCPStore 是基於 TCP 的分散式鍵值儲存實現(類似於 Redis)。是一個典型的 client-server 架構,伺服器儲存/儲存資料,而儲存客戶端可以通過 TCP 連線到伺服器儲存並執行諸如set()插入鍵值對、get()檢索鍵值對等操作。
所以,對於 c10d 後端來說,在其中一個代理之上會執行 TCPStore Master,其負責監聽埠,提供API,Rendezvous 的各種同步操作,都是由各個代理連線到這個中心化的 TCPStore Master,在其上完成。
具體可以如下圖所示,來源於知乎 BobLiu。
3.5.1 使用
下圖展示瞭如何配置後端
store = TCPStore("localhost")
backend = C10dRendezvousBackend(store, "my_run_id") # 配置了後端
rdzv_handler = DynamicRendezvousHandler.from_backend(
run_id="my_run_id",
store=store,
backend=backend,
min_nodes=2,
max_nodes=4
)
3.5.2 基類
我們首先看看後端的基類 RendezvousBackend。這是一個虛類,主要功能就是設定和獲取State。
class RendezvousBackend(ABC):
"""Represents a backend that holds the rendezvous state."""
@property
@abstractmethod
def name(self) -> str:
"""Gets the name of the backend."""
@abstractmethod
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""Gets the rendezvous state.
Returns:
A tuple of the encoded rendezvous state and its fencing token or
``None`` if no state is found in the backend.
"""
@abstractmethod
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""Sets the rendezvous state.
The new rendezvous state is set conditionally:
- If the specified ``token`` matches the fencing token stored in the
backend, the state will be updated. The new state will be returned
to the caller along with its fencing token.
- If the specified ``token`` does not match the fencing token stored
in the backend, the state won't be updated; instead the existing
state along with its fencing token will be returned to the caller.
- If the specified ``token`` is ``None``, the new state will be set
only if there is no existing state in the backend. Either the new
state or the existing state along with its fencing token will be
returned to the caller.
Args:
state:
The encoded rendezvous state.
token:
An optional fencing token that was retrieved by a previous call
to :py:meth:`get_state` or ``set_state()``.
Returns:
A tuple of the serialized rendezvous state, its fencing token, and
a boolean value indicating whether our set attempt succeeded.
Raises:
RendezvousConnectionError:
The connection to the backend has failed.
RendezvousStateError:
The rendezvous state is corrupt.
"""
3.5.3 建立
以下程式碼是如何建立後端。其先是生成了 tcp store,然後呼叫 C10dRendezvousBackend。
def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]:
"""Creates a new :py:class:`C10dRendezvousBackend` from the specified
parameters.
+--------------+-----------------------------------------------------------+
| Parameter | Description |
+==============+===========================================================+
| store_type | The type of the C10d store. As of today the only |
| | supported type is "tcp" which corresponds to |
| | :py:class:`torch.distributed.TCPStore`. Defaults to "tcp".|
+--------------+-----------------------------------------------------------+
| read_timeout | The read timeout, in seconds, for store operations. |
| | Defaults to 60 seconds. |
+--------------+-----------------------------------------------------------+
| is_host | A boolean value indicating whether this backend instance |
| | will host the C10d store. If not specified it will be |
| | inferred heuristically by matching the hostname or the IP |
| | address of this machine against the specified rendezvous |
| | endpoint. Defaults to ``None``. |
| | |
| | Note that this configuration option only applies to |
| | :py:class:`torch.distributed.TCPStore`. In normal |
| | circumstances you can safely skip it; the only time when |
| | it is needed is if its value cannot be correctly |
| | determined (e.g. the rendezvous endpoint has a CNAME as |
| | the hostname or does not match the FQDN of the machine). |
+--------------+-----------------------------------------------------------+
"""
# As of today we only support TCPStore. Other store types do not have the
# required functionality (e.g. compare_set) yet.
store_type = params.get("store_type", "tcp").strip().lower()
if store_type != "tcp":
raise ValueError("The store type must be 'tcp'. Other store types are not supported yet.")
store = _create_tcp_store(params)
return C10dRendezvousBackend(store, params.run_id), store
3.5.3.1 TCPStore
_create_tcp_store 建立了一個 TCPStore。
def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
host, port = parse_rendezvous_endpoint(params.endpoint, default_port=29400)
cfg_is_host = params.get_as_bool("is_host") # 獲取配置看看
# If the user has explicitly specified whether our process should host the
# the store, respect it.
if cfg_is_host is not None: # 如果配置了,就使用
is_host = cfg_is_host
# Otherwise try to determine whether we are the host based on our hostname
# and IP address.
else: # 否則動態看看本機是不是host
is_host = _matches_machine_hostname(host)
# The timeout
read_timeout = cast(int, params.get_as_int("read_timeout", 60))
if read_timeout <= 0:
raise ValueError("The read timeout must be a positive integer.")
# In specific cases we attempt to instantiate the store twice. For details
# see the explanation in the except clause below.
for is_server in [is_host, False]:
try:
store = TCPStore( # type: ignore[call-arg]
host, port, is_master=is_server, timeout=timedelta(seconds=read_timeout)
)
if is_server:
log.info(
f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend."
)
break
except (ValueError, RuntimeError) as exc:
# If we heuristically inferred the value of is_host as True and our
# first attempt to instantiate the TCP store has failed, try it one
# more time with is_host set to False. As an edge case there can be
# more than one process that is part of the same rendezvous on this
# machine and only one of them will eventually host the store.
if not is_server or cfg_is_host is not None:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc
return store
3.5.3.2 C10dRendezvousBackend
可以看到,C10dRendezvousBackend 其核心就是一個 Store,用來儲存相關資訊,以下程式碼進行了精簡,是通過 set_state 和 get_state 來對 store 進行讀寫。
class C10dRendezvousBackend(RendezvousBackend):
"""Represents a C10d-backed rendezvous backend.
Args:
store:
The :py:class:`torch.distributed.Store` instance to use to
communicate with the C10d store.
run_id:
The run id of the rendezvous.
"""
# See the explanation in the __init__ method.
_NULL_SENTINEL = "Y2FuaW1hZGFt"
_store: Store
_key: str
def __init__(self, store: Store, run_id: str) -> None:
if not run_id:
raise ValueError("The run id must be a non-empty string.")
self._store = store
self._key = "torch.rendezvous." + run_id
# The read operation of a store blocks the caller until the specified
# key becomes available. This behavior makes it tricky to use a store
# as a regular key-value dictionary.
#
# As a workaround we initially set a sentinel value as the rendezvous
# state. Whenever this value gets returned we treat it as a None.
self._call_store("compare_set", self._key, "", self._NULL_SENTINEL)
@property
def name(self) -> str:
"""See base class."""
return "c10d"
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
# 從store讀取資料
base64_state: bytes = self._call_store("get", self._key)
return self._decode_state(base64_state)
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""See base class."""
base64_state_str: str = b64encode(state).decode()
if token:
# Shortcut if we know for sure that the token is not valid.
if not isinstance(token, bytes):
result = self.get_state()
if result is not None:
tmp = *result, False
# Python 3.6 does not support tuple unpacking in return
# statements.
return tmp
return None
token = token.decode()
else:
token = self._NULL_SENTINEL
# 往 store 之中插入資料
base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str)
state_token_pair = self._decode_state(base64_state)
if state_token_pair is None:
return None
new_state, new_token = state_token_pair
# C10d Store's compare_set method does not offer an easy way to find out
# whether our write attempt was successful. As a brute-force solution we
# perform a bitwise comparison of our local state and the remote state.
return new_state, new_token, new_state == state
def _call_store(self, store_op: str, *args, **kwargs) -> Any:
return getattr(self._store, store_op)(*args, **kwargs)
def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]:
if base64_state == self._NULL_SENTINEL.encode():
return None
state = b64decode(base64_state)
return state, base64_state
3.6 StateHolder
3.6.1 _RendezvousStateHolder
這個類的作用是儲存與其他節點同步的rendezvous狀態,但是需要一個派生類來完成功能。
class _RendezvousStateHolder(ABC):
"""Holds the shared rendezvous state synced with other nodes."""
@property
@abstractmethod
def state(self) -> _RendezvousState:
"""Gets the local state."""
@abstractmethod
def sync(self) -> Optional[bool]:
"""Reads or writes the latest state.
Returns:
A boolean value indicating whether the local state, in case marked
as dirty, was successfully synced with other nodes.
"""
@abstractmethod
def mark_dirty(self) -> None:
"""Marks the local state as dirty."""
3.6.2 _BackendRendezvousStateHolder
_BackendRendezvousStateHolder
擴充了_RendezvousStateHolder
。其 sync 就是呼叫內部的 後端,對 store 進行讀寫。
class _BackendRendezvousStateHolder(_RendezvousStateHolder):
"""Holds the rendezvous state synced with other nodes via a backend.
Args:
backend:
The rendezvous backend to use.
settings:
The rendezvous settings.
cache_duration:
The amount of time, in seconds, to cache the last rendezvous state
before requesting it from the backend again.
"""
_backend: RendezvousBackend
_state: _RendezvousState
_settings: RendezvousSettings
_cache_duration: int
_token: Token
_dirty: bool
_last_sync_time: float
_dead_nodes: List[_NodeDesc]
def __init__(
self, backend: RendezvousBackend, settings: RendezvousSettings, cache_duration: int = 1
) -> None:
self._backend = backend
self._state = _RendezvousState()
self._settings = settings
self._cache_duration = cache_duration
self._token = None
self._dirty = False
self._last_sync_time = -1
self._dead_nodes = []
@property
def state(self) -> _RendezvousState:
"""See base class."""
return self._state
def sync(self) -> Optional[bool]:
"""See base class."""
state_bits: Optional[bytes] = None
token = None
has_set: Optional[bool]
if self._dirty:
has_set = False
state_bits = pickle.dumps(self._state)
# 這裡會對後端進行設定
set_response = self._backend.set_state(state_bits, self._token)
if set_response is not None:
state_bits, token, has_set = set_response
else:
has_set = None
if self._cache_duration > 0:
# Avoid overloading the backend if we are asked to retrieve the
# state repeatedly. Try to serve the cached state.
if self._last_sync_time >= max(time.monotonic() - self._cache_duration, 0):
return None
get_response = self._backend.get_state()
if get_response is not None:
state_bits, token = get_response
if state_bits is not None:
try:
self._state = pickle.loads(state_bits)
except pickle.PickleError as exc:
raise RendezvousStateError(
"The rendezvous state is corrupt. See inner exception for details."
) from exc
else:
self._state = _RendezvousState()
if has_set and self._dead_nodes and log.isEnabledFor(logging.DEBUG):
node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes)
self._token = token
self._dirty = False
self._last_sync_time = time.monotonic()
self._sanitize()
return has_set
def _sanitize(self) -> None:
expire_time = datetime.utcnow() - (
self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt
)
# Filter out the dead nodes.
self._dead_nodes = [
node
for node, last_heartbeat in self._state.last_heartbeats.items()
if last_heartbeat < expire_time
]
for dead_node in self._dead_nodes:
del self._state.last_heartbeats[dead_node]
try:
del self._state.participants[dead_node]
except KeyError:
pass
try:
self._state.wait_list.remove(dead_node)
except KeyError:
pass
def mark_dirty(self) -> None:
"""See base class.
If the local rendezvous state is dirty, the next sync call will try to
write the changes back to the backend. However this attempt might fail
if another node, which had the same state, also made changes and wrote
them before us.
"""
self._dirty = True
3.6.3 如何使用
StateHolder 具體如何使用在 _DistributedRendezvousOpExecutor 之中有(以下程式碼精簡):
- 通過 _state_holder.sync() 同步各種狀態,因為最新狀態在 rendezvous。
- 通過 self._state_holder.state 得到最新的狀態。
- 進行業務處理。
- 通過 _state_holder.mark_dirty() 再次同步,把自己狀態同步給其他節點
def run(
self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
"""See base class."""
action = None
while action != _Action.FINISH:
# Reads or writes the latest rendezvous state shared by all nodes in
# the rendezvous. Note that our local changes might get overridden
# by another node if that node synced its changes before us.
has_set = self._state_holder.sync() # 這裡要同步各種狀態,因為最新狀態在 rendezvous。
self._state = self._state_holder.state # 得到最新的狀態
ctx = _RendezvousContext(self._node, self._state, self._settings)
# Determine the next action to take based on the current state of
# the rendezvous.
action = state_handler(ctx, deadline)
# 省略部分程式碼
if action == _Action.SYNC:
# Delay the execution by one second to avoid overloading the
# backend if we are asked to poll for state changes.
_delay(seconds=1)
else:
if action == _Action.KEEP_ALIVE:
self._keep_alive()
elif action == _Action.ADD_TO_PARTICIPANTS:
self._add_to_participants()
elif action == _Action.ADD_TO_WAIT_LIST:
self._add_to_wait_list()
elif action == _Action.REMOVE_FROM_PARTICIPANTS:
self._remove_from_participants()
elif action == _Action.REMOVE_FROM_WAIT_LIST:
self._remove_from_wait_list()
elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
self._mark_rendezvous_complete()
elif action == _Action.MARK_RENDEZVOUS_CLOSED:
self._mark_rendezvous_closed()
# Attempt to sync our changes back to other nodes.
self._state_holder.mark_dirty() # 再次同步,把自己狀態同步給其他節點
3.7 總結
我們把目前邏輯總結如下,兩個 _BackendRendezvousStateHolder 通過 TCPStore 進行資訊互動。
+
+-------------------------------+ | +-------------------------------+
| _BackendRendezvousStateHolder | | | _BackendRendezvousStateHolder |
| | +-------------------+ | +--------------------+ | |
| _settings +-----------> | RendezvousSettings| | | RendezvousSettings | <----------+ _settings |
| | +-------------------+ | +--------------------+ | |
| | +-------------------+ | +--------------------+ | |
| _state +--------------> | _RendezvousState | | | _RendezvousState | <----------+ _state |
| | | | | | | | |
| | +-------------------+ | +--------------------+ | |
| | | | |
| | +-----------------------+ + +----------------------+ | |
| _backend +------------> | C10dRendezvousBackend | | C10dRendezvousBackend| <-------+ _backend |
| | | | +---------+ | | | |
| | | _store +-----> |TCPStore | <---------+ _store | | |
| | | | | | | | | |
| | +-----------------------+ +---------+ +----------------------+ | |
| | | |
| | ^ + ^ | |
| | | | | | |
| | | | | | |
| sync +----------------------+ | +---------------------+ sync |
| | set_state | set_state | |
+-------------------------------+ + +-------------------------------+
手機如下:
0x04 動態邏輯
4.1 入口
我們先看看如何使用 Rendezvous。
launch_agent 啟動了一個 LocalElasticAgent,呼叫了其 run 方法。在呼叫 run 之前,會生成 rdzv_handler,然後設定到 WorkerSpec 之中。
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
@record
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
args: List[Any],
) -> Dict[int, Any]:
rdzv_parameters = RendezvousParameters(
backend=config.rdzv_backend,
endpoint=config.rdzv_endpoint,
run_id=config.run_id,
min_nodes=config.min_nodes,
max_nodes=config.max_nodes,
**config.rdzv_configs,
)
# 構建了 rdzv_handler
rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
try:
spec = WorkerSpec(
role=config.role,
local_world_size=config.nproc_per_node,
entrypoint=entrypoint,
args=tuple(args),
rdzv_handler=rdzv_handler, # 這裡設定了 rdzv_handler
max_restarts=config.max_restarts,
monitor_interval=config.monitor_interval,
redirects=config.redirects,
tee=config.tee,
master_addr=master_addr,
master_port=master_port,
)
agent = LocalElasticAgent( # 構建
spec=spec, start_method=config.start_method, log_dir=config.log_dir
)
result = agent.run() # 啟動代理
except ChildFailedError:
run 函式中,最終會呼叫到 self._rendezvous(worker_group)
,_rendezvous 方法會 呼叫 next_rendezvous() 來處理成員關係變化。
@prof
def _rendezvous(self, worker_group: WorkerGroup) -> None:
r"""
Runs rendezvous for the workers specified by worker spec.
Assigns workers a new global rank and world size.
Updates the rendezvous store for the worker group.
"""
spec = worker_group.spec
store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
# 省略後續程式碼
在這個流程之中,rdzv_registry.get_rendezvous_handler(rdzv_parameters)
是最初的來源,因此,我們要看看 get_rendezvous_handler。而 get_rendezvous_handler 會返回 RendezvousHandler,所以 RendezvousHandler 和 rendezvous_handler_registry 才是根本。
from .api import rendezvous_handler_registry as handler_registry
def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
"""
This method is used to obtain a reference to a :py:class`RendezvousHandler`.
Custom rendezvous handlers can be registered by
::
from torch.distributed.elastid.rendezvous import rendezvous_handler_registry
from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler
def create_my_rdzv(params: RendezvousParameters):
return MyCustomRdzv(params)
rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters)
"""
return handler_registry.create_handler(params)
我們接下來就分別看看 RendezvousHandler 和 rendezvous_handler_registry。
4.2 基類 RendezvousHandler
RendezvousHandler 用來執行業務邏輯,幾個虛擬函式是:
- next_rendezvous :rendezvous barrier 的主要入口,新加入的節點會等待在這裡,直到當前rendezvous結束,或者超時,或者當前 rendezvous 被標識為closed。
- is_closed :是否已經結束,如果rendezvous結束,意味著所有試圖re-rendezvous都將失敗。
- num_nodes_waiting :返回在rendezvous barrier等待的當前階段數目,這些節點不屬於當前工作組。使用者應該週期呼叫這個方法來檢查是否有新節點等待加入工作組,如果有,就呼叫
next_rendezvous()
(re-rendezvous。) 進行下一次re-rendezvous。
具體程式碼如下:
class RendezvousHandler(ABC):
"""Main rendezvous interface.
Note:
Distributed Torch users normally **do not** need to implement their own
``RendezvousHandler``. An implementation based on C10d Store is already
provided, and is recommended for most users.
"""
# 獲取 rendezvous backend名字
@abstractmethod
def get_backend(self) -> str:
"""Returns the name of the rendezvous backend."""
# rendezvous barrier 的主要入口,新加入的節點會等待在這裡,直到當前rendezvous結束,或者超時,或者當前 rendezvous 被標識為closed。
@abstractmethod
def next_rendezvous(
self,
) -> Tuple[Store, int, int]:
"""Main entry-point into the rendezvous barrier.
Blocks until the rendezvous is complete and the current process is
included in the formed worker group, or a timeout occurs, or the
rendezvous was marked closed.
Returns:
A tuple of :py:class:`torch.distributed.Store`, ``rank``, and
``world size``.
"""
# 是否已經結束,如果rendezvous結束,意味著所有試圖re-rendezvous都將失敗
@abstractmethod
def is_closed(self) -> bool:
"""Checks whether the rendezvous has been closed.
A closed rendezvous means all future attempts to re-rendezvous within
same job will fail.
"""
@abstractmethod
def set_closed(self):
"""Marks the rendezvous as closed."""
# 返回在rendezvous barrier等待的當前階段數目,這些節點不屬於當前工作組。使用者應該週期呼叫這個方法來檢查是否有新節點等待加入工作組,如果有,就呼叫`next_rendezvous()` (re-rendezvous。) 進行下一次re-rendezvous。
@abstractmethod
def num_nodes_waiting(self) -> int:
"""Returns the number of nodes who arrived late at the rendezvous
barrier, hence were not included in the current worker group.
Callers should periodically call this method to check whether new
nodes are waiting to join the job and if so admit them by calling
:py:meth:`next_rendezvous()` (re-rendezvous).
"""
@abstractmethod
def get_run_id(self) -> str:
"""Returns the run id of the rendezvous.
The run id is a user-defined id that uniquely identifies an instance of
a distributed application. It typically maps to a job id and is used to
allow nodes to join the correct distributed application.
"""
def shutdown(self) -> bool:
"""Closes all resources that were open for the rendezvous.
"""
4.3 註冊
我們接下來看看 rendezvous_handler_registry。
在 torch/distributed/elastic/rendezvous/api.py 之中有如下程式碼。
# The default global registry instance used by launcher scripts to instantiate
# rendezvous handlers.
rendezvous_handler_registry = RendezvousHandlerRegistry()
所以我們來到了 RendezvousHandlerRegistry。
4.3.1 RendezvousHandlerRegistry
RendezvousHandlerRegistry 是一個負責建立 RendezvousHandler 的工廠類。
- register 就是往內部字典新增對應的構建器。
- create_handler 就是依據key,取出對應的構建器。
- rendezvous_handler_registry 是全域性Registry。
class RendezvousHandlerRegistry:
"""Represents a registry of :py:class:`RendezvousHandler` backends."""
_registry: Dict[str, RendezvousHandlerCreator]
def __init__(self) -> None:
self._registry = {}
def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
"""Registers a new rendezvous backend.
Args:
backend:
The name of the backend.
creater:
The callback to invoke to construct the
:py:class:`RendezvousHandler`.
"""
current_creator: Optional[RendezvousHandlerCreator]
current_creator = self._registry[backend]
self._registry[backend] = creator
def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
"""Creates a new :py:class:`RendezvousHandler`."""
creator = self._registry[params.backend]
handler = creator(params)
return handler
4.3.2 全域性registry
系統會建立一個全域性的 registry,就是前面看到的 rendezvous_handler_registry。
# The default global registry instance used by launcher scripts to instantiate
# rendezvous handlers.
rendezvous_handler_registry = RendezvousHandlerRegistry()
在這裡註冊了若干handler,用來提供建立器。rendezvous
提供瞭如下實現,分別是 etcd
、etcd-v2
、c10d
和 static
。
from .api import rendezvous_handler_registry as handler_registry
def _register_default_handlers() -> None:
handler_registry.register("etcd", _create_etcd_handler)
handler_registry.register("etcd-v2", _create_etcd_v2_handler)
handler_registry.register("c10d", _create_c10d_handler)
handler_registry.register("static", _create_static_handler)
執行時候就是:
rendezvous_handler_registry =
_registry = {dict: 4}
'etcd' = {function} <function _create_etcd_handler at 0x7ff657e12d08>
'etcd-v2' = {function} <function _create_etcd_v2_handler at 0x7ff657e12d90>
'c10d' = {function} <function _create_c10d_handler at 0x7ff657e12e18>
'static' = {function} <function _create_static_handler at 0x7ff657b9d2f0>
__len__ = {int} 4
其含義就是:_create_etcd_handler 可以建立 etcd 型別的handler,以此類推。
4.4 建立
既然有了建立途徑,我們就來看看如何建立。rendezvous
提供瞭如下實現,分別是 etcd
、etcd-v2
、c10d
和 static
,這裡我們以 static
和 c10d
為例進行說明。
4.4.1 靜態 RendezvousHandler
我們使用 _create_static_handler 舉例,看看如何建立 static 型別的 handler。
首先從 _create_static_handler 入手。
4.4.1.1 _create_static_handler
def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler:
from . import static_tcp_rendezvous
return static_tcp_rendezvous.create_rdzv_handler(params)
於是我們來到了torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py。在其中有 create_rdzv_handler 建立了 StaticTCPRendezvous。
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
endpoint = params.endpoint.strip()
master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1)
world_size = params.max_nodes
rank = cast(int, params.config.get("rank"))
run_id = params.run_id
if "timeout" in params.config:
timeout = int(params.config["timeout"])
else:
timeout = _default_timeout_seconds
return StaticTCPRendezvous(
master_addr, master_port, rank, world_size, run_id, timeout
)
4.4.1.2 StaticTCPRendezvous 子類
StaticTCPRendezvous 擴充了RendezvousHandler,其定義如下,其最主要邏輯是:在 group_rank = 0 之上建立一個 TCPStore,然後封裝成一個PrefixStore。
class StaticTCPRendezvous(RendezvousHandler):
"""
Static rendezvous that is a wrapper around the TCPStore.
Creates TCPStore based on the input parameters with the
listener on the agent with group_rank=0
"""
def __init__(
self,
master_addr: str,
master_port: int,
rank: int,
world_size: int,
run_id: str,
timeout: int,
):
self.master_addr = master_addr
self.master_port = master_port
self.rank = rank
self.world_size = world_size
self.run_id = run_id
self.timeout = datetime.timedelta(seconds=timeout)
self._store: Optional[Store] = None
def get_backend(self) -> str:
return "static"
def next_rendezvous(self) -> Tuple[Store, int, int]:
if not self._store:
is_master = self.rank == 0
self._store = TCPStore(
self.master_addr,
self.master_port,
self.world_size,
is_master,
self.timeout,
)
store = PrefixStore(self.run_id, self._store)
return store, self.rank, self.world_size
關鍵函式
def next_rendezvous(self) -> Tuple[Store, int, int]:
log.info("Creating TCPStore as the c10d::Store implementation")
if not self._store:
is_master = self.rank == 0
self._store = TCPStore(
self.master_addr,
self.master_port,
self.world_size,
is_master,
self.timeout,
)
store = PrefixStore(self.run_id, self._store)
return store, self.rank, self.world_size
4.4.2 動態 RendezvousHandler
我們接下來看看如何構建 DynamicRendezvousHandler。
4.4.2.1 _create_c10d_handler
這裡 _create_c10d_handler 會返回一個 DynamicRendezvousHandler。
def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler:
from .c10d_rendezvous_backend import create_backend
backend, store = create_backend(params)
return create_handler(store, backend, params)
這裡返回了 DynamicRendezvousHandler。
def create_handler(
store: Store, backend: RendezvousBackend, params: RendezvousParameters
) -> DynamicRendezvousHandler:
"""Creates a new :py:class:`DynamicRendezvousHandler` from the specified
parameters.
Args:
store:
The C10d store to return as part of the rendezvous.
backend:
The backend to use to hold the rendezvous state.
+-------------------+------------------------------------------------------+
| Parameter | Description |
+===================+======================================================+
| join_timeout | The total time, in seconds, within which the |
| | rendezvous is expected to complete. Defaults to 600 |
| | seconds. |
+-------------------+------------------------------------------------------+
| last_call_timeout | An additional wait amount, in seconds, before |
| | completing the rendezvous once the minimum number of |
| | nodes has been reached. Defaults to 30 seconds. |
+-------------------+------------------------------------------------------+
| close_timeout | The time, in seconds, within which the rendezvous is |
| | expected to close after a call to |
| | :py:meth:`RendezvousHandler.set_closed` or |
| | :py:meth:`RendezvousHandler.shutdown`. Defaults to |
| | 30 seconds. |
+-------------------+------------------------------------------------------+
"""
timeout = RendezvousTimeout(
_get_timeout(params, "join"),
_get_timeout(params, "last_call"),
_get_timeout(params, "close"),
)
return DynamicRendezvousHandler.from_backend(
params.run_id,
store,
backend,
params.min_nodes,
params.max_nodes,
timeout,
)
4.4.2.2 from_backend
from_backend 是具體生成 DynamicRendezvousHandler 的方法,相當於生成器。
其生成了 RendezvousSettings,_BackendRendezvousStateHolder 和 node,然後建立了 DynamicRendezvousHandler。
@classmethod
def from_backend(
cls,
run_id: str,
store: Store,
backend: RendezvousBackend,
min_nodes: int,
max_nodes: int,
timeout: Optional[RendezvousTimeout] = None,
):
"""Creates a new :py:class:`DynamicRendezvousHandler`.
Args:
run_id:
The run id of the rendezvous.
store:
The C10d store to return as part of the rendezvous.
backend:
The backend to use to hold the rendezvous state.
min_nodes:
The minimum number of nodes to admit to the rendezvous.
max_nodes:
The maximum number of nodes to admit to the rendezvous.
timeout:
The timeout configuration of the rendezvous.
"""
# We associate each handler instance with a unique node descriptor.
node = cls._node_desc_generator.generate()
settings = RendezvousSettings(
run_id,
min_nodes,
max_nodes,
timeout or RendezvousTimeout(),
keep_alive_interval=timedelta(seconds=5),
keep_alive_max_attempt=3,
)
state_holder = _BackendRendezvousStateHolder(backend, settings)
return cls(node, settings, backend.name, store, state_holder)
4.4.2.3 DynamicRendezvousHandler
Torch Distributed Elastic comes with the :py:class:`.DynamicRendezvousHandler`
class that implements the rendezvous mechanism described above. It is a backend-
agnostic type that expects a particular :py:class:`.RendezvousBackend` instance
to be specified during construction.
Torch distributed users can either implement their own backend type or use one
of the following implementations that come with PyTorch:
DynamicRendezvousHandler 擴充了RendezvousHandler,其定義如下,其最主要邏輯是:在 group_rank = 0 之上建立一個 TCPStore,然後封裝成一個 PrefixStore。
最主要的是如下幾個成員變數:
- _BackendRendezvousStateHolder 負責在 Rendezvous 之間協調資訊。
- _DistributedRendezvousOpExecutor 負責具體執行業務。
- _store 負責儲存資訊(分散式)。
class DynamicRendezvousHandler(RendezvousHandler):
"""Represents a handler that sets up a rendezvous among a set of nodes."""
# Static
_node_desc_generator = _NodeDescGenerator()
_this_node: _NodeDesc
_settings: RendezvousSettings
_backend_name: str
_store: Store
_state_holder: _RendezvousStateHolder
_op_executor: _RendezvousOpExecutor
_heartbeat_lock: threading.Lock
_keep_alive_timer: Optional[_PeriodicTimer]
@classmethod
def from_backend(
cls,
run_id: str,
store: Store,
backend: RendezvousBackend,
min_nodes: int,
max_nodes: int,
timeout: Optional[RendezvousTimeout] = None,
):
"""Creates a new :py:class:`DynamicRendezvousHandler`.
Args:
run_id:
The run id of the rendezvous.
store:
The C10d store to return as part of the rendezvous.
backend:
The backend to use to hold the rendezvous state.
min_nodes:
The minimum number of nodes to admit to the rendezvous.
max_nodes:
The maximum number of nodes to admit to the rendezvous.
timeout:
The timeout configuration of the rendezvous.
"""
# We associate each handler instance with a unique node descriptor.
node = cls._node_desc_generator.generate()
settings = RendezvousSettings(
run_id,
min_nodes,
max_nodes,
timeout or RendezvousTimeout(),
keep_alive_interval=timedelta(seconds=5),
keep_alive_max_attempt=3,
)
state_holder = _BackendRendezvousStateHolder(backend, settings)
return cls(node, settings, backend.name, store, state_holder)
def __init__(
self,
node: _NodeDesc,
settings: RendezvousSettings,
backend_name: str,
store: Store,
state_holder: _RendezvousStateHolder,
) -> None:
self._this_node = node
self._settings = settings
self._backend_name = backend_name
self._store = store
self._state_holder = state_holder
self._op_executor = _DistributedRendezvousOpExecutor(
self._this_node, self._state_holder, self._settings
)
self._heartbeat_lock = threading.Lock()
self._keep_alive_timer = None
我們也可以用如下方式直接生成 DynamicRendezvousHandler。
store = TCPStore("localhost")
backend = C10dRendezvousBackend(store, "my_run_id")
rdzv_handler = DynamicRendezvousHandler.from_backend(
run_id="my_run_id",
store=store,
backend=backend,
min_nodes=2,
max_nodes=4
)
4.4.2.4 next_rendezvous
這一函式呼叫會被阻塞,直到 worker 的數量達到了要求。在 worker 被初始化,或者重啟的時候,這一函式都會被呼叫。當函式返回時,不同的 worker group 會得到一個 rank 作為唯一的標示。其內部邏輯是:
- 先使用
_RendezvousExitOp
讓該node退出。 - 然後再使用
_RendezvousJoinOp
把該node重新加入。 - 最後啟動心跳,返回world size,store等,此時所有參與的Node都在participants之中。
def next_rendezvous(self) -> Tuple[Store, int, int]:
"""See base class."""
self._stop_heartbeats()
# Delay the execution for a small random amount of time if this is our
# first run. This will slightly skew the rendezvous attempts across the
# nodes and reduce the load on the backend.
if self._state_holder.state.round == 0:
_delay(seconds=(0, 0.3))
exit_op = _RendezvousExitOp()
join_op = _RendezvousJoinOp()
deadline = self._get_deadline(self._settings.timeout.join)
self._op_executor.run(exit_op, deadline)
self._op_executor.run(join_op, deadline)
self._start_heartbeats()
rank, world_size = self._get_world()
store = self._get_store()
return store, rank, world_size # 返回的是 worker group 的rank
4.4.2.5 _get_world
上面程式碼之中,使用了 _get_world,這裡我們再分析一下。rank, world_size 這兩個變數是動態生成的,所以從 state 之中取出。而且,因為 participants 是在所有Node之間同步的,所以每個Node得到的 participants 完全一致。
rank, world_size = self._get_world()
def _get_world(self) -> Tuple[int, int]:
state = self._state_holder.state
return state.participants[self._this_node], len(state.participants)
state.participants 從哪裡來?在 rendezvous 結束時候,會設定 rank。因為每個節點上都是按照同樣演算法排序,所以rank 排序在每個節點上都是一樣的。可以保證每個Node得到的rank是與其他Node不同的。
def _mark_rendezvous_complete(self) -> None:
state = self._state
state.complete = True
state.deadline = None
# Assign the ranks.
for rank, node in enumerate(sorted(state.participants)):
state.participants[node] = rank
4.5 容錯
前面提到:在開始join rendezvous 和 rendezvous 完成之間,如果有程式崩潰(或網路故障等),就會自動引發一個re-rendezvous,剩餘健康節點會自動重組。
Torch Distributed Elastic rendezvous is designed to tolerate node failures during the rendezvous process. Should a process crash (or lose network connectivity, etc), between joining the rendezvous and it being completed, then a re-rendezvous with remaining healthy nodes will happen automatically.
4.5.1 ETCD
這部分容錯機制在EtcdRendezvousHandler 之中體現的特別明顯。
next_rendezvous 方法會呼叫 rendezvous_barrier。
def next_rendezvous(self):
rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier()
log.info("Creating EtcdStore as the c10d::Store implementation")
store = self._rdzv_impl.setup_kv_store(rdzv_version)
return store, rank, world_size
在 rendezvous_barrier 之中,如果底層丟擲各種異常,則會捕獲,然後呼叫 init_phase 再次執行一次rendezvous,直到deadline時間到為止。
def rendezvous_barrier(self):
"""
Main entry point for next rendezvous.
This method is blocking until rendezvous succeeds or a timeout occurs.
Returns:
``(rdzv_version, rank, world_size)``
Raises:
RendezvousTimeoutError - timeout waiting for rendezvous
RendezvousClosedError - rendezvous is or was closed while waiting
RendezvousError - other persistent errors that
render the rendezvous non-retryable
"""
self._rendezvous_deadline = time.time() + self._timeout
while True:
if time.time() > self._rendezvous_deadline:
raise RendezvousTimeoutError()
log.info("Attempting to join next rendezvous")
try:
# Dis-own our lease in the previous rendezvous, if exists
if self._lease_this_rank_stop is not None:
self._lease_this_rank_stop.set()
return self.init_phase()
except EtcdRendezvousRetryImmediately:
# The type of failure suggests we can retry without delay
pass
except EtcdRendezvousRetryableFailure:
# In case of retryable failure, wait a small delay
# to avoid spamming etcd
time.sleep(1)
except RendezvousTimeoutError:
log.info("Rendezvous timeout occured in EtcdRendezvousHandler")
raise
except RendezvousClosedError:
log.info(
f"Rendezvous for run_id={self._run_id} was observed to be closed"
)
raise
except RendezvousError:
raise
except Exception as e:
# In case of a general exception, wait a small delay
# to avoid spamming etcd
# FIXME: there are a few things that fall under this like
# etcd.EtcdKeyNotFound, etc, which could be handled more explicitly.
log.info("Rendezvous attempt failed, will retry. Reason: " + str(e))
time.sleep(1)
init_phase 會發起一輪 rendezvous。
def init_phase(self):
"""
Initially, the rendezvous state is expected to be one of:
1. empty (non-existent) - in this case we try to create a new one.
2. joinable - we try to join it.
3. final - we announce ourselves as waiting, and go into monitoring mode
Any other state is considered transitional, and will be retried after
a short delay.
Returns:
``(rdzv_version, rank, world_size)``
Raises:
RendezvousClosedError - current rendezvous was/is closed
EtcdRendezvousRetryableFailure - observed some intermediate
state, which is best handled by retrying later
"""
try:
active_version = self.try_create_rendezvous() # 發起一輪rendezvous
state = json.loads(active_version.value)
log.info("New rendezvous state created: " + str(state))
except etcd.EtcdAlreadyExist:
active_version, state = self.get_rdzv_state()
# Note: it is possible for above query to fail (etcd.EtcdKeyNotFound),
# but this is ok for us - just means we'll restart from beginning.
log.info("Observed existing rendezvous state: " + str(state))
if state["status"] == "closed":
raise RendezvousClosedError()
if state["status"] == "joinable":
return self.join_phase(state["version"])
if state["status"] == "final":
self.handle_existing_rendezvous(state["version"])
raise EtcdRendezvousRetryImmediately()
self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1)
raise EtcdRendezvousRetryableFailure()
4.5.2 DynamicRendezvousHandler
DynamicRendezvousHandler 之中就體現的不明顯,應該是因為 DynamicRendezvousHandler 是在ETCD之後開發,所以很多功能不完善,在演進之中。
本系列是基於PyTorch 1.9 為主進行分析,所以上面 next_rendezvous 程式碼之中沒有錯誤處理,直接拋到最外面去了。在2021-12月最新程式碼之中,已經加入了錯誤處理,後續應該還會繼續完善。
def next_rendezvous(self) -> Tuple[Store, int, int]:
"""See base class."""
msg = (
f"The node '{self._this_node}' attempts to join the next round of the rendezvous "
f"'{self._settings.run_id}'."
)
self._record(message=msg)
try: # 加入了錯誤處理
self._stop_heartbeats()
# Delay the execution for a small random amount of time if this is our
# first run. This will slightly skew the rendezvous attempts across the
# nodes and reduce the load on the backend.
if self._state_holder.state.round == 0:
_delay(seconds=(0, 0.3))
exit_op = _RendezvousExitOp()
join_op = _RendezvousJoinOp()
deadline = self._get_deadline(self._settings.timeout.join)
self._op_executor.run(exit_op, deadline)
self._op_executor.run(join_op, deadline)
self._start_heartbeats()
rank, world_size = self._get_world()
store = self._get_store()
except Exception as e: # 加入了錯誤處理,但是沒有發起下一輪rendezvous
self._record(
message=f"{type(e).__name__}: {str(e)}",
node_state=NodeState.FAILED,
)
raise
msg = (
f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of "
f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size "
f"{world_size}."
)
self._record(message=msg, rank=rank)
return store, rank, world_size
4.6 小結
Rendezvous 和 Agent 之間的邏輯聯絡總結如下,每個啟動指令碼都有這麼一套機制。若干啟動指令碼的機制之間會互相聯絡。
+-----------------------------+ +------------------------------------------------+
| LocalElasticAgent | | WorkerSpec |
| | | |
| +------------------------+ | | rdzv_handler = {DynamicRendezvousHandler} -------+
| |WorkerGroup | | | | |
| | spec +--------------> | entry = worker_fn | |
| | workers | | | | |
| | store | | | role = {str} 'trainer' | |
| | group_rank | | | | |
| | group_world_size | | +------------------------------------------------+ |
| | | | |
| +------------------------+ | |
| | |
| rdzv_run_id | |
| store | +-----------------------------------------+ |
| | |DynamicRendezvousHandler | |
+-----------------------------+ | | |
| | |
| _settings: RendezvousSettings | <--+
| |
| _store: Store |
| |
| _state_holder: _RendezvousStateHolder |
| |
| _op_executor: _RendezvousOpExecutor |
| |
+-----------------------------------------+
或者和前面的靜態邏輯結合起來看看。
+------------------------+ +----------------------------------------------+ + +------------------------+ +---------------------------------------------+
| LocalElasticAgent | | WorkerSpec | | | LocalElasticAgent | | WorkerSpec |
| | | | | | | | |
| +--------------------+ | | rdzv_handler = {DynamicRendezvousHandler} ----+ | | +--------------------+ | | rdzv_handler = {DynamicRendezvousHandler}-----+
| | WorkerGroup | | | | | | | | WorkerGroup | | | | |
| | spec +----------->+ entry = worker_fn | | | | | spec +----------->+ entry = worker_fn | |
| | workers | | | | | | | | workers | | | | |
| | store | | | role = {str} 'trainer' | | | | | store | | | role = {str} 'trainer' | |
| | group_rank | | | | | | | | group_rank | | | | |
| | group_world_size | | +----------------------------------------------+ | | | | group_world_size | | +---------------------------------------------+ |
| | | | +--------------------------------------------+ | | | | | | +--------------------------------------------+ |
| +--------------------+ | | DynamicRendezvousHandler | | | | +--------------------+ | | DynamicRendezvousHandler | |
| rdzv_run_id | | | | | | rdzv_run_id | | | |
| store | | | | | | store | | | |
+------------------------+ | _settings: RendezvousSettings | | | +------------------------+ | _settings: RendezvousSettings | |
| | <--+ | | +<--+
| _store: Store | | | _store: Store |
| | | | |
+----------+ _state_holder: _RendezvousStateHolder | | | _state_holder: _RendezvousStateHolder +-----+
| | | | | | |
| | _op_executor: _RendezvousOpExecutor | | | _op_executor: _RendezvousOpExecutor | |
| | | | | | |
| +--------------------------------------------+ | +--------------------------------------------+ |
v | |
+---------+---------------------+ | +-------------------------------+ |
| _BackendRendezvousStateHolder | | | _BackendRendezvousStateHolder | |
| | +-------------------+ | +--------------------+ | | |
| _settings +-----------> | RendezvousSettings| | | RendezvousSettings | <----------+ _settings | <-------+
| | +-------------------+ | +--------------------+ | |
| | +-------------------+ | +--------------------+ | |
| _state +--------------> | _RendezvousState | | | _RendezvousState | <----------+ _state |
| | | | | | | | |
| | +-------------------+ | +--------------------+ | |
| | +-----------------------+ + +----------------------+ | |
| _backend +------------> | C10dRendezvousBackend | | C10dRendezvousBackend| <-------+ _backend |
| | | | +---------+ | | | |
| | | _store +-----> |TCPStore | <-------+ _store | | |
| | +---+-------------------+ +---+-----+ +-----------+----------+ | |
| | ^ | ^ | |
| | | | | | |
| | | | | | |
| sync +----------------------+ | +---------------------+ sync |
| | set_state NODE 1 | NODE 2 set_state | |
+-------------------------------+ + +-------------------------------+
手機如下:
0x05 總結
目前我們分析了Rendezvous的靜態結構和動態邏輯,大家對其機制有了一個基本理解,比如有如下概念:
- 節點概念_NodeDesc,這樣可以把系統表達出來。
- 狀態概念。_RendezvousState 是rendezvous的狀態。是動態資訊,每一個 node 都會維護一個本地 state。
- 總體靜態類 _BackendRendezvousStateHolder,用來把節點,狀態,後端以及其他資訊統一維護起來。
- 共享共享鍵值儲存,比如TCPStore,可以集中儲存上述資訊,也可以用來彼此交換資訊,達成共識。
- 動態server或者handler,RendezvousHandler就提供了一套API以供外界訪問。
下一篇我們介紹內部業務邏輯如何實現,即 Rendezvous 引擎。