[原始碼解析] 深度學習分散式訓練框架 horovod (16) --- 彈性訓練之Worker生命週期

羅西的思考發表於2021-07-19

[原始碼解析] 深度學習分散式訓練框架 horovod (16) --- 彈性訓練之Worker生命週期

0x00 摘要

Horovod 是Uber於2017年釋出的一個易於使用的高效能的分散式訓練框架,在業界得到了廣泛應用。

本系列將通過原始碼分析來帶領大家瞭解 Horovod。本文是第十六篇,看看 horovod 彈性訓練中 worker 的生命週期。

我們先給出一個邏輯圖,大家先有一個粗略的瞭解,本圖左側是 Driver 部分,右側是一個 Worker。

本系列其他文章連結如下:

[原始碼解析] 深度學習分散式訓練框架 Horovod (1) --- 基礎知識

[原始碼解析] 深度學習分散式訓練框架 horovod (2) --- 從使用者角度切入

[原始碼解析] 深度學習分散式訓練框架 horovod (3) --- Horovodrun背後做了什麼

[原始碼解析] 深度學習分散式訓練框架 horovod (4) --- 網路基礎 & Driver

[原始碼解析] 深度學習分散式訓練框架 horovod (5) --- 融合框架

[原始碼解析] 深度學習分散式訓練框架 horovod (6) --- 後臺執行緒架構

[原始碼解析] 深度學習分散式訓練框架 horovod (7) --- DistributedOptimizer

[原始碼解析] 深度學習分散式訓練框架 horovod (8) --- on spark

[原始碼解析] 深度學習分散式訓練框架 horovod (9) --- 啟動 on spark

[原始碼解析] 深度學習分散式訓練框架 horovod (10) --- run on spark

[原始碼解析] 深度學習分散式訓練框架 horovod (11) --- on spark --- GLOO 方案

[原始碼解析] 深度學習分散式訓練框架 horovod (12) --- 彈性訓練總體架構

[原始碼解析] 深度學習分散式訓練框架 horovod (13) --- 彈性訓練之 Driver

[原始碼解析] 深度學習分散式訓練框架 horovod (14) --- 彈性訓練發現節點 & State

[原始碼解析] 深度學習分散式訓練框架 horovod (15) --- 廣播 & 通知

0x01 Worker 是什麼

首先,我們要看看 worker 是什麼。為了可以單獨成文,本章節回憶了很多之前的知識,看過之前文章的同學可以跳過。

1.1 角色

”訓練“ 是通過計算梯度下降的方式利用資料來迭代地優化神經網路引數,最終輸出網路模型的過程。

我們首先要看看彈性訓練中的角色設定。

Horovod 的彈性訓練包含兩個角色,driver 程式和 worker 程式。driver 程式執行在 CPU 節點上,worker 程式可執行在 CPU 或者 GPU 節點上。在 Horovod 中,訓練程式是平等的參與者,每個 worker 程式既負責梯度的分發,也負責具體的梯度計算。

這兩個角色和 Spark 的 Driver -- Executor 依然很類似。Driver 程式就可以認為是 Spark 的 Driver,或者說是 master 節點。Worker 就類似於 Spark 的 Executor。

具體如圖:

                 +------------------------------+
                 |                              |
                 |            Driver            |
                 |                              |
                 |                              |
                 +-----+-------+--------+-------+
                       ^       ^        ^
                       |       |        |
                       |       |        |
         +-------------+       |        +--------------+
         |                     |                       |
         |                     |                       |
         |                     |                       |
         v                     v                       v
+--------+----+        +-------+------+           +----+--------+
|  Worker     |        |  Worker      |           |  Worker     |
|             |        |              |           |             |
|      host1  |        |      host2   |           |     host3   |
+-------------+        +--------------+           +-------------+

1.2 職責

角色的職責如下:

master(主節點)職責:

  • 負責實時探活 worker(工作節點)是否有變化,掉線情況;
  • 負責實時監控 host 是否有變化;
  • 負責分配任務到存活的worker(工作節點);
  • 在有程式失敗導致 AllReduce 呼叫失敗的 情況下,master 通過 blacklist 機制 組織剩下的活著的程式構造一個新的環。
  • 如果有新 host 加入,則生成新的 worker,新 worker 和 舊 worker 一起構造成一個新的環。

worker(工作節點)職責:

  • 負責彙報(其實是被動的,沒有主動機制)當前worker(工作節點)的狀態(就是訓練完成情況);

  • 負責在該worker(工作節點)負責的資料上執行訓練。

1.3 組網機制

Horovod 在單機的多個 GPU 上採用 NCCL 來通訊,在多機之間通過 ring-based AllReduce 演算法進行通訊。

Horovod 的彈性訓練是指多機的彈性訓練。在多機的 ring-based 通訊中的每個 worker 節點有一個左鄰和一個右鄰,每個 worker 只會向它的右鄰居傳送資料,並從左鄰居接受資料。

1.3.1 通訊環

Driver 程式用於幫助 worker 呼叫 gloo 構造 AllReduce 通訊環。

當 Horovod 在呼叫 Gloo 來構造通訊域時,Horovod 需要給 Gloo 建立一個帶有 KVStore 的 RendezvousServer,其中 KVStore 用於儲存 通訊域內每個節點的 host 地址給其在邏輯通訊環分配的序號 rank 等資訊。

構建過程如下:

  • Driver 程式建立帶有 KVStore 的 RendezvousServer,即這個 RendezvousServer 執行在 Horovod 的 driver 程式裡

  • Driver 程式拿到所有 worker 程式節點的 IP 地址和 GPU 卡數資訊後,會將其寫入RendezvousServer 的 KVStore 中。

  • 每個 worker 節點會通過呼叫 gloo 從而 請求 RendezvousServer 獲取自己的鄰居節點資訊(ip,port...),從而構造通訊域。

1.3.2 彈性構建

當有 worker 失敗或者新的 worker 加入訓練時,每個 worker 會停止當前的訓練,記錄當前模型迭代的步數,並嘗試重新初始化 AllReduce 的通訊域

1.3.2.1 Driver 監控

因為 driver 程式一直在監控 worker 的狀態 和 host 節點情況,所以

  • 當 host 變化時候,當驅動程式通過節點發現指令碼發現一個節點被標記為新增或者移除時,它將傳送一個通知到 所有workers,在下一次 state.commit() 或者更輕量的 state.check_host_updates() 被呼叫時,會丟擲一個 HostsUpdateInterrupt 異常。
  • 當有 worker 失敗時,driver 會重新獲取存活的 worker 的host,
1.3.2.2 Driver 重新構建

為了不讓其他 worker 程式退出,Horovod 會捕獲 gloo 丟擲的異常,並將異常傳遞給封裝的 Python API。進而 driver 會重新配置 RendezvousServer,從而讓 worker 節點能重新構造通訊域。所以 Horovod 是可容錯的。

如果有 host 變化,worker 程式可以通過這個 rendezvous 來構造新的通訊域。當新的通訊域構造成功後,rank=0 的 worker 會將自身的模型廣播給其他 worker,然後接著上次停止的迭代步數開始訓練。

組網機制如下:

                         +-------------------------------+
                         | Driver                        |
                         |                               |
                         |   +------------------------+  |
                         |   | RendezvousServer       |  |
                         |   |                        |  |
                         |   |                        |  |
                         |   |   host1, host2, host3  |  |
                         |   +------------------------+  |
                         +-------------------------------+
                                ^       ^        ^
                                |       |        |
                                |       |        |
                  +-------------+       |        +--------------+
                  |                     |                       |
                  |                     |                       |
                  |                     |                       |
                  v                     v                       v
         +--------+----+        +-------+------+           +----+--------+
         |  Worker     |        |  Worker      |           |  Worker     |
+------> |             +------> |              +---------> |             | +------+
|        |      host1  |        |      host2   |           |     host3   |        |
|        +-------------+        +--------------+           +-------------+        |
|                                                                                 |
|                                                                                 |
|                                                                                 v
<--------------------------------------------------------------------------------+

所以,本文就看看 Worker 的一個整體生命流程。

0x02 總體生命流程

在 Driver 的 launch_gloo_elastic 之中,如下程式碼負責啟動 worker。

  • command 就是傳入的可執行命令,比如 python train.py。經過 get_run_command 之後就得到了 env python train.py 之類的樣子,就是加上環境變數,可以執行了。
  • exec_command 類似如下: exec_command = _exec_command_fn(settings),就是基於各種配置來生成可以執行命令環境
run_command = get_run_command(command, server_ip, nics, global_rendezv_port, elastic=True)

create_worker = _create_elastic_worker_fn(exec_command, run_command, env, event)

driver.start(settings.num_proc, create_worker)

res = driver.get_results()

可以清晰看出來三個詳細的過程(因為 訓練過程是在 worker 內部,所以 driver 分析沒有深入此部分):

  • _create_elastic_worker_fn 是配置過程
  • driver.start 是啟動過程
  • driver.get_results 就是得到 & 註冊 執行結果

我們接下來就按照這三個過程詳細分析一下。

0x03 配置過程

配置過程 是由 _create_elastic_worker_fn 完成,就是提供一個在某個環境下執行某個命令的能力

_create_elastic_worker_fn 分為兩部分:

  • _slot_info_to_command_fn 會建立 slot_info_to_command,套路和之前文章中類似,就是把各種 horovod 環境變數和執行命令 run_command 糅合起來,得到一個可以在“某個 host and slot” 之上執行的命令文字
  • 返回 create_worker。
    • create_worker 是利用 exec_command 和 命令文字 構建的函式。
    • exec_command 我們在之前介紹過,就是提供了一種執行命令的能力,或者說是執行環境;
    • 所以 create_worker 就是提供一個在某個環境下執行某個命令的能力;
# 得到一個 可以在“某個 host and slot” 之上執行的命令文字
def _slot_info_to_command_fn(run_command, env):
     def slot_info_to_command(slot_info):
        """
        Given a slot_info, creates a command used by gloo to launch a single job.
        :param slot_info: host and slot to execute the run command on
        """
        env_vars = create_slot_env_vars(slot_info)
        horovod_rendez_env = " ".join(
            [f"{k}={str(v)}" for k, v in env_vars.items()])

        return '{horovod_env} {env} {run_command}' .format(
            horovod_env=horovod_rendez_env,
            env=' '.join(['%s=%s' % (key, quote(value)) for key, value in env.items()
                          if env_util.is_exportable(key)]),
            run_command=run_command)

    return slot_info_to_command

def _create_elastic_worker_fn(exec_command, run_command, env, event):
    get_command_with_env = _slot_info_to_command_fn(run_command, env)

    # 提供一個在某個環境下執行某個命令的能力
    def create_worker(slot_info, events):
        command = get_command_with_env(slot_info)
        events = [event] + (events or [])
        return exec_command(command, slot_info, events)
    return create_worker

所以,最終得到的 create_worker 是:

               command (python train.py)
                          +
                          |
                          |
          get_run_command |
                          |
                          v

               run_command(env python train.py) #得到命令類似 env python train.py
                          +
                          |
                          |
 _slot_info_to_command_fn |
                          |
                          v

             {horovod_env} {env} {run_command} #得到命令類似 horovod_env env python train.py
                          +
                          |
                          |
             exec_command |
                          |
                          |
                          v

create_worker = exec_command({horovod_env} {env} {run_command})#得到在某個環境下執行某個命令的能力

這樣,create_worker 就直接可以執行訓練程式碼了。

0x04 啟動過程

create_worker = _create_elastic_worker_fn 提供了一個在某個環境下執行某個命令的能力,因為 create_worker 方法內部已經包括了執行命令和執行環境,就是說,只要執行create_worker,就可以自動訓練。下面我們就利用這個能力來啟動 worker。

啟動過程基本都是在 ElasticDriver 類 的 start 方法中完成。

4.1 總體邏輯

以下邏輯都執行在 ElasticDriver 之中。

  • 首先,會把 上面生成的 create_worker 賦值給 self._create_worker_fn。
  • 其次,會呼叫 _activate_workers 啟動多個 worker,其中包括:
    • 先使用 wait_for_available_slots 等待 min_np 數目的可用的 hosts。之前分析過此函式,就是 無限迴圈等待,如果 avail_slots >= min_np and avail_hosts >= min_hosts 才會返回。
    • 使用 _update_host_assignments 來得到 slots;
    • 使用 _start_worker_processes 來啟動多個 worker;
def _activate_workers(self, min_np):
    current_hosts = self.wait_for_available_slots(min_np)
    pending_slots = self._update_host_assignments(current_hosts)
    self._worker_registry.reset(self.world_size())
    self._start_worker_processes(pending_slots)

def start(self, np, create_worker_fn):
    self._create_worker_fn = create_worker_fn
    self._activate_workers(np)

下面逐一看看。

4.2 賦值

第一步就是上面生成的 create_worker 賦值給 self._create_worker_fn。

                 command (python train.py)
                            +
                            |
                            |
            get_run_command |
                            |
                            v

                 run_command(env python train.py)
                            +
                            |
                            |
   _slot_info_to_command_fn |
                            |                                          +-----------------------------+
                            v                                          |                             |
                                                                       |                             |
               {horovod_env} {env} {run_command}                       |     ElasticDriver           |
                            +                                          |                             |
                            |                                          |                             |
                            |                                          |                             |
               exec_command |                                          |                             |
                            |                                          |                             |
                            |                                          |                             |
                            v                                       1  |                             |
+---------------------------+------------------------------------+     |                             |
| create_worker = exec_command({horovod_env} {env} {run_command})| <------------+  _create_worker_fn |
+----------------------------------------------------------------+     |                             |
                                                                       |                             |
                                                                       +-----------------------------+

手機如下:

4.3 獲取 host 資訊

接下來要使用 _update_host_assignments 來得到 slots,具體分為兩步:

首先構建 host 和 rank 之間的分配狀況。

其次

4.3.1 更新 host 和 rank

_update_host_assignments 函式中會根據 最新的 host 資訊,重新構建 rendezvous,比如:

self._rendezvous.init(host_assignments_list)

具體邏輯是:

  • 獲取 活躍的slot :active_slots
  • 獲取 host 分配情況;
  • 確保每個 worker 有前驅者,就是可以傳遞狀態,構成環;
  • 呼叫 self._rendezvous.init 重新構造 rendezvous;
  • 分配 rank 和 slot 的關係;
  • 返回 pending_slots,就是分配的slot之中,不在 活躍slot列表 active_slots 中的。不活躍的就是接下來可以啟動 新 worker 的
def _update_host_assignments(self, current_hosts):
    # Determine the slots that are already filled so we do not respawn these processes
    # 獲取 活躍的slot
    active_slots = set([(host, slot_info.local_rank)
                        for host, slots in self._host_assignments.items()
                        for slot_info in slots])

    # Adjust the host assignments to account for added / removed hosts
    host_assignments, host_assignments_list = self._get_host_assignments(current_hosts)

    if len(self._host_assignments) > 0:
        # Ensure that at least one previously active host is still assigned, otherwise there is no
        # way to sync the state to the new workers
        prev_hosts = self._host_assignments.keys()
        next_hosts = host_assignments.keys()
        if not prev_hosts & next_hosts:
            raise RuntimeError('No hosts from previous set remaining, unable to broadcast state.')

    self._host_assignments = host_assignments
    self._world_size = len(host_assignments_list)
    
    self._rendezvous.init(host_assignments_list) # 重新構造 rendezvous

    # Rank assignments map from world rank to slot info
    rank_assignments = {}
    for slot_info in host_assignments_list:
        rank_assignments[slot_info.rank] = slot_info
    self._rank_assignments = rank_assignments

    # Get the newly assigned slots that need to be started
    pending_slots = [slot_info
                     for host, slots in self._host_assignments.items()
                     for slot_info in slots
                     if (host, slot_info.local_rank) not in active_slots]
    return pending_slots

4.3.2 獲取 host 和 rank

其中,host資訊是通過 _get_host_assignments 來完成

def _get_host_assignments(self, current_hosts):
    # Adjust the host assignments to account for added / removed hosts
    host_list = [hosts.HostInfo(host, current_hosts.get_slots(host))
                 for host in current_hosts.host_assignment_order]
    host_assignments_list = hosts.get_host_assignments(host_list, self._min_np, self._max_np)
    host_assignments = defaultdict(list)
    for slot_info in host_assignments_list:
        host_assignments[slot_info.hostname].append(slot_info)
    return host_assignments, host_assignments_list

_get_host_assignments 呼叫get_host_assignments具體完成業務。

get_host_assignments 會依據 host 和 process capacities (slots) 來給 Horovod 之中的程式分配,即給出一個 horovod rank 和 slot 的對應關係。設定了幾個 np,就有幾個 slot。

給出的分配方案類似如下,這樣就知道了哪個rank對應於哪個host上的哪個slot

[
  SlotInfo(hostname='h1', rank=0, local_rank=0, cross_rank=0, size=2, local_size=2, coress_size=1),
	SlotInfo(hostname='h2', rank=1, local_rank=0, cross_rank=0, size=2, local_size=2, coress_size=1),
]

程式碼如下:

def get_host_assignments(hosts, min_np, max_np=None):
    """Assign hosts with process capacities (slots) to ranks in the Horovod process.
    This function will try to allocate as many as possible processes on the same host to leverage local network.

    :param hosts: list of HostInfo objects describing host and slot capacity
    :type hosts: list[HostInfo]
    :param min_np: minimum number of processes to be allocated
    :param max_np: (optional) maximum number of processes to be allocated
    :return: a list of the allocation of process on hosts in a `SlotInfo` object.
    :rtype: list[SlotInfo]
    """
    host_ranks = []
    cross_ranks = collections.defaultdict(dict)
    rank = 0
    # 依據 hosts 資訊構建 rank, local rank, cross rank(hierarchical allreduce所需要)
    for host_info in hosts:
        ranks = []
        for local_rank in range(host_info.slots):
            if rank == max_np:
                break

            ranks.append(rank)
            rank += 1

            cross_ranks_at_local = cross_ranks[local_rank]
            cross_ranks_at_local[host_info.hostname] = len(cross_ranks_at_local)

        host_ranks.append((host_info, ranks))

    world_size = rank

    # 給出一個 horovod rank 和 slot 的對應關係。返回一個alloc_list,每個SlotInfo包括各種rank資訊
    alloc_list = []
    for host_info, ranks in host_ranks:
        local_size = len(ranks)
        for local_rank, rank in enumerate(ranks):
            cross_ranks_at_local = cross_ranks[local_rank]
            cross_rank = cross_ranks_at_local[host_info.hostname]
            cross_size = len(cross_ranks_at_local)

            alloc_list.append(
                SlotInfo(
                    hostname=host_info.hostname,
                    rank=rank,
                    local_rank=local_rank,
                    cross_rank=cross_rank,
                    size=world_size,
                    local_size=local_size,
                    cross_size=cross_size))

    return alloc_list

4.3.3 擴充邏輯

目前擴充邏輯如下,經過 4.2,4.3 兩步之後,ElasticDriver 中的一些變數被賦值了,我們簡化了上圖中的 exec_command 如下(第一步在上圖中):

                                                           From Host discovery
                                                                   +
+------------------------------+                                   |
| ElasticDriver                |                                   |
|                              |                                   |  2
|                              |       wait_for_available_slots    |
|    wait_for_available_slots  |                                   |
|                              |                                   v
|   _update_host_assignments   |                              current_hosts
|                              |                                   +
|          _rendezvous         |                                   |
|                              |       _update_host_assignments    |
|         _host_assignments    |                                   |  3
|                              |          self._rendezvous.init    |
|         _world_size          |                                   |
|                              |                                   |
|         _rank_assignments    |                                   |
|                              |                                   v
|         _create_worker_fn+---------+                        pending_slots
|                              |     |
|                              |     |
+------------------------------+     |
                                     |
                                     |
                                     v
 +-----------------------------------+------------+
 | exec_command({horovod_env} {env} {run_command})|
 +------------------------------------------------+

4.4 啟動

_start_worker_processes 完成了啟動過程,邏輯遞增如下。

  • create_worker_fn 就是使用 之前賦值的 create_worker = exec_command({horovod_env} {env} {run_command}) ;
  • run_worker() 中 執行 create_worker_fn(slot_info, [shutdown_event, host_event]) 就是 執行訓練程式碼;
  • threading.Thread(target=run_worker) 就是在一個 thread 之中執行訓練程式碼;
  • _start_worker_processes 就是在多個 thread 之中執行多份訓練程式碼;
def _start_worker_processes(self, pending_slots):
    for slot_info in pending_slots:
        self._start_worker_process(slot_info)

def _start_worker_process(self, slot_info):
    create_worker_fn = self._create_worker_fn
    shutdown_event = self._shutdown
    host_event = self._host_manager.get_host_event(slot_info.hostname)

    def run_worker():
        res = create_worker_fn(slot_info, [shutdown_event, host_event])
        exit_code, timestamp = res
        self._handle_worker_exit(slot_info, exit_code, timestamp)

    thread = threading.Thread(target=run_worker)
    thread.daemon = True
    thread.start()
    self._results.expect(thread)

啟動之後,邏輯如下,可以看到,經過 4 步之後,啟動了 count(slot_info)這麼多的 Thread,每個 Thread 之中,有一個 _create_worker_fn 執行在一個 Slot 之上:

+------------------------------+
| ElasticDriver                |                           From Host discovery
|                              |                                   +
|    _start_worker_processes   |                                   |
|                              |                                   |
|    wait_for_available_slots  |                                   |  2
|                              |       wait_for_available_slots    |
|   _update_host_assignments   |                                   |
|                              |                                   v
|          _rendezvous         |                              current_hosts
|                              |                                   +
|         _host_assignments    |                                   |
|                              |       _update_host_assignments    |
|         _world_size          |                                   |  3
|                              |          self._rendezvous.init    |
|         _rank_assignments    |                                   |
|                              |                                   |
|         _create_worker_fn+---------+                             |
|                              |     |                             v
|         _worker_registry     |     |                        pending_slots
|                              |     |                             +
|                              |     |                             |
+------------------------------+     |    _start_worker_processes  | 4
                                     |                             |
                                     |                             |
                                     v                             |
 +-----------------------------------+------------+                |
 | exec_command({horovod_env} {env} {run_command})|                |
 +---+--------------------------------------------+                |
     ^                                                             v
     |                               +-----------------+-----------+------------------+
     |                               |                 |                              |
     |                               |                 |                              |
     |                               |                 v                              v
 +-----------------------------------+---+     +-------+---------+         +----------+-+
 |   |       Thread 1                    |     | Thread 2        |         | Thread n   |
 | +----------------------------------+  |     |                 |         |            |
 | | |               run_worker       |  |     |                 |         |            |
 | | |                                |  |     |   run_worker    | ......  | run_worker |
 | | |                                |  |     |                 |         |            |
 | | +----+ create_worker_fn( slot 1 )|  |     |     slot 2      |         |  slot n    |
 | |                                  |  |     |                 |         |            |
 | +----------------------------------+  |     |                 |         |            |
 +---------------------------------------+     +-----------------+         +------------+

手機如下:

0x05 執行過程

執行過程其實在上面已經提到了,就是在 Thread 之中執行 exec_command({horovod_env} {env} {run_command})這就是呼叫使用者程式碼進行訓練

thread = threading.Thread(target=run_worker)
thread.daemon = True
thread.start()
self._results.expect(thread)

這裡要說明的是 self._results = ResultsRecorder() 。具體在其中記錄每一個執行的 Thread。

class ResultsRecorder(object):
    def __init__(self):
        self._error_message = None
        self._worker_results = {}
        self._worker_threads = queue.Queue()

    def expect(self, worker_thread):
        self._worker_threads.put(worker_thread)

於是我們的邏輯變成如下,self._results 裡面記錄了所有的 Threads:

   +------------------------------+
   | ElasticDriver                |                           From Host discovery
   |                              |                                   +
   |                              |                                   |
   |                              |                                   |
   |    _start_worker_processes   |                                   |  2
   |                              |       wait_for_available_slots    |
   |    wait_for_available_slots  |                                   |
   |                              |                                   v
   |   _update_host_assignments   |                              current_hosts
   |                              |                                   +
   |          _rendezvous         |                                   |
   |                              |       _update_host_assignments    |
   |         _host_assignments    |                                   |  3
   |                              |          self._rendezvous.init    |
   |         _world_size          |                                   |
   |                              |                                   |
   |         _rank_assignments    |                                   |
   |                              |                                   v
   |         _create_worker_fn+---------+                        pending_slots
   |                              |     |                             +
+---------+  _results             |     |                             |
|  |                              |     |    _start_worker_processes  | 4
|  +------------------------------+     |                             |
|                                       |                             |
|                                       v                             |
|   +-----------------------------------+------------+                |
|   | exec_command({horovod_env} {env} {run_command})|                |
|   +---+--------------------------------------------+                |
|       ^                                                             v
|       |                               +-----------------+-----------+------------------+
|       |                               |                 |                              |
|       |                               |                 |                              |
|       |                               |                 v                              v
|   +-----------------------------------+---+     +-------+---------+         +----------+-+
|   |   |       Thread 1                    |     | Thread 2        |         | Thread n   |
|   | +----------------------------------+  |     |                 |         |            |
|   | | |               run_worker       |  |     |                 |         |            |
|   | | |                                |  |     |   run_worker    | ......  | run_worker |
|   | | |                                |  |     |                 |         |            |
|   | | +----+ create_worker_fn( slot 1 )|  |     |     slot 2      |         |  slot n    |
|   | |                                  |  |     |                 |         |            |
|   | +----------------------------------+  |     |                 |         |            |
|   +---------------------------------------+     +--------+--------+         +-------+----+
|                  ^                                       ^                          ^
|                  |                                       |                          |
+------------------+---------------------------------------+--------------------------+

手機如下:

0x06 註冊,結果 & 協調

本節主要對應架構圖中的下面三部分。

此部分的邏輯層次在後續介紹的 容錯機制之上。容錯機制是在 Worker 內部,此部分是在 Worker 之上

6.1 Worker 的邏輯層次

worker 是在 訓練指令碼之上的階段,即 Worker 來執行使用 "python train.py" 來執行訓練指令碼。

所以,上圖我們簡化如下:

   +------------------------------+
   | ElasticDriver                |                           From Host discovery
   |                              |                                   +
   |     _start_worker_processes  |                                   |
   |                              |                                   |
   |     wait_for_available_slots |                                   |  2
   |                              |       wait_for_available_slots    |
   |    _update_host_assignments  |                                   |
   |                              |                                   v
   |           _rendezvous        |                              current_hosts
   |                              |                                   +
   |          _host_assignments   |                                   |
   |                              |       _update_host_assignments    |
   |          _world_size         |                                   |  3
   |                              |          self._rendezvous.init    |
   |          _rank_assignments   |                                   |
   |                              |                                   |
   |          _create_worker_fn   |                                   |
   |                              |                                   v
   |          _worker_registry    |                              pending_slots
   |                              |                                   +
+---------+  _results             |                                   |
|  |                              |          _start_worker_processes  | 4
|  +------------------------------+                                   |
|                                                                     v
|                                       +-----------------+-----------+------------------+
|                                       |                 |                              |
|                                       |                 |                              |
|                                       |                 v                              v
|   +-----------------------------------+---+     +-------+---------+       +------------+-+
|   | Thread 1                       slot 1 |     | Thread 2        |       | Thread n     |
|   |                                       |     |                 |       |              |
|   |  +---------------------------------+  |     | slot 2          |       | slot n       |
|   |  | worker 1                        |  |     |  +-----------+  | ......|  +---------+ |
|   |  |                                 |  |     |  |worker 2   |  |       |  |worker n | |
|   |  |          Python train.py        |  |     |  |           |  |       |  |         | |
|   |  |                                 |  |     |  |           |  |       |  |         | |
|   |  +---------------------------------+  |     |  +-----------+  |       |  +---------+ |
|   +---------------------------------------+     +-----------------+       +--------------+
|                  ^                                       ^                          ^
|                  |                                       |                          |
+------------------+---------------------------------------+--------------------------+

手機如下:

6.2 worker 執行階段

於是我們提出了新問題如下:

  • worker 的執行,怎麼才算一個階段?一共有幾個階段(狀態)?
  • Driver 根據什麼特徵來記錄 Worker 的執行結果?

從原始碼中我們可以看到,Worker 有三個狀態如下:

READY = 'READY'
SUCCESS = 'SUCCESS'
FAILURE = 'FAILURE'

所以,Worker 可以分為四個階段,RUNNING 是我自己加上去的,就是執行訓練指令碼這個過程,官方沒有這個狀態,但是我覺得這樣應該更清晰。而 SUCCESS 和 FAILURE 就是指令碼執行成功與否。

          Worker
            +
            |
            |
            |
            v
          READY
            +
            |
            |
            v
   +--------+---------+
   | RUNNINZG         |
   |                  |
   |  Python train.py |
   |                  |
   +--------+---------+
            |
            |
            |
            v
   +--------+--------+
   |                 |
   |                 |
   v                 v

SUCCESS           FAILURE

我們接下來看看執行階段。

6.2.1 進入 C++ 世界

當 Driver 初始化 / resume(比如收到了HostsUpdateInterrupt ) 時候,就會呼叫到 hvd.init。

進入 hvd.init 有幾個呼叫途徑(按照下面1,2,3 順序邏輯進行):

  1. 依靠 WorkerStateRegistry . _barrier : 作用是當所有 worker 完成之後,會進一步處理。有三個途徑會觸發這個barrier:
    • start 一個 worker,worker會 hvd.init,進而呼叫了 gloo in c++,進而 聯絡 rendezvous,rendezvous 通知 driver,進而在 WorkerStateRegistry 設定自己的狀態是 READY,如果達到了 min_np,則會觸發了 _barrier途徑 1);
    • 新發現一個host,從而導致觸發一個 HostsUpdateInterrupt,worker 捕獲這個異常之後,進而會 reset,reset 時候會呼叫 hvd.init,進而和上述一樣,最終觸發_barrier途徑 2);
    • worker失敗,會呼叫會 _handle_worker_exit,進而在 WorkerStateRegistry 設定自己的狀態是 FAILURE,會觸發了 _barrier途徑 3);
  2. _barrier 繼續執行時候,會呼叫構建時候設定的 handler,即 _action 函式,進而呼叫到了 _on_workers_recorded,最終呼叫到了 self._driver.resume()
  3. resume 函式 會self._activate_workers(self._min_np),最終就是重新生成(也許是部分,根據 pending_slots決定)worker

6.2.2 構建 Gloo

前文我們提到過,在 python 世界呼叫 hvd.init 的時候,會進入到 C++世界,這時候如果編譯了 GLOO,就建立了一個 GlooContext。

namespace {

// All the Horovod state that must be stored globally per-process.
HorovodGlobalState horovod_global;

#if HAVE_GLOO
GlooContext gloo_context;
#endif

GlooContext 就得到了一個與 RendezvousServer 通訊的介面。

6.2.2.1 去 Rendezvous 獲取資訊

從 GlooContext::Initialize 可以知道,需要獲取大量配置資訊,其中最重要的就是 rank 等資訊。

這些資訊是在 RendezvousServer 設定儲存的。所以 GlooContext 會去 RendezvousServer 進行 http 互動,下面程式碼還是比較清晰易懂的。

#define HOROVOD_GLOO_GET_RANK_AND_SIZE "rank_and_size"

void GlooContext::Initialize(const std::string& gloo_iface) {
  // Create a device for communication
  // TODO(sihan): Add support for multiple interfaces:
  //  https://github.com/facebookincubator/gloo/issues/190
  attr device_attr;
  device_attr.iface = gloo_iface;

  device_attr.ai_family = AF_UNSPEC;
  auto dev = CreateDevice(device_attr);
  auto timeout = GetTimeoutFromEnv();

  auto host_env = std::getenv(HOROVOD_HOSTNAME);
  std::string hostname = host_env != nullptr ? std::string(host_env) : std::string("localhost");

  int rank = GetIntEnvOrDefault(HOROVOD_RANK, 0);
  int size = GetIntEnvOrDefault(HOROVOD_SIZE, 1);
  int local_rank = GetIntEnvOrDefault(HOROVOD_LOCAL_RANK, 0);
  int local_size = GetIntEnvOrDefault(HOROVOD_LOCAL_SIZE, 1);
  int cross_rank = GetIntEnvOrDefault(HOROVOD_CROSS_RANK, 0);
  int cross_size = GetIntEnvOrDefault(HOROVOD_CROSS_SIZE, 1);

  auto rendezvous_addr_env = std::getenv(HOROVOD_GLOO_RENDEZVOUS_ADDR);
  auto rendezvous_port = GetIntEnvOrDefault(HOROVOD_GLOO_RENDEZVOUS_PORT, -1);

  bool elastic = GetBoolEnvOrDefault(HOROVOD_ELASTIC, false);
  if (elastic && reset_) {
    std::string server_addr = rendezvous_addr_env;
    std::string scope = HOROVOD_GLOO_GET_RANK_AND_SIZE;
    HTTPStore init_store(server_addr, rendezvous_port, scope, rank);

    auto key = hostname + ":" + std::to_string(local_rank);
    std::vector<char> result = init_store.get(key);
    std::string s(result.begin(), result.end());
    std::stringstream ss(s);

    int last_rank = rank;
    int last_size = size;
    int last_local_rank = local_rank;
    int last_local_size = local_size;
    int last_cross_rank = cross_rank;
    int last_cross_size = cross_size;

    rank = ParseNextInt(ss);

    size = ParseNextInt(ss);
    local_rank = ParseNextInt(ss);
    local_size = ParseNextInt(ss);
    cross_rank = ParseNextInt(ss);
    cross_size = ParseNextInt(ss);

    SetEnv(HOROVOD_RANK, std::to_string(rank).c_str());
    SetEnv(HOROVOD_SIZE, std::to_string(size).c_str());
    SetEnv(HOROVOD_LOCAL_RANK, std::to_string(local_rank).c_str());
    SetEnv(HOROVOD_LOCAL_SIZE, std::to_string(local_size).c_str());
    SetEnv(HOROVOD_CROSS_RANK, std::to_string(cross_rank).c_str());
    SetEnv(HOROVOD_CROSS_SIZE, std::to_string(cross_size).c_str());
  }

  ctx = Rendezvous(HOROVOD_GLOO_GLOBAL_PREFIX,
                   rendezvous_addr_env, rendezvous_port,
                   rank, size, dev, timeout);
  local_ctx = Rendezvous(HOROVOD_GLOO_LOCAL_PREFIX + hostname,
                         rendezvous_addr_env, rendezvous_port,
                         local_rank, local_size, dev, timeout);
  cross_ctx = Rendezvous(HOROVOD_GLOO_CROSS_PREFIX + std::to_string(local_rank),
                         rendezvous_addr_env, rendezvous_port,
                         cross_rank, cross_size, dev, timeout);
}
6.2.2.2 RendezvousServer

RendezvousServer 對外提供了 GET 方法。

# GET methods
GET_RANK_AND_SIZE = 'rank_and_size'

def _get_value(self, scope, key):
    if scope == GET_RANK_AND_SIZE:
        host, local_rank = key.split(':')
        return self._get_rank_and_size(host, int(local_rank))

    return super(RendezvousHandler, self)._get_value(scope, key)

ElasticRendezvousHandler 是 RendezvousServer 的響應handler,其中 ElasticRendezvousHandler._get_rank_and_size 函式是:

def _get_rank_and_size(self, host, local_rank):
    driver.record_ready(host, local_rank)
    slot_info = driver.get_slot_info(host, local_rank)
    return slot_info.to_response_string().encode('ascii')

這裡會呼叫到 driver.record_ready,就是通知 Driver,現在有一個 worker 是 READY 狀態。

def record_ready(self, host, slot):
    self._worker_registry.record_ready(host, slot)

6.2.3 進入READY 狀態

當呼叫 hvd.init -----> GlooContext 建立 之後,會與RendezvousServer 通訊,這之後 Worker 就進入到 READY 狀態。

我們需要繼續深化下,看看一個 worker 從開始執行到 READY 狀態 之間都發生了什麼。

  1. Worker 開始呼叫 python train.py;

  2. 在 train.py 之中,呼叫 hvd.init(),此方法會深入到 C++ 世界,從而生成了 GlooContext;

  3. GlooContext 之中,會從環境變數之中得到 Rendezvous Server 的ip, port,進而呼叫 init_store 生成一個 HTTPStore;

  4. 呼叫 init_store.get(hostname + ":" + std::to_string(local_rank)) 向 Rendezvous Server 傳送請求,要求獲得本worker 的 rank 對應的 各種配置(local_rank, cross_rank...,因為 Rendezvous Server 可能會重新初始化從而重新分配);

  5. ElasticRendezvousHandler 是 響應函式,其中會 呼叫 driver.record_ready(host, local_rank) 從而在 WorkerStateRegistry 的 READY 字典中記錄下來,worker 2 已經是 READY 了。

  6. 會呼叫 driver.get_slot_info(host, local_rank) 從 driver 獲得 slot info;

  7. 此時,Worker 的狀態就是 READY(其實 Worker 本身沒有這個狀態,只是 WorkerStateRegistry 有這個狀態);

  8. ElasticRendezvousHandler 會返回 slot info 到 worker 的 C++ 世界;

  9. 在 worker 的 C++ 世界 之中繼續執行,把 slot info 返回給 GlooContext,進行各種設定;

具體邏輯圖如下:

Python                                                                        +                                                C++
                                                                              |
                +--------------------+      +----------------------+          |
                | ElasticDriver      |      | RendezvousServer     |          |
+------------>  |                    |      |                      |          |
|               |   _rendezvous +---------> |       handler_cls    |          |
|               |                    |      |             +        |          |
|               |                    |      +----------------------+          |
|         +-------+ _worker_registry |                    |                   |
|         |     |                    |                    v                   |
|         |     +--------+-----------+    +---------------+---------------+   |
|         |              ^                | ElasticRendezvousHandler      |
|         |              |                |  +---------------------+      |   |     4     get(rank)              +---------------+
|         |              |                |  |  _get_rank_and_size +<------------------------------------------+ | HTTPStore     |
|         |            6 | get_slot_info  |  |                     |      |   |                                  |               |
|         |              +-------------------+  driver  +------------------------------------------------------> |               |
|         |                               |  |                     |      |   | 8 slot_info (rank,local_rank...) +--+------------+
|         |        +-------------------------+                     |      |   |                                     |
|         |        |  5  record_ready     |  +---------------------+      |   |                                     |     ^
|         |        |                      +-------------------------------+   |                                     |     |
|         v        v                    ----------------------------------+   |                                     |     |
|   +-----+--------+---------------+    |               Worker            |   |                                     |     |
|   | WorkerStateRegistry          |    |                 +               |   |                                     |     |
|   |                              |    |                 |               |   |                                   9 |     | 3
|   |                              |    |                 |  1            |   |                                     |     |
|   |      _host_manager           |    |                 v               |   |                                     |     |
|   |                              |    |           Python train.py       |   |                                     |     |
+--------+ _driver                 |    |                 +               |   |                       +----------------------------+
    |                              |    |                 |               |   |                       | GlooContext |     |        |
    |      _barrier                |    |                 |               |   |                       |             |     |        |
    |                              |    |                 v               |   |  create an instance   |             |     +        |
    |   +---------------------+    |    |             hvd.init() +----------------------------------> |             | init_store   |
    |   |  _workers           |    |    |                                 |   |          2            |             v              |
    |   |                     |    | 7  |                                 |   |                       |  gloo::Context(rank)       |
    |   |   READY[]+----------------------------------> READY             |   |                       |                            |
    |   |   SUCCESS[]         |    |    |                 +               |   |                       |  gloo::Context(local_rank) |
    |   |   FAILURE[]         |    |    |                 |               |   |                       |                            |
    |   |                     |    |              run_fn(train_function)  |   |                       |  gloo::Context(cross_rank) |
    |   +---------------------+    |    |                 |               |   |                       |                            |
    +------------------------------+    |                 |               |   |                       +----------------------------+
                                        |                 v               |   |
                                        |        +--------+--------+      |   |
                                        |        |                 |      |   |
                                        |        |                 |      |   |
                                        |        v                 v      |   |
                                        |                                 |   |
                                        |     SUCCESS           FAILURE   |   |
                                        |                                 |   |
                                        +---------------------------------+   +

手機如下:

至此,Worker 就可以開始執行了。

6.3 WorkerStateRegistry

WorkerStateRegistry 的作用是 註冊執行結果,然後進行對應協調

其主要成員變數是:

  • _driver :用來聯絡 Driver,因為會呼叫 driver 來做處理;

  • _host_manager :用來發現 host;

  • _workers : 紀錄每種狀態的worker,狀態包括:'READY','SUCCESS','FAILURE';

    • def count(self, state):
          return len(self._workers[state])
      
  • _states : 紀錄 worker 的狀態;

  • _barrier : 作用是當所有 worker 完成之後,會進一步處理;

具體定義如下:

class WorkerStateRegistry(object):
    def __init__(self, driver, host_manager, reset_limit=None, verbose=False):
        self._driver = driver
        self._host_manager = host_manager
        self._reset_limit = reset_limit
        self._reset_count = 0
        self._lock = threading.Lock()
        self._states = {}
        self._workers = defaultdict(set)
        self._barrier = None
        self._rendezvous_id = 0
        self._verbose = verbose
        self._size = 0

6.3.1 初始化

WorkerStateRegistry 在 Driver 之中 進行初始化,並且把自己設定為 Driver 的一個成員變數,這樣 Driver 就可以方便呼叫:

self._worker_registry = WorkerStateRegistry(self, self._host_manager, reset_limit=reset_limit)

6.3.2 啟動

在 master 啟動所有 worker 之前,會呼叫 reset。

def _activate_workers(self, min_np):
    current_hosts = self.wait_for_available_slots(min_np)
    pending_slots = self._update_host_assignments(current_hosts)
    self._worker_registry.reset(self.world_size()) # 這裡reset
    self._start_worker_processes(pending_slots)

reset 函式中有複雜邏輯。

有兩個問題:

  • 為什麼要有 _barrier?原因是:大部分機器學習演算法機制是需要當所有 worker(或者若干worker) 完成之後,才會進一步處理,所以需要等待
    • 這裡 barrier 的引數 parties 具體數值是 self.world_size(),就是說,只有等到barrier 內部計數達到 self.world_size() 時候,就會激發 self._action 函式。
    • 每個worker 結束時候,都會呼叫到 _handle_worker_exit,最終會 self._barrier.wait()
    • 這樣,當所有 worker 都結束時候,barrier 會激發 self._action 函式。
  • 設定的 _action 起到什麼作用?其作用是:根據本次訓練結果,進一步控制,決定下一步動作

程式碼如下:

def reset(self, size):
    with self._lock:
        self._states.clear()
        self._workers.clear()
        self._barrier = threading.Barrier(parties=size, action=self._action)
        self._rendezvous_id += 1
        self._size = size

6.3.3 worker 結束

當 worker 結束時候,會回到 Driver 設定的 _handle_worker_exit。根據 exit_code 來決定是呼叫 success 函式還是 failure 函式。

def  _handle_worker_exit(self, slot_info, exit_code, timestamp):
    if not self.has_rank_assignment(slot_info.hostname, slot_info.local_rank):
        return

    if exit_code == 0:
        rendezvous_id = self._worker_registry.record_success(slot_info.hostname, slot_info.local_rank)
    else:
        rendezvous_id = self._worker_registry.record_failure(slot_info.hostname, slot_info.local_rank)

    if self.finished() and self._worker_registry.last_rendezvous() == rendezvous_id:
        name = '{}[{}]'.format(slot_info.hostname, slot_info.local_rank)
        self._results.add_result(name, (exit_code, timestamp))

從而呼叫到 WorkerStateRegistry 之中。

def record_ready(self, host, slot):
    return self._record_state(host, slot, READY)

def record_success(self, host, slot):
    return self._record_state(host, slot, SUCCESS)

def record_failure(self, host, slot):
    return self._record_state(host, slot, FAILURE)

_record_state 函式會使用 self._workers[state].add(key) 來紀錄狀態,並且呼叫 _wait。

def _record_state(self, host, slot, state):
    if self._driver.finished():
        return self._rendezvous_id
    if self._host_manager.is_blacklisted(host):
        return self._rendezvous_id

    key = (host, slot)
    with self._lock:
        if key in self._states:
            if state == FAILURE:
                self._barrier.reset()

        if key not in self._states or state == FAILURE:
            self._states[key] = state
            self._workers[state].add(key)

        rendezvous_id = self._rendezvous_id

    rendezvous_id = self._wait(key, state, rendezvous_id)
    return rendezvous_id

_wait 會並且呼叫 self._barrier.wait() 來等待,這是為了等待其他 worker 的資訊,最後一起處理

def _wait(self, key, state, rendezvous_id):
    while True:
        try:
            self._barrier.wait()
            return rendezvous_id
        except threading.BrokenBarrierError:
            if self._barrier.broken:
                # Timeout or other non-recoverable error, so exit
                raise

            # Barrier has been reset
            with self._lock:
                # Check to make sure the reset was not caused by a change of state for this key
                rendezvous_id = self._rendezvous_id
                saved_state = self._states.get(key, state)
                if saved_state != state:
                    # This worker changed its state, so do not attempt to wait again to avoid double-counting
                    raise RuntimeError('State {} overridden by {}'.format(state, saved_state))

6.3.4 進一步控制

_action 函式會在所有worker 結束之後,進行判斷,控制。

def _action(self):
    self._on_workers_recorded()

_on_workers_recorded 函式會完成控制邏輯。

  • 判斷是否有一個 worker 成功,如果有一個 worker 成功了,就關閉其他process,結束訓練;因為此時所有 worker 都已經執行結束,所以只要有一個 worker 成功,就可以跳出迴圈;
  • 如果所有的 worker 都失敗了,就結束訓練;
  • 把失敗的 worker 紀錄到 黑名單;
  • 如果所有的 host 都在黑名單,則結束訓練;
  • 如果已經到了最大重試數目,則結束訓練;
  • 否則呼叫 _driver.resume() 重啟訓練,因為已經 commit 了,所以會自動恢復訓練;

具體程式碼如下:

def _on_workers_recorded(self):
    # Check for success state, if any process succeeded, shutdown all other processes
    if self.count(SUCCESS) > 0:
        self._driver.stop()
        return

    # Check that all processes failed, indicating that processing should stop
    if self.count(FAILURE) == self._size:
        self._driver.stop()
        return

    # Check for failures, and add them to the blacklisted hosts list
    failures = self.get(FAILURE)
    for host, slot in failures:
        self._host_manager.blacklist(host)

    # If every active host is blacklisted, then treat this as job failure
    if all([self._host_manager.is_blacklisted(host) for host, slot in self.get_recorded_slots()]):
        self._driver.stop()
        return

    # Check that we have already reset the maximum number of allowed times
    if self._reset_limit is not None and self._reset_count >= self._reset_limit:
      
self._driver.stop(error_message=constants.RESET_LIMIT_EXCEEDED_MESSAGE.format(self._reset_limit))
        return

    try:
        self._reset_count += 1
        self._driver.resume()
    except Exception:
        self._driver.stop()

6.4 Driver.resume 場景

resume 的作用 就是 所有都重新來過。

def resume(self):
    self._activate_workers(self._min_np)

我們之前分析的場景是:一個 worker 從開始執行到 READY 狀態 之間都發生了什麼。

現在,我們加上一個情形:就是當 Driver 在 resume 的時候,發現居然有新節點,隨即啟動了一個新 worker 3

  1. Worker 2 開始呼叫 python train.py;

  2. 在 train.py 之中,呼叫 hvd.init(),此方法會深入到 C++ 世界,從而生成了 GlooContext;

  3. GlooContext 之中,會從環境變數之中得到 Rendezvous Server 的ip, port,進而呼叫 init_store 生成一個 HTTPStore;

  4. 呼叫 init_store.get(hostname + ":" + std::to_string(local_rank)) 向 Rendezvous Server 傳送請求,要求獲得本worker 的 rank 對應的 各種配置(local_rank, cross_rank...,因為 Rendezvous Server 可能會重新初始化從而重新分配);

  5. ElasticRendezvousHandler 是 響應函式,其中會 呼叫 driver.record_ready(host, local_rank) 從而在 WorkerStateRegistry 的 READY 字典中記錄下來,worker 2 已經是 READY 了。

  6. 會呼叫 driver.get_slot_info(host, local_rank) 從 driver 獲得 slot info;

  7. 把 slot info 返回給 worker 中的 Http_store;

  8. 在 worker 2 之中繼續執行,把 slot info 返回給 GlooContext,進行各種設定;

  9. 我們接著 第 5 項繼續進行;record_ready 之中 會呼叫 rendezvous_id = self._wait(key, state, rendezvous_id) 來在 WorkerStateRegistry . _barrier 之上等待; _barrier 的型別是threading.Barrier(parties=size, action=self._action)

  10. 如果 READY 的 worker 數目達到了 Horovod 設定的 min-np,就是可以啟動的最小 worker 數目, _barrier 就結束使命,就 broken,繼續執行;

  11. _barrier 繼續執行時候,會呼叫構建時候設定的 handler,即 _action 函式,進而呼叫到了 _on_workers_recorded,最終呼叫到了 self._driver.resume()

  12. ElasticDriver 的 resume 函式呼叫到了 _activate_workers,其定義如下,可以看到,如果此時 discovery 指令碼已經發現了新節點,進而返回了 pending_slotspending_slots 就是可以在這些 slot 之上啟動新 worker 的,於是 就會 呼叫 _start_worker_processes

    1. def _activate_workers(self, min_np):
          current_hosts = self.wait_for_available_slots(min_np)
          pending_slots = self._update_host_assignments(current_hosts)
          self._worker_registry.reset(self.world_size())
          self._start_worker_processes(pending_slots)
      
  13. _start_worker_processes 會開啟一個新的 worker : Worker 3;

  14. Worker 3 也執行 python train.py,至此,新worker啟動完畢;

  15. 回到worker 2,如果訓練結束,則會依據訓練結果,返回 SUCCESS 或者 FAILURE 到 Driver ;

  16. Driver 會呼叫 _handle_worker_exit 對訓練結果進行處理;

至此,新的邏輯完成。

       +--------------------------------------------+
       |      +------------------------------+      |        +----------------------+
       |      | WorkerStateRegistry          |      | 15     |  ElasticDriver       |          +----------------------+
       |      |                              |      +----->  |     _rendezvous +-------------> | RendezvousServer     |
       |      |     _driver +-----------+    |               |                      |          |                      |
       |      |        ^                |    +<----------------+ _worker_registry   | 12 init  |       handler_cls    |
       |      |        |                |    |               |                      +--------> |             +        |
       |      |        | 10 broken      |    |   11  resume  |                      |          +----------------------+
       |      |        |                +----------------------> _activate_workers  |                        |
       |      |        |                     |      +----------+_handle_worker_exit |                        v
       |      |        +          9 wait     |      |        |                      |        +---------------+---------------+
       |      |     _barrier <---------------------------+   +--+-----------+-------+        | ElasticRendezvousHandler      |
       |      |                              |      |    |      |           ^                |  +---------------------+      |         4     get(rank)
       |      | +--------------------------+ |      |    |      |           |                |  |  _get_rank_and_size +<-----------------------------------------------------+
       |      | | _workers                 | +<-----+    |      |         6 | get_slot_info  |  |                     |      |                                               |
       |      | |                          | |  16       |      |           +-------------------+  driver  +-----------------------------------------------------------+     |
       |      | |SUCCESS[]                 | |           |      |                            |  |                     |      |     7 slot_info (rank,local_rank...)    |     |
       |      | |READY[(host 1, slot1)] <----------------+--------------------------------------+                     |      |                                         |     |
       |      | |FAILURE[]                 | |                  |     5 record_ready         |  +---------------------+      |                                         |     |
       |      | +--------------------------+ |                  |                            +-------------------------------+                                         |     |
       +      +------------------------------+                  |                                                                                                      |     |
_handle_worker_exit                                             +          Host 1                                                                                      +     +
 +-----+-----------------------------------------------------+ rsh +---------+----------------------------------------------------------------------------------------+ Socket +------+
       |                                                        +   Host 3   |    Host 2                                                              Host 2           +     +
       |                                                        |            |             +-----------------+---------------+  Python +   C++                         |     |
       |                                                        +            |             |  Worker 2       | 1     slot 2  |         |                               v     |
       |                          13  _start_worker_processes(pending_slots) |             |                 v       host 2  |         |                                     |
       |                                                        +            |             |           Python train.py       |         |                            +--------+---+
       |                                                        |            |             |                 +               |         |                            | HTTPStore  |
       |                                                        v            |             |                 |               |         |                            +---+-----+--+
       |                                                                     |             |                 v               |         |                                |     ^
       |                                 +---------------------------+       |             |             hvd.init()          |         |   2                          8 |     | 3
       |                                 | Worker 3                  |       |             |                 +      +------------------+-------+                        |     |
       |                                 |           +       slot 3  |       |             |                 |               |                 |          +----------------------------+
       |                                 |           |       host 3  |       |             |                 +               |         +       |          | GlooContext |     |        |
       |                                 |       14  |               |       |             |               READY             |         |       |          |             |     |        |
       |                                 |           v               |       |             |                 +               |         |       |          |             |     +        |
       |                                 |                           |       |             |                 |               |         |       |          |             | init_store   |
       |                                 |      Python train.py      |       |             |                 v               |         |       |          |             v              |
       |                                 |                           |       |             |        hvd.elastic.run(state)   |         |       |          |  gloo::Context(rank)       |
       |                                 |                           |       |             |                 +               |         |       |          |                            |
       |                                 +---------------------------+       |             |                 +               |         |       +--------> |  gloo::Context(local_rank) |
       |                                                                     |             |         run_fn(train_function)  |         |                  |                            |
       |                                                                     |             |                 +               |         |                  |  gloo::Context(cross_rank) |
       |                                                                     |             |                 v               |         |                  |                            |
       |                                                                     |             |        +--------+--------+      |         |                  +----------------------------+
       |                                                                     |             |        |                 |      |         |
       |                                                                     |             |        v                 v      |         +
       |                                                                     |             |     SUCCESS           FAILURE   |
       |                                                                     +             |         +                 +     |
       |                                                                                   |         |                 |     |
       |                                                                                   +---------------------------------+
       |                                                                                             |                 |
       |                                                                                             |                 |
       +--------------------------------------------------------------------------<------------------+-----------------+

手機如下:

我們可以擴大部分細節看看,就是Driver + 一個 Worker,這樣看得更清楚:

                                                                                                    +
+------------------------------+           +----------------------+                                 |
| WorkerStateRegistry          |           |  ElasticDriver       |                        Host 1   |  Host 2
|                              |10 resume  |                      |                                 |                  1
|      _driver +-------------------------->+      start/resume    +-----------------------------------------------------------+
|                              |           |                      |               _activate_workers |                         |
|        ^                     |           |                      | <-------------+                 |                         |
|        |                     |           +---+---+--------------+ get_slot_info |                 |       +---------------------------------+
|      9 | broken              |               |   |                              |                 |       |  Worker 2       |       slot 2  |
|        |                     |               |   |      ^                       |                 |       |                 v       host 2  |
|        |                     |               |   |      |                       |                 |       |           Python train.py       |
|        +                     |               |   |      |            +----------+-----------+     |       |                 +               |
|     _barrier                 |               |   |      |            | RendezvousServer     |     |       |                 |               |
|        ^                     |               |   |      |            |                      |     |   2   |                 v               |
|      8 | wait                |               |   |      |            |   _get_rank_and_size | <---------------------+   hvd.init()          |
|        |                     |               |   |      |            |                      |     |       |                 +               |
|        |                     |               |   |      |            +-----------+----------+     |       |                 |               |
| +------+-------------------+ |               |   |      |                        |                |       |                 +               |
| |_workers                  | |               |   |      |                        |                |       |               READY             |
| |                          | |               |   |      |                        |                |       |                 +               |
| |SUCCESS[]    <------------------------------+   |      |                        |                |       |                 |               |
| |                          | | 6 record_success  |      |                        |                |       |                 v               |
| |                          | |                   |      |                        |                |       |        hvd.elastic.run(state)   |
| |READY[(host 1, slot1)] <--------------------------------------------------------+                |       |                 +               |
| |                          | |                   |      |          3 record_ready                 |       |                 +               |
| |FAILURE[]   <-----------------------------------+      |                                         |       |         run_fn(train_function)  |
| |                          | | 7 record_failure         |                                         |       |                 +               |
| +--------------------------+ |                          |                                         |       |                 v               |
+------------------------------+                          |                                         |       |        +--------+--------+      |
                                                          |                                         |       |        |                 |      |
                                                          |                                         |       |        v                 v      |
                                                          |                                         |       |     SUCCESS           FAILURE   |
                                                          |                                         |       |         +                 +     |
                                                          |                                         |       |         |                 |     |
                                                          |                                         |       +---------------------------------+
                                                          |                                         |                 |                 |
                                                          | _handle_worker_exit                     |               4 |                5|
                                                          |                                         |                 |                 |
                                                          +-----------------------------------------------------<-----+-----------------+
                                                                                                    |
                                                                                                    +

手機如下:

至此,worker部分分析完畢,下一篇我們看看如何處理錯誤。

0xEE 個人資訊

★★★★★★關於生活和技術的思考★★★★★★

微信公眾賬號:羅西的思考

如果您想及時得到個人撰寫文章的訊息推送,或者想看看個人推薦的技術資料,敬請關注。

在這裡插入圖片描述

0xFF 參考

PaddlePaddle Fluid:彈性深度學習在Kubernetes中的實踐

Horovod 彈性訓練

ElasticDL呼叫 Horovod 在Kubernetes上實現彈性 AllReduce(一)

Kubernetes-native 彈性分散式深度學習系統

雲原生的彈性 AI 訓練系列之一:基於 AllReduce 的彈性分散式訓練實踐

相關文章