[原始碼解析] PyTorch 分散式之彈性訓練(2)---啟動&單節點流程
0x00 摘要
在前面的文章之中,我們已經學習了PyTorch 分散式的基本模組,介紹了官方的幾個例子,我們接下來會介紹PyTorch的彈性訓練,本文是第二篇,重點關注的是如何啟動彈性訓練,並且可以對系統總體架構有所瞭解。
彈性訓練系列文章如下:
[原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路
0x01 重要概念
為了更好的說明(這個說明可能在後面文章也會出現,因為太重要了),我們先總述一下TE 最重要的 Agent 和 Rendezvous 兩個概念。
- Agent :Agent是執行在單節點上的獨立後臺程式,可以認為是 worker manager 或者 process supervisor,其負責啟動worker,監控 worker 執行,捕獲woker異常,通過
rendezvous
實現 worker 間的相互發現(比如把狀態上報到KVStore),成員變動時候基於rendezvous
進行變更同步等等。 - Rendezvous :為了實現彈性訓練,需要有一個節點/程式之間彼此發現的機制。Rendezvous就是這個發現機制或者說同步元件。當系統啟動或者成員變更時候,所有worker會(重新)集合(rendezvous)以建立一個新的程式組。
我們從原始碼中取出示意圖看看,大家先有一個總體概念。
0x02 分散式執行
2.1 方式改變
2.1.1 原有方式
我們知道,PET是從 PyTorch v1.9 合併進來的,因為合併了彈性訓練,所以分散式啟動的方式有了很大的改變。
V1.9 之前是使用 torch/distributed/launch.py 進行啟動,比如:
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
--nnodes=2 --node_rank=0 --master_addr="192.168.1.1"
--master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
and all other arguments of your training script)
此處引數含義是:
nnodes
:是參與訓練的節點數目。nproc_per_node
:每個節點上執行的程式數目。node_rank
:當前節點識別符號。master_addr
和master_port
是 master 監聽的地址和埠。
當執行時,torch.distributed.launch
會設定一些環境變數,包括 world_size
,master_addr
和 master_port
等等。然後在當前機器上建立 nproc_per_node
個程式,這些程式構成了一個本地組。如果一共有 NODE_SIZE
個機器參與訓練,則一共有 NODE_SIZE * TRAINERS_PER_NODE
個程式。如果想啟動一個分散式訓練任務,則需要在所有的機器上執行相關命令。
2.1.2 目前方式
PyTorch 1.9 使用 torch/distributed/run.py 進行啟動。如果依然採用 torch/distributed/launch.py,其實其內部已經透傳給 run.py,具體參見程式碼:
def main(args=None):
logger.warn(
"The module torch.distributed.launch is deprecated "
"and going to be removed in future."
"Migrate to torch.distributed.run"
)
args = parse_args(args)
run(args)
torch.distributed.run
是之前torch.distributed.launch
的一個超集,提供如下新功能:
- 容錯:通過重新啟動所有workers,可以優雅地處理worker故障。
- 自動:Worker 的
RANK
和WORLD_SIZE
是自動分配的。 - 彈性:允許在最小值和最大值(彈性)之間更改節點數。
為了使用彈性訓練,使用者程式碼也需要做一些修改,如果使用者的訓練指令碼已經支援 torch.distributed.launch ,則只需要修改幾處就可以使用torch.distributed.run
:
- 無需手動傳遞RANK , WORLD_SIZE , MASTER_ADDR 和 MASTER_PORT。
- 必須提供
rdzv_backend
和rdzv_endpoint
。對於大多數使用者來說,這其實就是“c10d”(參見“rendezvous“)。其實這就替代了之前的MASTER_ADDR 和 MASTER_PORT。 use_env
引數已被刪除。請從 LOCAL_RANK 環境變數中獲取local_rank (例如,os.environ["LOCAL_RANK"]
)。- 使用者需要確保指令碼中有
load_checkpoint(path)
和save_checkpoint(path)
邏輯,即手動處理Checkpoint。因為當worker失敗時,我們將使用最近的checkpoint來恢復現場,重啟所有worker。
下面是一個訓練指令碼的示例,該指令碼在每個epoch上設定檢查點,因此在失敗時最差也只是會丟失一個epoch的訓練成果。
def main():
args = parse_args(sys.argv[1:])
state = load_checkpoint(args.checkpoint_path)
initialize(state)
# torch.distributed.run ensure that this will work
# by exporting all the env vars needed to initialize the process group
torch.distributed.init_process_group(backend=args.backend)
for i in range(state.epoch, state.total_num_epochs)
for batch in iter(state.dataset)
train(batch, state.model)
state.epoch += 1
save_checkpoint(state)
所以,我們接下來看看在新模式之下,如何分散式啟動。
2.2 部署
部署一般按照如下方式。
- (C10d後端不需要)啟動 rendezvous 後端伺服器,並獲取端點(作為
--rdzv_endpoint
傳遞給啟動程式指令碼) - 單節點多 worker:在主機上啟動 launcher 以啟動代理程式,代理會建立並監視本地工作組。
- 多節點多 worker:在所有節點上使用相同的引數啟動 launcher 參加訓練。
當使用作業/群集管理器時,多節點作業的入口點命令應為 launcher。
2.3 示例
我們首先通過幾個例子來看看如何啟動分散式訓練。
2.3.1 單節點多worker啟動
單節點多worker的啟動方式如下,其實就是Standalone 模式,這是分散式模式的一種特例,具體就是針對單機多 Worker 提供了一些便利設定。
python -m torch.distributed.run
--standalone
--nnodes=1
--nproc_per_node=$NUM_TRAINERS
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
2.3.2 容錯方式啟動
如下是容錯方式啟動,固定數目workers,沒有彈性訓練。 --nproc_per_node=$NUM_TRAINERS 一般是 單節點上GPU 個數。
python -m torch.distributed.run
--nnodes=$NUM_NODES
--nproc_per_node=$NUM_TRAINERS
--rdzv_id=$JOB_ID
--rdzv_backend=c10d
--rdzv_endpoint=$HOST_NODE_ADDR
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
HOST_NODE_ADDR
, 的格式是:
2.3.3 彈性方式啟動
下面是彈性訓練,彈性區間為 (min=1
, max=4
)。通過指定rdzv引數,可以實現多機訓練,具備容錯與彈效能力。
在多臺機器上分別執行以下命令啟動:最小節點數為MIN_SIZE,最大為MAX_SIZE,利用etcd服務實現一致性和資訊同步。
python -m torch.distributed.run
--nnodes=1:4
--nproc_per_node=$NUM_TRAINERS
--rdzv_id=$JOB_ID
--rdzv_backend=c10d
--rdzv_endpoint=$HOST_NODE_ADDR
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
HOST_NODE_ADDR
, 的格式是:
關於 rendezvous backend,有幾點說明:
對於多節點訓練,需要指定:
--rdzv_id
: 一個唯一的 job id,在參與job的所有節點之間共享。--rdzv_backend
:torch.distributed.elastic.rendezvous.RendezvousHandler
的一個實現。 (--rdzv_backend
預設是static模式,不支援容錯和彈性伸縮)--rdzv_endpoint
: rendezvous backend 所執行的 endpoint,通常格式為:host:port
。就是取代了之前的 master address / port 設定。
目前,以下幾種後端可以直接使用,c10d
(推薦), etcd-v2
, and etcd
(legacy) 。為了使用 etcd-v2
或者 etcd
,需要搭建一個 v2
api開啟的 etcd server (即. --enable-v2
)。
0x03 啟動指令碼
既然以上啟動都是用 torch/distributed/run.py,所以我們仔細分析一下這個指令碼,該指令碼提供三個功能:
-
依靠"重啟所有 workers"來處理 worker 失敗;
-
自動分配 worker 的
RANK
andWORLD_SIZE
; -
彈性訓練,即 node 數目允許在minimum和maximum之間改變;
3.1 引數定義
啟動指令碼中,一些引數定義如下:
Node
- 物理例項或容器;對映到與 job manager 所協調的單元。Worker
- 分散式訓練環境中的worker。WorkerGroup
- 執行相同功能的一組worker(例如trainers)。LocalWorkerGroup
- 在同一節點上執行的工作組中的workers子集。- 一個
節點
執行LOCAL_WORLD_SIZE
個workers,這些 workers 組成LocalWorkerGroup
。 - 節點上所有
LocalWorkerGroups
組成WorkerGroups
。
- 一個
RANK
- 工作組中worker的rank,是全域性rank,可以認為是一個全域性GPU資源列表。- Rank是不穩定的,在重啟之間,本地Workers 會被分配到不同的ranks,所以不要在程式碼中對
RANK
和LOCAL_RANK
的穩定性做任何假設和依賴編碼。 - rendezvous完成後,其所有成員將對工作成員資格以及每個人在其中的角色(role)達成共識。此角色(role)使用一個介於 0 ~ world size 之間的整型來表示,被稱之為rank。
- Rank是不穩定的,在重啟之間,本地Workers 會被分配到不同的ranks,所以不要在程式碼中對
LOCAL_RANK
- 本地工作組中,某個worker 的 rank,可以認為是當前節點上的GPU資源列表。GROUP_RANK
- worker group的rank。介於0和“最大節點數”之間的數字。如果每個節點執行一個單一工作組,那GROUP_RANK
就是這個節點的rank。ROLE_RANK
- 對於具有相同角色worker來說,他們之間共享的rank,角色在“WorkerSpec”中被指定。WORLD_SIZE
- 工作組中worker的總數。因為節點會加入/離開,所以WORLD_SIZE
會變化,不能依賴WORLD_SIZE
的穩定性進行編碼。LOCAL_WORLD_SIZE
- 本地工作組的大小,即本地執行的worker數目,等於在torch.distributed.run
執行時候指定的--nproc_per_node
。目前,torch/distributed/run.py 僅支援同構的LOCAL_WORLD_SIZE
。也就是說,假設所有節點執行相同數量的本地工作者(每個角色)。ROLE_WORLD_SIZE
- 具有同樣角色的workers總數,在WorkerSpec
之中被指定。rdzv_id
- 使用者定義的id,用於唯一標識作業的工作組。這個id在每個節點加入特定工作組時候使用。rdzv_backend
-rendezvous 的後端(例如“c10d”)。這通常是一個強一致性的鍵值儲存。rdzv_endpoint
- rendezvous 後端端點;通常以“<host>:<port>
”的形式出現。run_id
: 使用者定義的id,它唯一地標識分散式應用程式的一個例項。它通常對映到作業id並用於允許節點加入正確的分散式應用程式。TORCHELASTIC_RUN_ID
- 與 rendezvousrun_id
相等,即唯一的job id。TORCHELASTIC_RESTART_COUNT
- 迄今為止,工作組重啟的次數。TORCHELASTIC_MAX_RESTARTS
- 配置的最大重啟數目。
3.2 相關函式/變數
為了更好的理解上面的引數,我們選取部分相關函式/變數看看。
world_size,rank
這兩個變數是動態生成的,所以從 state 之中取出。
rank, world_size = self._get_world()
def _get_world(self) -> Tuple[int, int]:
state = self._state_holder.state
return state.participants[self._this_node], len(state.participants)
_pg_group_ranks
該全域性變數儲存了每個 group 的 global rank 到 local rank 對映資訊。
# Process group's global rank to local rank mapping
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
其賦值舉例如下:
# Create the global rank to group rank mapping
_pg_group_ranks[pg] = {
global_rank: group_rank
for group_rank, global_rank in enumerate(ranks)
}
group_rank
我們可以利用 global rank 從 _pg_group_ranks 之中提取對應的 local rank。
def _get_group_rank(group: ProcessGroup, rank):
"""
Helper that gets a given group's local rank in the group from a given global
rank.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
"rank mapping")
if group not in _pg_group_ranks:
raise RuntimeError("The given group does not exist")
try:
group_rank = _pg_group_ranks[group][rank]
except KeyError:
raise RuntimeError(f"The global rank {rank} is not part of the group {group}") from None
return group_rank
global_rank
我們可以利用一個 group 的 local rank 獲取到其 gloabl rank。
def _get_global_rank(group, group_rank):
"""
Helper that gets a given group's global rank from a given local rank in the
group.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
"rank mapping")
group_rank_map = _pg_group_ranks[group]
for rank, grp_rank in group_rank_map.items():
if grp_rank == group_rank:
return rank
raise RuntimeError("The group rank is not part of the group")
group_size
我們可以 _get_group_size 獲取到某一個group 的大小。
def _get_group_size(group):
"""
Helper that gets a given group's world size.
"""
if group is GroupMember.WORLD or group is None:
default_pg = _get_default_group()
return default_pg.size()
if group not in _pg_group_ranks:
raise RuntimeError("The given group does not exist")
return len(_pg_group_ranks[group])
nproc_per_node
這個變數可以得到每個node之上支援多少個程式。
def determine_local_world_size(nproc_per_node: str):
try:
logging.info(f"Using nproc_per_node={nproc_per_node}.")
return int(nproc_per_node)
except ValueError:
if nproc_per_node == "cpu":
num_proc = os.cpu_count()
device_type = "cpu"
elif nproc_per_node == "gpu":
if not torch.cuda.is_available():
raise ValueError("Cuda is not available.")
device_type = "gpu"
num_proc = torch.cuda.device_count()
elif nproc_per_node == "auto":
if torch.cuda.is_available():
num_proc = torch.cuda.device_count()
device_type = "gpu"
else:
num_proc = os.cpu_count()
device_type = "cpu"
else:
raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}")
)
return num_proc
3.3 指令碼入口
指令碼入口主要程式碼如下,可以看到,其呼叫到了 elastic_launch 來完成功能,所以我們下一節就要順藤摸瓜來看看這個函式。
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
def run(args):
if args.standalone: # 有兩種模式:Standalone 模式和分散式模式,這裡要判斷一下
args.rdzv_backend = "c10d"
args.rdzv_endpoint = "localhost:29400"
args.rdzv_id = str(uuid.uuid4())
log.info(
f"\n**************************************\n"
f"Rendezvous info:\n"
f"--rdzv_backend={args.rdzv_backend} "
f"--rdzv_endpoint={args.rdzv_endpoint} "
f"--rdzv_id={args.rdzv_id}\n"
f"**************************************\n"
)
config, cmd, cmd_args = config_from_args(args)
elastic_launch(
config=config,
entrypoint=cmd,
)(*cmd_args)
def main(args=None):
args = parse_args(args)
run(args)
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="[%(levelname)s] %(asctime)s %(module)s: %(message)s"
)
main()
0x04 單體總體流程
我們下面就從 elastic_launch 開始,看看在單節點上如何啟動執行。我們首先給出一個總體示意圖,圖上是兩個節點,每個節點有一個 agent,agent下面是一個 worker group,組下面是4個worker。
4.1 小例子
我們再從原始碼中找一個例子來看看,這裡只是設定了兩個workers。
import uuid
import torch
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
def worker_fn(t1, t2):
return torch.add(t1, t2)
def main():
t1 = torch.rand((3,3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
config = LaunchConfig(
min_nodes=2,
max_nodes=4,
nproc_per_node=1,
run_id=str(uuid.uuid4()),
role="trainer",
rdzv_endpoint="localhost:29400",
rdzv_backend="c10d",
max_restarts=1,
monitor_interval=1,
start_method="spawn",
)
outputs = elastic_launch(config, worker_fn)(t1, t2)
if __name__ == '__main__':
main()
輸出如下,可以看到有兩個 worker 程式 和一個 agent 程式。
{"name": "torchelastic.worker.status.SUCCEEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": 0, "group_rank": 0, "worker_id": "12172", "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\", \"local_rank\": [0], \"role_rank\": [0], \"role_world_size\": [2]}", "agent_restarts": 0}}
{"name": "torchelastic.worker.status.SUCCEEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": 1, "group_rank": 0, "worker_id": "3276", "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\", \"local_rank\": [1], \"role_rank\": [1], \"role_world_size\": [2]}", "agent_restarts": 0}}
{"name": "torchelastic.worker.status.SUCCEEDED", "source": "AGENT", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": null, "group_rank": 0, "worker_id": null, "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\"}", "agent_restarts": 0}}
4.2 入口
順著程式碼我們深入挖掘一下。elastic_launch 的作用就是啟動一個 torchelastic agent,然後通過這個 agent來呼叫使用者程式入口,agent 會啟動 worker 進行訓練,並且管理 worker 生命週期。
class elastic_launch:
"""
Launches an torchelastic agent on the container that invoked the entrypoint.
1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
``entrypoint`` can be a function or a command.
2. The return value is a map of each worker's output mapped
by their respective global rank.
"""
def __init__(
self,
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
):
self._config = config
self._entrypoint = entrypoint
def __call__(self, *args, **kwargs):
return launch_agent(self._config, self._entrypoint, list(args)) # 內部會呼叫使用者程式
4.3 啟動代理
launch_agent 啟動了一個 LocalElasticAgent,呼叫了其 run 方法。
@record
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
args: List[Any],
) -> Dict[int, Any]:
if not config.run_id:
run_id = str(uuid.uuid4().int)
config.run_id = run_id
entrypoint_name = _get_entrypoint_name(entrypoint, args)
rdzv_parameters = RendezvousParameters(
backend=config.rdzv_backend,
endpoint=config.rdzv_endpoint,
run_id=config.run_id,
min_nodes=config.min_nodes,
max_nodes=config.max_nodes,
**config.rdzv_configs,
)
agent = None
rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
try:
spec = WorkerSpec( # 1. 得到spec
role=config.role,
local_world_size=config.nproc_per_node,
entrypoint=entrypoint,
args=tuple(args),
rdzv_handler=rdzv_handler, # RendezvousHandler
max_restarts=config.max_restarts,
monitor_interval=config.monitor_interval,
redirects=config.redirects,
tee=config.tee,
master_addr=master_addr,
master_port=master_port,
)
cfg = metrics.MetricsConfig(config.metrics_cfg) if config.metrics_cfg else None
metrics.initialize_metrics(cfg)
agent = LocalElasticAgent( # 2. 構建代理
spec=spec, start_method=config.start_method, log_dir=config.log_dir
)
result = agent.run() # 3. 啟動代理
events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED))
if result.is_failed():
# ChildFailedError is treated specially by @record
# if the error files for the failed children exist
# @record will copy the first error (root cause)
# to the error file of the launcher process.
raise ChildFailedError(
name=entrypoint_name,
failures=result.failures,
)
else:
return result.return_values
except ChildFailedError:
raise
except Exception:
if agent:
events.record(agent.get_agent_status_event(WorkerState.FAILED))
else:
events.record(_construct_event(config))
raise
finally:
rdzv_handler.shutdown()
這裡有幾個關鍵點:
4.3.1 WorkerSpec
WorkerSpec :這是配置資訊,裡面包含了代理所需要的某些全域性資訊,比如 RendezvousHandler,role,entry(使用者函式)。
spec = {WorkerSpec}
args = {tuple: 2} (tensor, tensor)
fn = {NoneType} None
local_world_size = {int} 1
master_addr = {NoneType} None
master_port = {NoneType} None
max_restarts = {int} 1
monitor_interval = {int} 1
rdzv_handler = {DynamicRendezvousHandler}
redirects = {Std} Std.NONE
role = {str} 'trainer'
tee = {Std} Std.NONE
entry = worker_fn
代理會從這裡提取各種所需資訊。比如_start_workers 會從中獲取 store。
use_agent_store = spec.rdzv_handler.get_backend() == "static"
此時邏輯為:
+--------------------------+ +---------------------------------------------------+
|LocalElasticAgent | | WorkerSpec |
| | | |
| WorkerSpec +--------------> | rdzv_handler = {DynamicRendezvousHandler} --------+
| | | | |
| rdzv_run_id | | entry = worker_fn | |
| | | | |
| store | | role = {str} 'trainer' | |
| | | | |
| | +---------------------------------------------------+ |
| | |
| | |
| | |
| | |
| | +-----------------------------------------+ |
+--------------------------+ |DynamicRendezvousHandler | |
| | |
| | |
| _settings: RendezvousSettings | <---+
| |
| _store: Store |
| |
| _state_holder: _RendezvousStateHolder |
| |
| _op_executor: _RendezvousOpExecutor |
| |
+-----------------------------------------+
4.3.2 WorkerGroup
WorkerGroup 代表了一個工作組。WorkerGroup 作為一個整體來管理多個 workers,進行批量處理。
class WorkerGroup:
"""
Represents the set of ``Worker`` instances for the given ``WorkerSpec``
managed by ``ElasticAgent``. Whether the worker group contains cross
instance workers or not depends on the implementation of the agent.
"""
__slots__ = ["spec", "workers", "store", "group_rank", "group_world_size", "state"]
def __init__(self, spec: WorkerSpec):
self.spec = spec
self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]
# assigned after rdzv
self.store = None
self.group_rank = None
self.group_world_size = None
self.state = WorkerState.INIT
在SimpleElasticAgent 初始化之中,會建立一個 WorkerGroup。
class SimpleElasticAgent(ElasticAgent):
"""
An ``ElasticAgent`` that manages workers (``WorkerGroup``)
for a single ``WorkerSpec`` (e.g. one particular type of worker role).
"""
def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
self._worker_group = WorkerGroup(spec)
self._remaining_restarts = self._worker_group.spec.max_restarts
self._store = None
self._exit_barrier_timeout = exit_barrier_timeout
self._total_execution_time = 0
具體如下:
+-----------------------------+ +------------------------------------------------+
| LocalElasticAgent | | WorkerSpec |
| | | |
| +------------------------+ | | rdzv_handler = {DynamicRendezvousHandler} -------+
| |WorkerGroup | | | | |
| | spec +--------------> | entry = worker_fn | |
| | workers | | | | |
| | store | | | role = {str} 'trainer' | |
| | group_rank | | | | |
| | group_world_size | | +------------------------------------------------+ |
| | | | |
| +------------------------+ | |
| | |
| rdzv_run_id | |
| store | +-----------------------------------------+ |
| | |DynamicRendezvousHandler | |
+-----------------------------+ | | |
| | |
| _settings: RendezvousSettings | <--+
| |
| _store: Store |
| |
| _state_holder: _RendezvousStateHolder |
| |
| _op_executor: _RendezvousOpExecutor |
| |
+-----------------------------------------+
4.4 代理執行
SimpleElasticAgent 是 LocalElasticAgent 的基類,所以會先執行到WorkerSpec.run 方法這裡,run方法則呼叫了 _invoke_run。
@prof
def run(self, role: str = DEFAULT_ROLE) -> RunResult:
start_time = time.monotonic()
try:
result = self._invoke_run(role) # 呼叫
self._total_execution_time = int(time.monotonic() - start_time)
self._record_metrics(result)
self._record_worker_events(result)
return result
finally:
# record the execution time in case there were any exceptions during run.
self._total_execution_time = int(time.monotonic() - start_time)
self._shutdown()
4.5 代理主迴圈
代理在 invoke_run 之中做如下操作:
- 啟動 _initialize_workers,這裡會使用 _rendezvous 構建一個 rendezvous,然後呼叫 _start_workers 啟動 workers。
- 進入 while True 迴圈,在迴圈之中:
- 通過 _monitor_workers 定期輪訓使用者程式執行情況,得到客戶程式執行結果,然後依據情況作出判斷。
- 如果程式正常結束,則返回。
- 如果程式出錯,則重試,即重啟所有 workers,如果重試次數達到依然有問題,就結束所有workers。
- 如果節點成員關係有變化,比如scale up就會有新的節點在waiting,這時候就重啟所有workers。
- 通過 _monitor_workers 定期輪訓使用者程式執行情況,得到客戶程式執行結果,然後依據情況作出判斷。
def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
# NOTE: currently only works for a single role
spec = self._worker_group.spec
role = spec.role
self._initialize_workers(self._worker_group) # 啟動worker
monitor_interval = spec.monitor_interval
rdzv_handler = spec.rdzv_handler
while True:
assert self._worker_group.state != WorkerState.INIT
# 定期監控
time.sleep(monitor_interval)
# 監控客戶程式執行情況
run_result = self._monitor_workers(self._worker_group) # 得到程式執行結果
state = run_result.state
self._worker_group.state = state
put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
put_metric(f"workers.{role}.{state.name.lower()}", 1)
if state == WorkerState.SUCCEEDED:
# 程式正常結束
self._exit_barrier()
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
# 程式出錯
if self._remaining_restarts > 0: # 重試
self._remaining_restarts -= 1
self._restart_workers(self._worker_group)
else:
self._stop_workers(self._worker_group) # 重試次數達到,結束workers
self._worker_group.state = WorkerState.FAILED
self._exit_barrier()
return run_result
elif state == WorkerState.HEALTHY:
# 節點成員關係有變化,比如scale up,就會有新節點waiting
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
# 如果有新的節點在waiting,就重啟所有workers
if num_nodes_waiting > 0:
self._restart_workers(self._worker_group)
else:
raise Exception(f"[{role}] Worker group in {state.name} state")
於是最終邏輯如下:
+----------------------------------------------+
| LocalElasticAgent |
| | +---------------------------------------------------+
| rdzv_run_id | | WorkerSpec |
| | | |
| store +------------------------+ | | rdzv_handler = {DynamicRendezvousHandler} +-------+
| |WorkerGroup | | | | |
| _pcontext | spec +------------> | entry = worker_fn | |
| | workers | | | | |
| | store | | | role = {str} 'trainer' | |
| | group_rank | | | | |
| | group_world_size | | +---------------------------------------------------+ |
| | | | |
| +------------------------+ | |
| +----------------------------------------+ | |
| | _invoke_run | | |
| | | | +-----------------------------------------+ |
| | _initialize_workers +------------------------+ |DynamicRendezvousHandler | |
| | | | | | | |
| | | | | | | |
| | while True: | | | | _settings: RendezvousSettings | <---+
| | _monitor_workers(_worker_group) | | | | |
| | + | | | | _store: Store |
| | | _pcontext.wait | | | | |
| | | | | | | _state_holder: _RendezvousStateHolder |
| +----------------------------------------+ | | | |
| | | | | _op_executor: _RendezvousOpExecutor |
+----------------------------------------------+ | | |
| | +-----------------------------------------+
| |
v v
+-------------------------------------------------+
| +------------+ +------------+ +------------+ |
| |Process | |Process | |Process | |
| | | | | | | |
| | work_fn | | work_fn | | work_fn | |
| | | | | | | |
| +------------+ +------------+ +------------+ |
+-------------------------------------------------+
手機如下:
至此,指令碼如何啟動和單體流程我們分析完畢,下一篇我們來具體分析代理。
0xFF 參考
[PyTorch Elastic原始碼閱讀](