[原始碼解析] 深度學習分散式訓練框架 horovod (9) --- 啟動 on spark

羅西的思考發表於2021-07-03

[原始碼解析] 深度學習分散式訓練框架 horovod (9) --- 啟動 on spark

0x00 摘要

Horovod 是Uber於2017年釋出的一個易於使用的高效能的分散式訓練框架,在業界得到了廣泛應用。

本系列將通過原始碼分析來帶領大家瞭解 Horovod。這幾篇介紹 horovod 如何執行在 spark 之上。本文是第九篇,介紹 horovod on spark 如何啟動。

本系列其他文章如下:

[原始碼解析] 深度學習分散式訓練框架 Horovod (1) --- 基礎知識

[原始碼解析] 深度學習分散式訓練框架 horovod (2) --- 從使用者角度切入

[原始碼解析] 深度學習分散式訓練框架 horovod (3) --- Horovodrun背後做了什麼

[原始碼解析] 深度學習分散式訓練框架 horovod (4) --- 網路基礎 & Driver

[原始碼解析] 深度學習分散式訓練框架 horovod (5) --- 融合框架

[原始碼解析] 深度學習分散式訓練框架 horovod (6) --- 後臺執行緒架構

[原始碼解析] 深度學習分散式訓練框架 horovod (7) --- DistributedOptimizer

[原始碼解析] 深度學習分散式訓練框架 horovod (8) --- on spark

0x01 總體架構圖

首先,我們還是要祭出架構圖,這樣大家可以按圖索驥。

總體來說,Horovod on Spark 的總體邏輯分為以下階段:

  • 啟動 SparkDriverService 服務,利用 _make_spark_thread 啟動 Spark task,然後 horovod 會等待啟動結束;
  • 多執行緒在 spark executor 之中啟動 spark task,每個task之中執行一個 SparkTaskService,SparkTaskService 會向 hovorod 主程式中的 SparkDriverTask 進行註冊,並且等待下一步執行啟動的指令;
  • Horovod 收到所有 task 結束的資訊之後,通知各個 task,進入下一階段;
  • Horovod 呼叫 mpi_run (又利用到 mpirun_rsh.py)在每一個 spark executor 上啟動 orted 程式,以啟動 MPI cluster;
  • orted 在每一個 executor 之上執行訓練程式碼;

我們下面就具體看看如何啟動。

0x02 第一階段 :Horovod 啟動

本部分主要邏輯是:啟動 SparkDriverService 服務,利用 _make_spark_thread 啟動 Spark task,然後 horovod 會等待啟動結束。

2.1 Driver服務 :SparkDriverService

SparkDriverService 繼承了 driver_service.BasicDriverService,所以其內部啟動了一個 socket server,可以進行網路互動。

Horovod 利用 SparkDriverService 來和 Spark executor(通過其中執行的SparkTaskService)互動,比如收集資訊,讓 spark 啟動訓練job等等。這是一個 RPC 機制

具體 SparkDriverService 的功能可以參見其內部處理的各種 Request,比如

  • CodeRequest :SparkTaskService會用來請求使用者程式碼;
  • TaskHostHashIndicesRequest :獲取 task host 地址;
  • TaskIndexByRankRequest :從 rank 獲取到 task index;
  • SetLocalRankToRankRequest :從 local rank 得到 rank 資訊;
  • WaitForTaskShutdownRequest :等待 shutdown;

和前文介紹的 HorovodRunDriverService 有些類似。

其中,其成員變數 _fn 就是訓練函式,以後當 SparkTaskService 請求程式碼的時候,就通過 CodeResponse 把 _fn 直接傳送回去。這樣就解決了程式碼釋出問題

class SparkDriverService(driver_service.BasicDriverService):
    NAME = 'driver service'

    def __init__(self, initial_np, num_proc, fn, args, kwargs, key, nics):
        super(SparkDriverService, self).__init__(num_proc,
                                                 SparkDriverService.NAME,
                                                 key, nics)
        self._initial_np = initial_np
        self._fn = fn # 儲存使用者程式碼
        self._args = args # 使用者引數
        self._kwargs = kwargs 
        self._key = key
        self._nics = nics # 網路卡資訊
        self._ranks_to_indices = {}
        self._spark_job_failed = False
        self._lock = threading.Lock()
        self._task_shutdown = threading.Event()

    def _handle(self, req, client_address):

        if isinstance(req, TaskHostHashIndicesRequest): # 獲取 task host 地址
            return TaskHostHashIndicesResponse(self._task_host_hash_indices[req.host_hash])

        if isinstance(req, SetLocalRankToRankRequest): # 從 local rank 得到 rank 資訊
            self._lock.acquire()

            try:
                # get index for host and local_rank
                indices = self._task_host_hash_indices[req.host]
                index = indices[req.local_rank]

                values = list(self._ranks_to_indices.values())
                prev_pos = values.index(index) if index in values else None
                if prev_pos is not None:
                    prev_rank = list(self._ranks_to_indices.keys())[prev_pos]
                    del self._ranks_to_indices[prev_rank]

                # memorize rank's index
                self._ranks_to_indices[req.rank] = index
            finally:
                self._lock.release()
            return SetLocalRankToRankResponse(index)

        if isinstance(req, TaskIndexByRankRequest): # 是從 rank 獲取到 task index
            self._lock.acquire()
            try:
                return TaskIndexByRankResponse(self._ranks_to_indices[req.rank])
            finally:
                self._lock.release()

        if isinstance(req, CodeRequest): # SparkTaskService會用來請求使用者程式碼
            return CodeResponse(self._fn, self._args, self._kwargs)

        if isinstance(req, WaitForTaskShutdownRequest): # 等待任務結束
            self._task_shutdown.wait()
            return network.AckResponse()

        return super(SparkDriverService, self)._handle(req, client_address)

2.2 啟動spark task : _make_spark_thread

在 Horovod.spark.run 之中,_make_spark_thread 建立了 thread。這裡關鍵程式碼是:

mapper = _make_mapper(driver.addresses(), settings, use_gloo, is_elastic)
result = procs.mapPartitionsWithIndex(mapper).collect()

mapPartitionsWithIndex 這句程式碼會促使 Spark 在多個 Executor 之中執行 mapper 函式,並且得到執行結果

即建立 settings.num_procSpark tasks,每個 task 會執行 mapper(_task_fn), 外部的 run 函式會等待這些執行結果。其實如果需要使用RDD,也許可以使用 foreachPartition,這樣每個結點上將會在記憶體中持有RDD的一個分割槽。

def _make_spark_thread(spark_context, spark_job_group, driver, result_queue,
                       settings, use_gloo, is_elastic):
    """Creates `settings.num_proc` Spark tasks in a parallel thread."""
    
    def run_spark():
        """Creates `settings.num_proc` Spark tasks, each executing `_task_fn` and waits for them to terminate."""
        try:
            spark_context.setJobGroup(spark_job_group, "Horovod Spark Run", interruptOnCancel=True)
            procs = spark_context.range(0, numSlices=settings.max_np if settings.elastic else settings.num_proc)
            # We assume that folks caring about security will enable Spark RPC encryption,
            # thus ensuring that key that is passed here remains secret.
            mapper = _make_mapper(driver.addresses(), settings, use_gloo, is_elastic)
            # 促使 Spark 在多個 Executor 之中執行 mapper 函式,並且得到執行結果
            result = procs.mapPartitionsWithIndex(mapper).collect()
            result_queue.put(result)
        except:
            driver.notify_spark_job_failed()
            raise

    spark_thread = in_thread(target=run_spark, daemon=False)
    return spark_thread

2.3 等待 spark task 啟動結束

啟動了 spark task 之後,horovod 主程式會呼叫如下來等待 task 全部 啟動完成。

# wait for all tasks to register, notify them and initiate task-to-task address registration
_notify_and_register_task_addresses(driver, settings)

即,run 函式中,當 _make_spark_thread 之後,horovod 主程式呼叫 _notify_and_register_task_addresses,從而呼叫 driver.wait_for_initial_registration(settings.start_timeout) ,進行總體等待。

等待的內容是:等待所有 num_proc tasks 來註冊。當所有 spark thread 都ready 之後,主 horovod 程式會繼續執行

2.3.1 _notify_and_register_task_addresses

horovod 主程式之中,會使用 _notify_and_register_task_addresses等待這些 spark task 來註冊,從而呼叫 driver.wait_for_initial_registration(settings.start_timeout) ,進行總體等待。

注意,同時傳送註冊請求之後, spark task 自己也呼叫 task.wait_for_initial_registration 等待 horovod 再通知下一階段的啟動。

而在horovod 主程式的 _notify_and_register_task_addresses 其實也很複雜:

  • 呼叫 driver.wait_for_initial_registration 等待task來註冊,需要等待 num_proc 個task;
  • 利用 notify_and_register 註冊task,並且通知各個 task 開始下一步;

具體程式碼如下:

def _notify_and_register_task_addresses(driver, settings, notify=True):
    # wait for num_proc tasks to register
    # 等待task來註冊,需要等待 num_proc 個task
    driver.wait_for_initial_registration(settings.start_timeout) 

    def notify_and_register(index): # 註冊task,並且通知各個 task 開始下一步
        task_client = task_service.SparkTaskClient(index,
                                                   driver.task_addresses_for_driver(index),
                                                   settings.key, settings.verbose)

        if notify:
            task_client.notify_initial_registration_complete()

        next_task_index = (index + 1) % settings.num_proc
        next_task_addresses = driver.all_task_addresses(next_task_index)
        task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
        driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)

    for index in driver.task_indices():
        in_thread(notify_and_register, (index,)) #在thread之中啟動task

    driver.wait_for_task_to_task_address_updates(settings.start_timeout)

我們目前只能看其第一步 “等待註冊”。

2.3.2 driver.wait_for_initial_registration

在這裡 SparkDriverSerivce 首先等待所有 spark executor 註冊。

在 class BasicDriverService(network.BasicService): 有如下程式碼,可以看到,只有全部 _num_proc 註冊完成,當所有 spark thread 都ready 之後,主 horovod 程式會繼續執行。

這裡關鍵是:while len(self._all_task_addresses) < self._num_proc就是等待 self._all_task_addresses 的數目達到 _num_proc。

class BasicDriverService(network.BasicService):
  
  def wait_for_initial_registration(self, timeout):
      self._wait_cond.acquire()
      try:
          # 等待 self._all_task_addresses 的數目達到 _num_proc
          while len(self._all_task_addresses) < self._num_proc:
              self._wait_cond.wait(timeout.remaining())
              timeout.check_time_out_for('tasks to start')
      finally:
          self._wait_cond.release()

2.4 等待

關於等待程式碼,我們要做一下特殊說明,具體看圖。

這裡有兩套 wait_for_initial_registration。可以認為是兩套 barrier

就是:

  • barrier 1 :SparkDriverSerivce 等待所有 SparkTaskSerivce ready;
  • barrier 2 :所有 SparkTaskSerivce 需要一起執行,所以 SparkTaskSerivce們 都在等待 barrier 2。SparkDriverSerivce 會通知 這些 SparkTaskSerivce 一起發動;

2.3.1 Barrier 1 in Driver

在 run 函式中,當 _make_spark_thread 之後,horovod 主程式呼叫 _notify_and_register_task_addresses,從而呼叫 driver.wait_for_initial_registration(settings.start_timeout) ,進行總體等待。

等待的內容是:等待所有 num_proc tasks 來註冊。當所有 spark thread 都ready 之後,主 horovod 程式會繼續執行。這裡關鍵是:

while len(self._all_task_addresses) < self._num_proc

就是等待 self._all_task_addresses 的數目達到 _num_proc。

def wait_for_initial_registration(self, timeout):
    self._wait_cond.acquire()
    try:
        while len(self._all_task_addresses) < self._num_proc:
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for('tasks to start')
    finally:
        self._wait_cond.release()

在 BasicDriverService 之中,如果收到了 spark executor 的註冊請求就進行處理,這裡最重要是:

self._all_task_addresses[req.index] = req.task_addresses

當所有的 spark executor 都註冊了,這裡就等待成功

2.3.2 Barrier 2 in task

每個 spark thread 在 _task_fn 之中執行,就是在 spark task 之中執行。這裡也可以看出來是 Spark task 的一個總體流程

  • 首先 呼叫 register_task
  • 其次 呼叫 task.wait_for_initial_registration(settings.start_timeout)
  • 然後 呼叫 wait_for_command_termination 來等待結束;

task.wait_for_initial_registration 會等待 self._initial_registration_complete = True 這個條件,就是等待 register_task 註冊完成。

每個 Spark Executor 都有一個 SparkTaskService,所以 每個spark task 都有自己的 _initial_registration_complete。

hovorod.run 主程式會逐一通知每個 SparkTaskService 的 _initial_registration_complete。

即,哪個 SparkTaskService 好了,就通知哪個 SparkTaskService 的 _initial_registration_complete。這樣,這個 SparkTaskService 就可以正式執行了。

2.3.3 總體等待流程

總體等待流程具體如圖,數字就是執行順序:

  1. SparkDriverSerivce 呼叫 driver.wait_for_initial_registration 來等待 SparkTaskSerivce 的註冊,這是 barrier 1
  2. SparkTaskSerivce 1 進行註冊,然後 SparkTaskSerivce 1 自己也呼叫 task.wait_for_initial_registration 等待 horovod 再通知下一階段的啟動,這是 barrier 2
  3. SparkTaskSerivce 2 進行註冊,然後 SparkTaskSerivce 2 自己也呼叫 task.wait_for_initial_registration 等待 horovod 再通知下一階段的啟動,這是 barrier 2
  4. hovorod.run 主程式在發現所有 task 都註冊之後,barrier 1 等待結束,會逐一通知每個 SparkTaskService 的 _initial_registration_complete。只有 4 完成之後,兩個 SparkTaskSerivce 才能繼續執行 5,6;
  5. SparkTaskSerivce 1 對於 barrier 2 等待結束,繼續執行;
  6. SparkTaskSerivce 2 對於 barrier 2 等待結束,繼續執行;
    SparkTaskSerivce 1          SparkTaskSerivce 2            SparkDriverSerivce

            +                           +                             +
            |                           |                             |
            |                           |                             |
            |                           |                             |
            |                           |                             |   1
            |                           |                             |
            |                           |                             |
            |                           |                             v
            |                           |
            |                           |         +--------------------------------------+
            |                           |         | barrier 1                            |
            |                           |   2     |                                      |
            |          3                +-------> |                                      |
            |                           |         |                                      |
            +-----------------------------------> | driver.wait_for_initial_registration |
            |                           |         |                                      |
            |                           |         |                                      |
            |                           |         |                                      |
            |                           |         +--------------------+-----------------+
            |                           |                              |
            |                           |                              |
+-----------+----------------------+    |                  4           |
|barrier 2                         | <---------------------------------+
|                                  |    |                              |
|task.wait_for_initial_registration|    |                              |
|                                  |    |                              |
+-----------+----------------------+    |                              |
            |                           |                              |
            |             +-------------+----------------------+       |
            |             | barrier 2                          |   4   |
            | 6           |                                    +<------+
            |             | task.wait_for_initial_registration |       |
            |             |                                    |       |
            |             +-------------+----------------------+       |
            |                           |                              |
            |                           |                              |
            |                           |  5                           |
            |                           |                              |
            v                           v                              v

我們接下來詳細介紹 task 啟動內容 和 driver 後續工作。

0x03 第二階段 :Spark Task 啟動

本階段我們詳細介紹下 Spark Task 的啟動過程。

這部分主要功能是:多執行緒在 spark executor 之中啟動 spark task,每個spark task會執行_task_fn函式,_task_fn函式會執行一個 SparkTaskService,SparkTaskSerivce 會向 hovorod 主程式中的 SparkDriverTask 進行註冊,並且等待下一步執行啟動的指令;

此時程式(不是訓練程式,而是 SparkTaskService)已經在 Spark Executor內部執行了。我們看看在 spark Executor 之中,是如何啟動執行 SparkTaskService 的。

3.1 具體spark啟動邏輯 :_task_fn

Horovod 在 thread 裡面通過 _make_mapper 來讓 Spark 執行 _task_fn。

def _make_mapper(driver_addresses, settings, use_gloo, is_elastic):

    def _mapper(index, _):
        yield _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic)

    return _mapper

_task_fn 的作用是為了註冊 horovod 進入到 spark task。即,在每一個 spark task (executor) 之中啟動一個 SparkTaskService。

一定要注意:這些 SparkTaskService 是執行在 spark executor 之中,通過網路與 horovod 之中的 SparkDriverService 互動

可以看到,_task_fn 的總體邏輯是:

  • 啟動 SparkTaskService;
  • 通過 driver_service.SparkDriverClient.register_task 來向 horovod 中的 Driver 註冊;
  • 通過 task.wait_for_initial_registration(settings.start_timeout) 來等待下一步啟動的開始指示;
  • 如果下一步開始啟動了,則呼叫 task.wait_for_command_termination() 等待結束;

具體如下:

def _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic):
    settings.key = key
    hosthash = host_hash(salt='{}-{}'.format(index, time.time()) if is_elastic else None)
    os.environ['HOROVOD_HOSTNAME'] = hosthash
    # 啟動 SparkTaskService,SparkTaskService本身包括一個socket server,可以和driver互動
    task = task_service.SparkTaskService(index, settings.key, settings.nics,...)
    try:
        driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose)
        # 向 horovod 中的 Driver 註冊
        driver_client.register_task(index, task.addresses(), hosthash)

        # 這裡依然執行在spark task之中,但因為不是SparkTaskService,所以只是做協助工作,最後靜靜等待
        if not is_elastic:
            # 等待下一步啟動的開始指示
            task.wait_for_initial_registration(settings.start_timeout)
            task_indices_on_this_host = driver_client.task_host_hash_indices(hosthash)
            local_rank_zero_index = task_indices_on_this_host[0]
        else:
            local_rank_zero_index = None

        if is_elastic:
						...... # 後續文章會介紹
        elif use_gloo or index == local_rank_zero_index:
            # Either Gloo or first task with MPI.
            # 使用Gloo或者使用MPI的第一個task,讓這個task做操作
            task.wait_for_command_start(settings.start_timeout)
            # 等待結束
            task.wait_for_command_termination()
        else:
            # The other tasks with MPI need to wait for the first task to finish.
            # 讓其他的task等待第一個task結束
            first_task_addresses = driver_client.all_task_addresses(local_rank_zero_index)
            first_task_client = \
                task_service.SparkTaskClient(local_rank_zero_index,
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            # 呼叫 task.wait_for_command_termination() 等待結束  
            first_task_client.wait_for_command_termination()

        return task.fn_result()
    finally:
        task.shutdown()

3.2 SparkTaskService

再次強調如下程式碼:

task = task_service.SparkTaskService(index, settings.key, settings.nics,...)

每一個_task_fn 中都定義了一個 SparkTaskService,即每一個 Spark Executor 都會生成一個(或者多個) SparkTaskService,在 spark task 之中執行並且作用。

3.2.1 SparkTaskService 定義

SparkTaskService 定義如下,因為繼承了BasicTaskService,所以其內部最終也會啟動一個 socket server,以便同 horovod 中的 SparkDriverService 互動:

class SparkTaskService(task_service.BasicTaskService):
    NAME_FORMAT = 'task service #%d'

    def __init__(self, index, key, nics, minimum_command_lifetime_s, verbose=0):
        # on a Spark cluster we need our train function to see the Spark worker environment
        # this includes PYTHONPATH, HADOOP_TOKEN_FILE_LOCATION and _HOROVOD_SECRET_KEY
        env = os.environ.copy()

        # we inject the secret key here
        env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key)

        # we also need to provide the current working dir to mpirun_exec_fn.py
        env['HOROVOD_SPARK_WORK_DIR'] = os.getcwd()

        super(SparkTaskService, self).__init__(SparkTaskService.NAME_FORMAT % index,
                                               index, key, nics, env, verbose)
        self._key = key
        self._minimum_command_lifetime_s = minimum_command_lifetime_s
        self._minimum_command_lifetime = None

3.2.2 基本功能

SparkTaskService 的基本功能如下。

  • _run_command 將會被用來在 spark 之中啟動訓練job;
  • _handle 會處理 GetTaskToTaskAddressesRequest,用來獲取 task 地址,也會處理ResourcesRequest,返回資源;
  • _get_resources 將返回 spark 資源;
  • wait_for_command_termination 會等待命令執行結束;

具體程式碼如下:

def _run_command(self, command, env, event,
                 stdout=None, stderr=None, index=None,
                 prefix_output_with_timestamp=False):
    # 在 spark 之中啟動訓練job
    super(SparkTaskService, self)._run_command(command, env, event,
                                               stdout, stderr, index,
                                               prefix_output_with_timestamp)

    if self._minimum_command_lifetime_s is not None:
        self._minimum_command_lifetime = timeout.Timeout(self._minimum_command_lifetime_s,
                                                         message='Just measuring runtime')

def _handle(self, req, client_address):
    # 返回資源
    if isinstance(req, ResourcesRequest):
        return ResourcesResponse(self._get_resources())

    # 獲取 task 地址  
    if isinstance(req, GetTaskToTaskAddressesRequest):
        next_task_index = req.task_index
        next_task_addresses = req.all_task_addresses
        # We request interface matching to weed out all the NAT'ed interfaces.
        next_task_client = \
            SparkTaskClient(next_task_index, next_task_addresses,
                            self._key, self._verbose,
                            match_intf=True)
        return GetTaskToTaskAddressesResponse(next_task_client.addresses())

    return super(SparkTaskService, self)._handle(req, client_address)

def _get_resources(self):
    # 返回 spark 資源
    if LooseVersion(pyspark.__version__) >= LooseVersion('3.0.0'):
        task_context = pyspark.TaskContext.get()
        if task_context:
            return task_context.resources()
        else:
            print("Not running inside Spark worker, no resources available")
    return dict()

def wait_for_command_termination(self):
    """
    Waits for command termination. Ensures this method takes at least
    self._minimum_command_lifetime_s seconds to return after command started.
    """
    try:
        # 等待命令執行結束
        return super(SparkTaskService, self).wait_for_command_termination()
    finally:
        # command terminated, make sure this method takes at least
        # self._minimum_command_lifetime_s seconds after command started
        # the client that started the command needs some time to connect again
        # to wait for the result (see horovod.spark.driver.rsh).
        if self._minimum_command_lifetime is not None:
            time.sleep(self._minimum_command_lifetime.remaining())

3.3 註冊Task

下一步程式碼就是用來向 Driver 註冊 本 task。

driver_client.register_task(index, task.addresses(), hosthash)

3.3.1 傳送註冊請求

註冊具體通過如下完成,這裡呼叫了 network.py 的 _send 函式,就是通過 socket,spark executor 和 horovod driver 進行了網路互動:

class BasicDriverClient(network.BasicClient):

    def register_task(self, index, task_addresses, host_hash):
        self._send(RegisterTaskRequest(index, task_addresses, host_hash))

3.3.2 Driver處理

我們先來到 Horovod 中執行的 Driver來看看(下一節內容,這裡提前看看

在 BasicDriverService 之中,如果收到了RegisterTaskRequest請求就進行處理,這裡最重要是:

self._all_task_addresses[req.index] = req.task_addresses

這樣,self._all_task_addresses 的數目就增加了。

而我們之前提到了,horovod 正在 driver.wait_for_initial_registration 上面等待,其關鍵是:

while len(self._all_task_addresses) < self._num_proc

如果self._all_task_addresses 的數目達到了_num_proc,driver.wait_for_initial_registration 就結束了,就順利執行。

具體處理 RegisterTaskRequest 的程式碼如下,BasicDriverService 之中有各種成員變數,用來維護各種所需資訊,我們在前文 [原創 原始碼解析] 深度學習分散式訓練框架 horovod (4) --- 網路基礎 & Driver 中已經詳細講解過,_handle函式的RegisterTaskRequest 處理就是用來更新這些成員變數:

class BasicDriverService(network.BasicService):

    def _handle(self, req, client_address):
        if isinstance(req, RegisterTaskRequest):
            self._wait_cond.acquire()
            try:

                self._all_task_addresses[req.index] = req.task_addresses
                # Just use source address for service for fast probing.
                self._task_addresses_for_driver[req.index] = \
                    self._filter_by_ip(req.task_addresses, client_address[0])
                  
                # Remove host hash earlier registered under this index.
                if req.index in self._task_index_host_hash:
                    earlier_host_hash = self._task_index_host_hash[req.index]
                    if earlier_host_hash != req.host_hash:
                        self._task_host_hash_indices[earlier_host_hash].remove(req.index)

                # Make index -> host hash map.
                self._task_index_host_hash[req.index] = req.host_hash

                # Make host hash -> indices map.
                if req.host_hash not in self._task_host_hash_indices:
                    self._task_host_hash_indices[req.host_hash] = []
                self._task_host_hash_indices[req.host_hash].append(req.index)
                # TODO: this sorting is a problem in elastic horovod
                self._task_host_hash_indices[req.host_hash].sort()
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
                
            return network.AckResponse()

3.4 Task 等待下一步通知

前面提到了,當 spark task 向 driver 傳送註冊請求之後,Spark task 通過 task.wait_for_initial_registration(settings.start_timeout) 來等待下一步啟動的開始指示。就是 driver 認為你一景註冊完成了,可以開始進入下一步了。

task.wait_for_initial_registration 會等待 self._initial_registration_complete = True 這個條件,就是等待 register_task 註冊完成。

class BasicTaskService(network.BasicService):
  
  def wait_for_initial_registration(self, timeout):
        self._wait_cond.acquire()
        try:
            while not self._initial_registration_complete:
                self._wait_cond.wait(timeout.remaining())
                timeout.check_time_out_for('tasks to start')
        finally:
            self._wait_cond.release()

每個 Spark Executor 都有一個 SparkTaskService,所以 每個spark task 都有自己的 _initial_registration_complete。

hovorod.run 主程式會逐一通知每個 SparkTaskService 的 _initial_registration_complete。即,哪個 SparkTaskService 好了,就通知哪個 SparkTaskService 的 _initial_registration_complete。

hovorod.run 主程式 是通過傳送 NotifyInitialRegistrationCompleteRequest完成這一步的。

def notify_initial_registration_complete(self):
    self._send(NotifyInitialRegistrationCompleteRequest())

BasicTaskService 在等待 NotifyInitialRegistrationCompleteRequest,如果收到了,就設定為 True,這樣wait_for_initial_registration 就等待結束了。

if isinstance(req, NotifyInitialRegistrationCompleteRequest):
    self._wait_cond.acquire()
    try:
        self._initial_registration_complete = True
    finally:
        self._wait_cond.notify_all()
        self._wait_cond.release()
    return network.AckResponse()

就說明當本 thread 註冊在 horovod 之後,就算本 spark thread 啟動成功了。

+-------------------------------------+             +----------------------------------------------------+
| Horovod Main thread                 |             | Spark Executor                                     |
|                                     |             |                     _task_fn                       |
|                                     |             |                        +                           |
|                                     |             |                        |                           |
|                                     |             |                        |                           |
|                                     |             |                        v                           |
| +-------------------------------+   |             |  +---------------------+------------------------+  |
| | SparkDriverService            |   |             |  | SparkTaskService                             |  |
| |                               |   |             |  |               +                              |  |
| |                               |   |  1 register |  |               |                              |  |
| |  self._all_task_addresses <----------------------------------------+                              |  |
| |                               |   |             |  |               |                              |  |
| |              +                |   |             |  |               |                              |  |
| |              |                |   |             |  |               |                              |  |
| |              | 3              |   |             |  |               |                              |  |
| |              |                |   |             |  |               | 2                            |  |
| |              v                |   |             |  |               |                              |  |
| |  self._wait_cond.notify_all() |   |             |  |               |                              |  |
| |              +                |   |             |  |               v                              |  |
| |              |                |   |             |  |     +---------+---------------------------+  |  |
| |              |                |   |             |  |     |                                     |  |  |
| |              |                |   |             |  |     | task.wait_for_initial_registration  |  |  |
| |              |                |   |             |  |     |                                     |  |  |
| |              |                |   |             |  |     +-------------------------------------+  |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              v                |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| +-------------------------------+   |             |  +----------------------------------------------+  |
+-------------------------------------+             +----------------------------------------------------+

手機如下:

0x04 第三階段:Driver 通知 task 註冊成功

本階段的作用是:Horovod 收到所有 task 結束的資訊之後,通知各個 task,進入下一階段。

4.1 _notify_and_register_task_addresses

前面提到。在 horovod 主程式之中,會使用 _notify_and_register_task_addresses 來等待這些 spark task 來註冊,從而呼叫 driver.wait_for_initial_registration(settings.start_timeout) ,進行總體等待。

注意,同時傳送註冊請求之後, spark task 自己也呼叫 task.wait_for_initial_registration 等待horovod 再通知下一階段的啟動。

而 _notify_and_register_task_addresses 中其實也很複雜:

  • 呼叫 driver.wait_for_initial_registration 等待task來註冊;(目前這一步已經完成
  • 利用 notify_and_register 註冊task,並且通知各個 task 開始下一步;(我們這裡進入後面這兩步
  • 利用 driver.wait_for_task_to_task_address_updates 再次確認下所有 task 都OK;
def _notify_and_register_task_addresses(driver, settings, notify=True):
    # wait for num_proc tasks to register
    driver.wait_for_initial_registration(settings.start_timeout)

    def notify_and_register(index):
        # 註冊task,並且通知各個 task 開始下一步
        task_client = task_service.SparkTaskClient(index,
                                                   driver.task_addresses_for_driver(index),
                                                   settings.key, settings.verbose)

        if notify:
            task_client.notify_initial_registration_complete()

        next_task_index = (index + 1) % settings.num_proc
        next_task_addresses = driver.all_task_addresses(next_task_index)
        task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
        driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)

    for index in driver.task_indices():
        in_thread(notify_and_register, (index,)) # 註冊task,並且通知各個 task 開始下一步

    # 再次確認下所有 task 都OK    
    driver.wait_for_task_to_task_address_updates(settings.start_timeout)

4.2 notify_and_register

可以看到 notify_and_register 的作用就是:

  • 呼叫 task_client.notify_initial_registration_complete() 通知 spark task 註冊成功了,這樣就讓所有等待 task.wait_for_initial_registration 的 spark executor 一起執行下一階段。
  • 呼叫 driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses) 來讓 Driver 完成註冊。
def wait_for_task_to_task_address_updates(self, timeout):
    self._wait_cond.acquire()
    try:
        while len(self._task_addresses_for_tasks) < self._initial_np:
            self.check_for_spark_job_failure()
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
    finally:
        self._wait_cond.release()

4.3 wait_for_task_to_task_address_updates

這裡會再次確認所有 spark task 都OK。

def wait_for_task_to_task_address_updates(self, timeout):
    self._wait_cond.acquire()
    try:
        while len(self._task_addresses_for_tasks) < self._initial_np:
            self.check_for_spark_job_failure()
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
    finally:
        self._wait_cond.release()

4.4 等待 In Task

在 Spark task 之中,如果收到了下一步啟動指示之後,會呼叫 wait_for_command_termination 進行等待。

其實,這一步也就意味著 spark exector 自己本身的邏輯任務結束了,因為以後都是 SparkTaskService 自己獨立完成的動作,它來負責訓練程式碼的啟動。既然 _task_fn 的邏輯任務已經結束,那麼靜靜地等待即可。

4.4.1 wait_for_command_termination

在 horovod-master/horovod/spark/task/task_service.py

def wait_for_command_termination(self):
    """
    Waits for command termination. Ensures this method takes at least
    self._minimum_command_lifetime_s seconds to return after command started.
    """
    try:
        return super(SparkTaskService, self).wait_for_command_termination()
    finally:
        # command terminated, make sure this method takes at least
        # self._minimum_command_lifetime_s seconds after command started
        # the client that started the command needs some time to connect again
        # to wait for the result (see horovod.spark.driver.rsh).
        if self._minimum_command_lifetime is not None:
            time.sleep(self._minimum_command_lifetime.remaining())

在 horovod-master/horovod/runner/common/service/task_service.py 中可以看到,就是等待訓練程式碼所在的 thread 結束。

def wait_for_command_termination(self):
    self._command_thread.join() # 馬上會說明

4.4.2 _command_thread

這裡對 _command_thread 略作說明。

在 SparkTaskService 處理 RunCommandRequest 時候,執行 Command 的 thread 就是被賦值為 _command_thread。

class BasicTaskService(network.BasicService):
    def _handle(self, req, client_address):
      
        if isinstance(req, RunCommandRequest): # 執行命令請求
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:

                    if self._command_env:
                        env = self._command_env.copy()
                        self._add_envs(env, req.env)
                        req.env = env

                    self._command_abort = threading.Event()
                    self._command_stdout = Pipe() if req.capture_stdout else None
                    self._command_stderr = Pipe() if req.capture_stderr else None
                    # 配置各種引數資訊
                    args = (req.command, req.env, self._command_abort,
                            self._command_stdout, self._command_stderr,
                            self._index,
                            req.prefix_output_with_timestamp)
                    # 啟動一個新執行緒來執行命令
                    self._command_thread = in_thread(self._run_command, args)
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()  

邏輯如下:

+-------------------------------------+             +----------------------------------------------------+
| Horovod Main thread                 |             | Spark Executor                                     |
|                                     |             |                     _task_fn                       |
|                                     |             |                        +                           |
|                                     |             |                        |                           |
|                                     |             |                        |                           |
|                                     |             |                        v                           |
| +-------------------------------+   |             |  +---------------------+------------------------+  |
| | SparkDriverService            |   |             |  | SparkTaskService                             |  |
| |                               |   |             |  |               +                              |  |
| |                               |   |  1 register |  |               |                              |  |
| |  self._all_task_addresses <----------------------------------------+                              |  |
| |                               |   |             |  |               |                              |  |
| |              +                |   |             |  |               |                              |  |
| |              |                |   |             |  |               |                              |  |
| |              | 3              |   |             |  |               |                              |  |
| |              |                |   |             |  |               | 2                            |  |
| |              v                |   |             |  |               |                              |  |
| |  self._wait_cond.notify_all() |   |             |  |               |                              |  |
| |              +                |   |             |  |               v                              |  |
| |              |                |   +             +  +     +---------+---------------------------+  |  |
| |              |            4   |  RegistrationComplete    |                                     |  |  |
| |              |  +-----------------+-------------+--+---> | task.wait_for_initial_registration  |  |  |
| |              |                |   |             |  |     |                                     |  |  |
| |              |                |   |             |  |     +---------+---------------------------+  |  |
| |              |                |   |             |  |               |                              |  |
| |              |                |   |             |  |               |                              |  |
| |              |                |   |             |  |               | 5                            |  |
| |              |                |   |             |  |               |                              |  |
| |              |                |   |             |  |               |                              |  |
| |              |                |   |             |  |               v                              |  |
| |              |                |   |             |  |        wait_for_command_termination          |  |
| |              |                | 6 |  RunCommand |  |               +                              |  |
| |              |                |   |             |  |               |                              |  |
| |              +----------------------------------------------->     | 7                            |  |
| |              |                |   |             |  |               v                              |  |
| |              v                |   |             |  |        self._command_thread.join()           |  |
| |                               |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| +-------------------------------+   |             |  +----------------------------------------------+  |
+-------------------------------------+             +----------------------------------------------------+

手機如下:

至此,第一階段完成,我們下一篇繼續,敬請期待。

0x05 總結

總體來說,Horovod on Spark 的總體邏輯分為以下階段:

  • 啟動 SparkDriverService 服務,利用 _make_spark_thread 啟動 Spark task,然後 horovod 會等待啟動結束;
  • 多執行緒在 spark executor 之中啟動 spark task,每個task之中執行一個 SparkTaskService,SparkTaskService 會向 hovorod 主程式中的 SparkDriverTask 進行註冊,並且等待下一步執行啟動的指令;
  • Horovod 收到所有 task 結束的資訊之後,通知各個 task,進入下一階段;
  • Horovod 呼叫 mpi_run (又利用到 mpirun_rsh.py)在每一個 spark executor 上啟動 orted,以啟動 MPI cluster;
  • orted 在每一個 executor 之上執行訓練程式碼;

本文介紹了前三個階段,即啟動階段。下文介紹後續兩個階段,敬請期待。

0xEE 個人資訊

★★★★★★關於生活和技術的思考★★★★★★

微信公眾賬號:羅西的思考

如果您想及時得到個人撰寫文章的訊息推送,或者想看看個人推薦的技術資料,敬請關注。

在這裡插入圖片描述

相關文章