[原始碼解析] 深度學習分散式訓練框架 horovod (10) --- run on spark

羅西的思考發表於2021-07-05

[原始碼解析] 深度學習分散式訓練框架 horovod (10) --- run on spark

0x00 摘要

Horovod 是Uber於2017年釋出的一個易於使用的高效能的分散式訓練框架,在業界得到了廣泛應用。

本系列將通過原始碼分析來帶領大家瞭解 Horovod。本文是系列第十篇,看看horovod 如何執行在 spark 之上。

Horovod on Spark 具體有兩種底層實現:MPI,GLOO。因為篇幅所限,本文介紹 MPI 實現,下一篇介紹GLOO實現。

本系列其他文章如下:

[原始碼解析] 深度學習分散式訓練框架 Horovod (1) --- 基礎知識

[原始碼解析] 深度學習分散式訓練框架 horovod (2) --- 從使用者角度切入

[原始碼解析] 深度學習分散式訓練框架 horovod (3) --- Horovodrun背後做了什麼

[原始碼解析] 深度學習分散式訓練框架 horovod (4) --- 網路基礎 & Driver

[原始碼解析] 深度學習分散式訓練框架 horovod (5) --- 融合框架

[原始碼解析] 深度學習分散式訓練框架 horovod (6) --- 後臺執行緒架構

[原始碼解析] 深度學習分散式訓練框架 horovod (7) --- DistributedOptimizer

[原始碼解析] 深度學習分散式訓練框架 horovod (8) --- on spark

[原始碼解析] 深度學習分散式訓練框架 horovod (9) --- 啟動 on spark

0x01 回顧

1.1 總體序列圖

接上文,我們首先要回顧下 Horovod on Spark 的總體序列圖,這樣腦子裡有一個全景,溫故而知新。

img

1.2 總體邏輯

總體來說,Horovod on Spark 的總體邏輯分為以下階段:

  • 啟動 SparkDriverService 服務,利用 _make_spark_thread 啟動 Spark task,然後 horovod 會等待啟動結束;
  • 多執行緒在 spark executor 之中啟動 spark task,每個task之中執行一個 SparkTaskService,SparkTaskService 會向 hovorod 主程式中的 SparkDriverTask 進行註冊,並且等待下一步執行啟動的指令;
  • Horovod 收到所有 task 結束的資訊之後,通知各個 task,進入下一階段;
  • Horovod 呼叫 mpi_run (又利用到 mpirun_rsh.py)在每一個 spark executor 上啟動 orted(這裡是通過 SparkTaskService 來啟動 orted),以啟動 MPI cluster;
  • orted 在每一個 executor 之上執行訓練程式碼;

前文已經分析了前面三個階段,本文繼續後面兩個階段的分析。

1.3 問題

結合上面的流程圖,這裡就有一個問題會令人疑惑。

Horovod 按說可以直接呼叫 mpirun 來在遠端啟動 orted(orted 就是 mpi 可執行程式。mpirun 是 orterun 的別名,而 ortedrun 會最終呼叫到 orted)。但是為什麼流程圖上不是直接呼叫,而是通過 mpirun_rsh.py,進而通過 SparkTaskService 來啟動 orted?

原因應該是:

  • 通常 MPI 會通過 SSH 來連線 hosts,但是這種方式無法在 Spark Executor 之中啟動 Python function。
  • Orted 需要執行在 Spark Executor 之中,但是 mpirun 在啟動時候,沒辦法知道 Spark Executor 的 IP : PORT 這個組合,所以沒法直接啟動。
  • 因此 MPI 使用RPC 來啟動使用者程式碼:
    • 通過 SparkDriverService 和 SparkTaskService 等互動才可以知道這個 IP : PORT 組合資訊。
    • 使用 horovod.spark.driver.mpirun_rsh 來連線每個 Executor,然後 "remote shell" 到這些 executors 之中。
    • 直接使用 SparkTaskService 來啟動 orted。

0x02 第四階段 : 啟動 Job

下面我們看看第四階段,就是如何執行 訓練 job。

2.1 _launch_job

_launch_job 很簡單:

  • 首先 driver.get_common_interfaces 獲取網路路由資訊;
  • 其次 呼叫 run_contoller 來啟動 job;
def _launch_job(use_mpi, use_gloo, settings, driver, env, stdout=None, stderr=None):
    nics = driver.get_common_interfaces()
    run_controller(use_gloo, lambda: gloo_run(settings, nics, driver, env, stdout, stderr),
                   use_mpi, lambda: mpi_run(settings, nics, driver, env, stdout, stderr),
                   False, lambda: None,
                   settings.verbose)

2.2 獲取路由資訊

get_common_interfaces 與普通模式下的 get_common_interfaces 不同。因為此時,Spark Executor 之中的 SparkTaskService 的資訊已經儲存在 Driver 之中,直接獲取即可。

def get_common_interfaces(self):
    if self._nics is not None:
        return self._nics

    nics = None
    if len(self._task_addresses_for_tasks) > 0:
        # in Elastic Horovod on Spark with auto-scaling
        # keys in task_addresses are in range(max_np or proc_num)
        # but not all keys may exist, so we don't do for index in range(proc_num)
        indices = list(self._task_addresses_for_tasks.keys())
        nics = set(self._task_addresses_for_tasks[indices[0]].keys()) # 直接獲取
        for index in indices[1:]:
            nics.intersection_update(self._task_addresses_for_tasks[index].keys())

    return nics

2.3 run_controller

就是依據配置和編譯情況來進行處理,選擇 gloo,js,還是 mpi。

def run_controller(use_gloo, gloo_run, use_mpi, mpi_run, use_jsrun, js_run, verbosity):
    if use_gloo:
        gloo_run()
    elif use_mpi:
        mpi_run()
    elif use_jsrun:
        js_run()
    else:
        if mpi_built(verbose=verbose):
            if lsf.LSFUtils.using_lsf() and is_jsrun_installed():
                js_run()
            else:
                mpi_run()
        elif gloo_built(verbose=verbose):
            gloo_run()

所以我們開始啟動 job,具體我們分為 MPI,GLOO兩種進行分析。

0x03 MPI 實驗

我們首先要做一些 MPI 相關實驗,其原因是因為:

  • MPI 的呼叫之中有些看起來很奇怪的行為,或者說是一些 trick。
  • 這些 trick 對於 "horovod on spark" 基於 MPI 的實現是很有幫助,但是對於理解程式碼卻是一個極大的干擾
  • 我們暫時沒有時間和精力去研究 MPI 的原始碼是如何實現的,因為已經超出了本文範疇。

所以我們只能針對某些奇怪的行為,對 MPI 的相關實現機制做一些假設和估計。然後通過一個簡單的實驗來驗證我們的假設。

3.1 問題點

我們執行的 mpi 命令格式如下,這個命令格式就是為了模擬Horovod的 MPI 命令:

mpirun --allow-run-as-root -n 4 --hostfile ./remote_hosts -mca plm_rsh_agent "python rsh.py" python user_function.py

問題點就是:

  • plm_rsh_agent "python rsh.py" 的作用是什麼?
  • rsh.py 之中,有哪些 trick?如何呼叫遠端 mpi 程式?
  • python user_function.py 是在 rsh.py 之後執行嗎?

3.2 名詞解釋

3.2.1 orterun & orted

最開始在看到這個命令時候,容易讓人很暈。因為程式碼中沒有任何提及。

其實,outed 就是 mpi 可執行程式。

mpirun 是 orterun 的別名,而 ortedrun 會最終呼叫到 orted

具體解釋如下,資訊來源為 http://cn.voidcc.com/question/p-wkloammx-bha.html:

mpirunmpiexec基本上是相同的 - 許多MPI實現中的程式啟動器的名稱。 MPI標準沒有提到如何啟動和控制等級,但它建議(儘管不要求),如果有任何型別的啟動器,它應該被命名為mpiexec。一些MPI實現以mpirun開始,然後採用mpiexec以實現相容性。其他實現則相反。最後,大多數實現都使用兩個名稱來提供它們的啟動器。在實踐中,mpirunmpiexec所做的事情應該沒有什麼不同。

不同的MPI實現有不同的啟動和控制過程的方法。 MPICH從一個名為MPD(多用途守護程式或其他)的基礎架構開始。然後切換到新的Hydra流程管理器。由於Hydra的功能與MPD不同,因此基於Hydra的mpiexec採用的命令列引數不同於基於MPD的命令列引數,並且使使用者可以明確選擇基於Hydra的命令列引數,因此它可用作mpiexec.hydra。舊的稱為mpiexec.mpd。可能有一個基於MPICH的MPI庫只提供Hydra啟動程式,然後mpiexecmpiexec.hydra將是相同的可執行檔案。英特爾MPI基於MPICH,其新版本使用Hydra程式管理器。

Open MPI建立在開放執行環境(ORTE)的基礎上,其自身的程式啟動器被稱為orterun。為了相容,orterun也符號連結為mpirunmpiexec

總結:

  • mpiexec.something是MPI程式啟動的給定實現的特定版本
  • mpiexecmpirun是通用名稱的符號連結到實際發射通常副本或
  • mpiexecmpirun應該這樣做
  • 某些實現命名他們的發射器mpiexec,有些人命名它mpirun,有人將其命名為兩者,當系統路徑中同時有多個MPI實現可用時,這通常是混淆的來源(例如,當從發行版安裝時)

3.2.2 mpi orterun 原始碼

mpi之中 orterun 對應的原始碼如下,最主要是呼叫了 orte_submit_job 提交 job。

int orterun(int argc, char *argv[])
{
    orte_submit_status_t launchst, completest;

    /* orte_submit_init() will also check if the user is running as
       root (and may issue a warning/exit). */
    if (ORTE_SUCCESS != orte_submit_init(argc, argv, NULL)) {
        exit(1);
    }

    /* setup to listen for commands sent specifically to me, even though I would probably
     * be the one sending them! Unfortunately, since I am a participating daemon,
     * there are times I need to send a command to "all daemons", and that means *I* have
     * to receive it too
     */
    orte_rml.recv_buffer_nb(ORTE_NAME_WILDCARD, ORTE_RML_TAG_DAEMON,
                            ORTE_RML_PERSISTENT, orte_daemon_recv, NULL);

    /* if the user just wants us to terminate a DVM, then do so */
    if (orte_cmd_options.terminate_dvm) {
        // 省略部分程式碼
    } else {
        /* spawn the job and its daemons */
        memset(&launchst, 0, sizeof(launchst));
        memset(&completest, 0, sizeof(completest));
        launchst.active = true;
        completest.active = true;
      
        // 在這裡進行提交 job
        if (ORTE_SUCCESS != orte_submit_job(argv, NULL,
                                            launched, &launchst,
                                            completed, &completest)) {
            ORTE_UPDATE_EXIT_STATUS(1);
            goto DONE;
        }
    }

    // wait for response and unpack the status, jobid
    // 省略部分程式碼
}

3.3 實驗設計

3.3.1 元件

有如下幾個元件,其作用分別如下:

  • host 檔案。作用是指定本次執行有哪些host,以及host之上執行幾個MPI程式。
  • rsh.py。作用是作為 rsh agent 來給遠端機器下達命令。
    • MPI 使用者也可以通過其他方式給遠端機器下發命令。
    • 使用者可以對每個主機使用遠端 shell(sshrsh)而無需登入主機。預設情況下,mpirun 使用 ssh
    • 如果 mpirun 使用 ssh 出現問題,可以嘗試在 mpirun 命令中使用 --mca plm_rsh_agent rsh 選項,以使用 rsh 命令進行連線。
  • user_function.py。就是使用者希望執行的函式。

3.3.2 host 檔案 remote_hosts

remote_hosts 檔案內容如下:

1.1.1.1:2
2.2.2.2:2

其意義是:

  • 1.1.1.1 這個 ip 執行 2 個 slot,即兩個 MPI 程式。
  • 2.2.2.2 這個 ip 執行 2 個 slot,即兩個 MPI 程式。

3.3.3 rsh.py

rsh.py 內容如下,作用就是列印 MPI 傳入的 command,然後在遠端host之上啟動的 MPI 程式中執行新命令:

import os
import sys
import subprocess

if __name__ == '__main__':
  command = " ".join(sys.argv[0:])
  print(command)
  new_command = " ".join(sys.argv[2:])
  print(new_command)
  subprocess.Popen(new_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

3.3.4 user_function.py

內容如下,就是為了測試,列印一條語句。

print('hello world')

3.4 實驗結果

我們在 1.1.1.1 之上執行 mpi 命令。

mpirun --allow-run-as-root -n 4 --hostfile ./remote_hosts -mca plm_rsh_agent "python rsh.py" python user_function.py

結果如下:

# 以下是 command 內容,就是 MPI 傳遞給 rsh.py 的內容,這裡居然有 plm_rsh_agent "python rsh.py" 
rsh.py 1.1.1.1 orted -mca ess "env" -mca ess_base_jobid "41114481152" -mca ess_base_vpid 1 -mca ess_base_num_procs "4" -mca ored_node_regex "ip-[2:1]-1-1-1,[2:1]2.2.2@0(2)" -mca orted_hnp_uri "41114481152.0,tcp://1.1.1.1:53405" -mca plm "rsh" --tree-spawn -mca orte_parent_uri "41114481152.0,tcp://1.1.1.1:53405"  -mca plm_rsh_agent "python rsh.py" -mca pmix "^s1,s2,cray,isolated"

# 以下是 new_command 內容,就是 在遠端host上執行 使用者程式碼 的方法,這裡居然有 plm_rsh_agent "python rsh.py" 
orted -mca ess "env" -mca ess_base_jobid "41114481152" -mca ess_base_vpid 1 -mca ess_base_num_procs "4" -mca ored_node_regex "ip-[2:1]-1-1-1,[2:1]2.2.2@0(2)" -mca orted_hnp_uri "41114481152.0,tcp://1.1.1.1:53405" -mca plm "rsh" --tree-spawn -mca orte_parent_uri "41114481152.0,tcp://1.1.1.1:53405"  -mca plm_rsh_agent "python rsh.py" -mca pmix "^s1,s2,cray,isolated"

# 以下是 user_function.py 的執行內容
hello world

因此我們知道

  • plm_rsh_agent "python rsh.py" 的作用是在遠端執行 MPI orted。
  • python user_function.py 是在 rsh 之後執行的,而且是在遠端的 orted 之中執行。
  • 在 rsh.py 執行過程中,其接受到的命令內容有些奇怪

3.5 執行過程

執行過程如下:

  1. mpirun 執行 mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py,此時在遠端會執行一個 MPI daemon,用來響應處理;
  2. mpirun 呼叫了 rsh.py;
  3. rsh.py 使用 subprocess(orted -mca plm_rsh_agent "python rsh.py") 在遠端啟動 orted(會與 daemon 溝通),執行使用者程式碼;

具體如下圖:

                                                         1.1.1.1        +          2.2.2.2
                                                                        |
                                                                        |
                                                                        |  1      +---------------+
mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py  +----------->  |  MPI deamon   |
                                                                        |         +-------+-------+
             +                                                          |                 |
             |                                                          |                 |
             | 2                                                        |                 |
             |                                                          |                 |  3
             |                                                          |                 |
             |  rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py" |                 |
             |                                                          |                 v
             |                                                          |         +-------+--------------------------+
             |                                                          |         | orted                            |
             |                                                          |         |                                  |
             v                                                          |         |                                  |
+------------+------------------------------------------------------+   |         |   +---------------------------+  |
| rsh.py                                                            |   |         |   | user_function.py          |  |
|                                                                   |   |         |   |                           |  |
|    rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py"        |   |         |   |                           |  |
|                                                                   |   |   3     |   |      print('hello world') |  |
|    subprocess(orted -mca plm_rsh_agent "python rsh.py") +-------------------->  |   |                           |  |
|                                                                   |   |         |   +---------------------------+  |
+-------------------------------------------------------------------+   +         +----------------------------------+

手機如下:

3.6 Trick 分析

我們發現有幾個奇怪的點:

  • mpirun 執行 mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py
  • mpirun 呼叫了 rsh.py,但是在 rsh.py 收到的 argv 中,居然也有 plm_rsh_agent "python rsh.py" 。按說這時候不應該有這個引數了,因為 rsh.py 已經呼叫了,就不應該再有這個引數
  • rsh.py 執行遠端 MPI,使用的是 orted -mca plm_rsh_agent "python rsh.py",這裡居然還有 plm_rsh_agent "python rsh.py" 這個引數。這時候也不應該,因為 orted 已經執行在遠端了,這時候也傳入一個用來遠端控制的 rsh agent 引數,太奇怪了

就是說plm_rsh_agent "python rsh.py" 這個引數居然被 MPI 傳遞到各個階段,無論是 rsh agent 或者 遠端 mpi

rsh agent 就是 trick。不知道 MPI 為什麼要把 plm_rsh_agent "python rsh.py" 在各個階段傳遞的意圖,可能是為了更好的控制。

因為沒有精力來分析 MPI 原始碼,所以初步判斷,遠端 MPI daemon 在執行 orted -mca plm_rsh_agent "python rsh.py"時候,會判斷是否已經是遠端,如果是遠端,就不再執行 rsh agent 了。

所以,我們在後面分析中,在 Spark task 之中 發現 類似 plm_rsh_agent "python rsh.py" ,就不用再疑惑了

0x04 MPI 實現

一般來說,Horovod on Spark 是以 MPI 模式來執行,所以我們重點看這裡。

4.1 mpi_run in spark

mpi_run 程式碼位於:horovod/spark/mpi_run.py,作用是:

  • 依據各種配置生成remote shell的agent;
  • 依據各種配置生成可執行命令;
  • 呼叫hr_mpi_run(horovod.runner.mpi_run 就是普通模式下的 mpi_run)執行命令;

比如得到 rsh_agent 大致如下:

("/usr/bin/python", "-m", "horovod.spark.driver.mpirun_rsh", "xxxxx", "yyy")

得到 command 大致如下:

("/usr/bin/python", "-m", "horovod.spark.task.mpirun_exec_fn", "xxxxx", "yyy")

具體程式碼如下:

from horovod.runner.mpi_run import mpi_run as hr_mpi_run

def mpi_run(settings, nics, driver, env, stdout=None, stderr=None):
    """
    Runs mpirun.

    :param settings: Settings for running MPI.
                     Note: settings.num_proc and settings.hosts must not be None.
    :param nics: Interfaces to include by MPI.
    :param driver: The Spark driver service that tasks are connected to.
    :param env: Environment dictionary to use for running MPI.  Can be None.
    :param stdout: Stdout of the mpi process.
                   Only used when settings.run_func_mode is True.
    :param stderr: Stderr of the mpi process.
                   Only used when settings.run_func_mode is True.
    """
    env = {} if env is None else copy.copy(env)  # copy env so we do not leak env modifications

    # Pass secret key through the environment variables.
    env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(settings.key)
    # we don't want the key to be serialized along with settings from here on
    settings.key = None

    # 拼接出rsh_agent
    rsh_agent = (sys.executable,
                 '-m', 'horovod.spark.driver.mpirun_rsh',
                 codec.dumps_base64(driver.addresses()),
                 codec.dumps_base64(settings))
    settings.extra_mpi_args = ('{extra_mpi_args} -x NCCL_DEBUG=INFO -mca plm_rsh_agent "{rsh_agent}"'
                               .format(extra_mpi_args=settings.extra_mpi_args if settings.extra_mpi_args else '',
                                       rsh_agent=' '.join(rsh_agent)))
    # 拼接出command
    command = (sys.executable,
               '-m', 'horovod.spark.task.mpirun_exec_fn',
               codec.dumps_base64(driver.addresses()),
               codec.dumps_base64(settings))
    hr_mpi_run(settings, nics, env, command, stdout=stdout, stderr=stderr)

4.2 mpi_run in normal

上面程式碼最後是執行 hr_mpi_run,其實 hr_mpi_run 是 horovod.runner.mpi_run,就是普通模式下的 mpi_run。

horovod.runner.mpi_run 首先 就是依據各種配置以及引數來構建 mpirun 命令的所有引數,比如 ssh 的引數,mpi 引數,nccl 引數等等。

得到了 command 大致如下:

mpirun --allow-run-as-root --map-by slot -x SSH_CONNECITION -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" /usr/bin/python -m horovod.spark.task.mpurun_exec_fn xxxxx

具體程式碼如下:

def mpi_run(settings, nics, env, command, stdout=None, stderr=None):
    """
    Runs mpi_run.

    Args:
        settings: Settings for running MPI.
                  Note: settings.num_proc and settings.hosts must not be None.
        nics: Interfaces to include by MPI.
        env: Environment dictionary to use for running command.
        command: Command and arguments to run as a list of string.
        stdout: Stdout of the mpi process.
                Only used when settings.run_func_mode is True.
        stderr: Stderr of the mpi process.
                Only used when settings.run_func_mode is True.
    """

    # 獲取mpi相關配置
    mpi_impl_flags, impl_binding_args, mpi = _get_mpi_implementation_flags(settings.tcp_flag, env=env)
    impi = _IMPI_IMPL == mpi

    # 獲取ssh配置
    ssh_args = []
    if settings.ssh_port:
        ssh_args += [f'-p {settings.ssh_port}']
    if settings.ssh_identity_file:
        ssh_args += [f'-i {settings.ssh_identity_file}']

    mpi_ssh_args = ''
    if ssh_args:
        joined_ssh_args = ' '.join(ssh_args)
        mpi_ssh_args = f'-bootstrap=ssh -bootstrap-exec-args \"{joined_ssh_args}\"' if impi else f'-mca plm_rsh_args \"{joined_ssh_args}\"'

    # 網路卡相關資訊
    tcp_intf_arg = '-mca btl_tcp_if_include {nics}'.format(
        nics=','.join(nics)) if nics and not impi else ''
    nccl_socket_intf_arg = '-{opt} NCCL_SOCKET_IFNAME={nics}'.format(
        opt='genv' if impi else 'x',
        nics=','.join(nics)) if nics else ''

    # On large cluster runs (e.g. Summit), we need extra settings to work around OpenMPI issues
    host_names, host_to_slots = hosts.parse_hosts_and_slots(settings.hosts)
    if not impi and host_names and len(host_names) >= _LARGE_CLUSTER_THRESHOLD:
        mpi_impl_flags.append('-mca plm_rsh_no_tree_spawn true')
        mpi_impl_flags.append('-mca plm_rsh_num_concurrent {}'.format(len(host_names)))

    # if user does not specify any hosts, mpirun by default uses local host.
    # There is no need to specify localhost.
    hosts_arg = '-{opt} {hosts}'.format(opt='hosts' if impi else 'H',
                hosts=','.join(host_names) if host_names and impi else settings.hosts)

    ppn_arg = ' '
    if host_to_slots and impi:
        ppn = host_to_slots[host_names[0]]
        for h_name in host_names[1:]:
            if ppn != host_to_slots[h_name]:
                raise Exception('''Different slots in -hosts parameter are not supported in Intel(R) MPI.
                                 Use -machinefile <machine_file> for this purpose.''')
        ppn_arg = ' -ppn {} '.format(ppn)

    if settings.prefix_output_with_timestamp and not impi:
        mpi_impl_flags.append('--timestamp-output')

    binding_args = settings.binding_args if settings.binding_args and not impi else ' '.join(impl_binding_args)

    basic_args = '-l' if impi else '--allow-run-as-root --tag-output'

    output = []
    if settings.output_filename:
        output.append('-outfile-pattern' if impi else '--output-filename')
        output.append(settings.output_filename)

    env_list = '' if impi else ' '.join(
                    '-x %s' % key for key in sorted(env.keys()) if env_util.is_exportable(key))

    # Pass all the env variables to the mpirun command.
    mpirun_command = (
        'mpirun {basic_args} '
        '-np {num_proc}{ppn_arg}{hosts_arg} '
        '{binding_args} '
        '{mpi_args} '
        '{mpi_ssh_args} '
        '{tcp_intf_arg} '
        '{nccl_socket_intf_arg} '
        '{output_filename_arg} '
        '{env} {extra_mpi_args} {command}'  # expect a lot of environment variables
        .format(basic_args=basic_args,
                num_proc=settings.num_proc,
                ppn_arg=ppn_arg,
                hosts_arg=hosts_arg,
                binding_args=binding_args,
                mpi_args=' '.join(mpi_impl_flags),
                tcp_intf_arg=tcp_intf_arg,
                nccl_socket_intf_arg=nccl_socket_intf_arg,
                mpi_ssh_args=mpi_ssh_args,
                output_filename_arg=' '.join(output),
                env=env_list,
                extra_mpi_args=settings.extra_mpi_args if settings.extra_mpi_args else '',
                command=' '.join(quote(par) for par in command))
    )

    # we need the driver's PATH and PYTHONPATH in env to run mpirun,
    # env for mpirun is different to env encoded in mpirun_command
    for var in ['PATH', 'PYTHONPATH']:
        if var not in env and var in os.environ:
            # copy env so we do not leak env modifications
            env = copy.copy(env)
            # copy var over from os.environ
            env[var] = os.environ[var]

    # Execute the mpirun command.
    if settings.run_func_mode:
        exit_code = safe_shell_exec.execute(mpirun_command, env=env, stdout=stdout, stderr=stderr)
    else:
        os.execve('/bin/sh', ['/bin/sh', '-c', mpirun_command], env)

4.3 執行命令

目前得到的命令是:

mpirun --allow-run-as-root --map-by slot -x SSH_CONNECITION -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" /usr/bin/python -m horovod.spark.task.mpurun_exec_fn xxxxx

所以我們接著分析。

當 mpi_run 準備好命令之後,他呼叫 safe_shell_exec.execute 或者 bin/sh 執行命令。對於 safe_shell_exec.execute 來說,它需要執行的命令是:

mpirun --allow-run-as-root --map-by slot -x SSH_CONNECITION -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" /usr/bin/python -m horovod.spark.task.mpurun_exec_fn xxxxx

這樣,就是先呼叫 safe_shell_exec.execute 或者 bin/sh 執行 "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx",然後執行 horovod.spark.task.mpurun_exec_fn xxxxx。

4.3.1 mpi 引數

對於 mpirun 來說,引數 --mca pls_rsh_agent rsh 告訴節點間通訊用rsh。

這樣我們就知道 horovod.spark.driver.mpirun_rsh 就是在節點通訊時候,首先執行的指令碼。

就是說,當 mpirun 想在異地節點執行一個 程式(horovod.spark.task.mpurun_exec_fn) 時候,首先執行 horovod.spark.driver.mpirun_rsh 從而在異地節點上啟動一個 orted,其次在這個 異地 orted 之上執行 horovod.spark.task.mpurun_exec_fn

4.3.3 mpirun_rsh.py

所以,horovod.spark.driver.mpirun_rsh 會最先執行,我們需要首先看看,就是下圖中最下面部分

mpirun_rsh.py 的作用如其註釋所言,目的是被 MPI 呼叫以便連線到一個 host,並且執行指定的命令

命令通常是 orted ,用來建立 MPI cluster。orted 程式然後被用來啟動遠端程式(Horovod 使用者的 Python方法)。 orted 程式將執行在最低index的 task上,同一個host 的其他task將執行 no-op 並且等待 orted task 結束

Method run by MPI to connect to a host hash and execute the given command.

The command is usually `orted` to setup the MPI cluster. That `orted` process
is then used to spin-up the actual remote process, the Horovod user's Python method.
The `orted` process will run on the lowest task index and all other tasks with the
same host hash are expected to no-op (see `horovod.spark._task_fn`)
and wait for the first task to terminate.

但是實際上程式碼其實很簡單,就是直接呼叫了 rsh,所以我們還得接著看。

if len(sys.argv) < 5:
    print('Usage: %s <service addresses> <settings> <host hash> '
          '<command...>' % sys.argv[0])
    sys.exit(1)

addresses = codec.loads_base64(sys.argv[1])
key = codec.loads_base64(os.environ.get(secret.HOROVOD_SECRET_KEY))
settings = codec.loads_base64(sys.argv[2])
host_hash = sys.argv[3]
command = " ".join(sys.argv[4:])
env = {}  # orted does not need any env vars, the target training code gets env from mpirun

# Since tasks with the same host hash have shared memory,
# we will run only one orted process on the first task.
rsh(addresses, key, host_hash, command, env, 0, settings.verbose) # 直接呼叫

4.3.4 rsh

這裡才是上述邏輯的具體實現,所以rsh 的作用就是:

  • 與在 Spark Driver 上執行的 SparkDriverService 進行互動,從 SparkDriverService 獲取需要執行 task 的所需資訊;
  • 與 Spark Executor 中的 SparkTaskService 互動,執行 command;

具體到程式碼就是:

  • 利用 driver_client.task_host_hash_indices(host_hash) 從在 Spark Driver 上執行的 SparkDriverService 獲取某一個 host 上的所有 task;
  • 利用 task_indices[local_rank] 獲取到對應的 task;
  • 利用 driver_client.all_task_addresses(task_index) 獲取 task 的地址;
  • 利用 task_service.SparkTaskClient.run_command 來執行 command;

command 舉例如下,此時 command 已經被 mpirun 處理轉義

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

具體程式碼是:

def rsh(driver_addresses, key, host_hash, command, env, local_rank, verbose,
        stdout=None, stderr=None, prefix_output_with_timestamp=False,
        background=True, events=None):
    """
    Method to run a command remotely given a host hash, local rank and driver addresses.

    This method connects to the SparkDriverService running on the Spark driver,
    retrieves all information required to connect to the task with given local rank
    of that host hash and invoke the command there.

    The method returns immediately after launching the command if background is True (default).
    When background is set to False, this method waits for command termination and returns
    command's result. If there is an exception while waiting for the result (i.e. connection reset)
    it returns -1.

    :param driver_addresses: driver's addresses
    :param key: used for encryption of parameters passed across the hosts
    :param host_hash: host hash to connect to
    :param command: command and arguments to invoke
    :param env: environment to use
    :param local_rank: local rank on the host of task to run the command in
    :param verbose: verbosity level
    :param stdout: Task stdout is redirected to this stream.
    :param stderr: Task stderr is redirected to this stream.
    :param prefix_output_with_timestamp: shows timestamp in stdout/stderr forwarding on the driver if True
    :param background: run command in background if True, returns command result otherwise
    :param events: events to abort the command, only if background is True
    :return exit code if background is False
    """
    driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=verbose)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    task_index = task_indices[local_rank]
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=verbose)
    task_client.stream_command_output(stdout, stderr)
    task_client.run_command(command, env,
                            capture_stdout=stdout is not None,
                            capture_stderr=stderr is not None,
                            prefix_output_with_timestamp=prefix_output_with_timestamp)

    if not background:
        events = events or []
        stop = threading.Event()
        for event in events:
            on_event(event, task_client.abort_command, stop=stop)

        try:
            exit_code = task_client.wait_for_command_exit_code()
            return exit_code
        except:
            traceback.print_exc()
            return -1
        finally:
            stop.set()

4.3.5 傳送命令

具體 run_command 就是向 SparkTaskService 傳送 RunCommandRequest。

class BasicTaskClient(network.BasicClient):

  def run_command(self, command, env,
                    capture_stdout=False, capture_stderr=False,
                    prefix_output_with_timestamp=False):
        self._send(RunCommandRequest(command, env,
                                     capture_stdout, capture_stderr,
                                     prefix_output_with_timestamp))

具體如下圖邏輯所示:

與之前的測試程式碼對比如下:

                                                   Our test code     +    Horovod on spark
                                                                     |
                                                                     |
 mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py   |    mpirun pls_rsh_agent "python mpirun_rsh" python -m mpurun_exec_fn
                                                                     |
          +                                                          |           +
          |                                                          |           |
          |  rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py" |           |    orted -mca pls_rsh_agent "python -m mpirun_rsh"
          |                                                          |           |
          v                                                                      v
+----------------------------------------------------------------+   |    +------+---------------------------------------------------+
| rsh.py (via SSH)                                               |   |    | mpirun_rsh                                               |
|                                                                |   |    |                                                          |
|    rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py"     |   |    +------+---------------------------------------------------+
|                                                                |   |           |
|                                                                |   |           |
|                                                                |   |           v
|                                                                |   |    +----------------------------------------------------------+
|                                                                |   |    | rsh (via RPC)                                            |
|                                                                |   |    |                                                          |
|    subprocess(orted -mca plm_rsh_agent "python rsh.py")        |   |    |                                                          |
|                                                                |   |    |  task_client = task_service.SparkTaskClient              |
|                                                                |   |    |                                                          |
|                                                                |   |    |  task_client.run_command(                                |
|                                                                |   |    |       orted -mca pls_rsh_agent "python -m mpirun_rsh"    |
|                                                                |   |    |  )                                                       |
+---------+------------------------------------------------------+   |    +------+---------------------------------------------------+
          |                                                          |           |
          |                                                          |           |
          v                                                          |           v
+---------+------------------------------------------------------+   |    +------+---------------------------------------------------+
| user_function.py                                               |   |    | mpirun_exec_fn.py                                        |
|                                                                |   |    |                                                          |
|    print('hello world')                                        |   |    |              task_exec +--------> user_function          |
|                                                                |   |    |                                                          |
+----------------------------------------------------------------+   |    +----------------------------------------------------------+
                                                                     +

手機如下:

因此,下面就會進入到 spark executor 去執行

4.4 Run in Spark Executor

再次注意,這裡已經是 遠端的 Spark Executor 了

上節提到,系統會利用 task_service.SparkTaskClient.run_command 來執行command;

command 舉例如下,此時 command 已經被 mpirun 處理轉義:

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

需要依據上圖留意一點:系統在 Spark Executor 上執行 command 之後,會接著執行 mpirun_exec_fn

我們接下來就看看如何處理 RunCommandRequest。具體都是在 BasicTaskService 之中完成。

4.4.1 RunCommandRequest

可以看到,接受到訊息之後,是呼叫 _run_command 完成。

def _handle(self, req, client_address):
    if isinstance(req, RunCommandRequest):
        self._wait_cond.acquire()
        try:
            if self._command_thread is None:
                # we add req.env to _command_env and make this available to the executed command
                if self._command_env:
                    env = self._command_env.copy()
                    self._add_envs(env, req.env)
                    req.env = env

                # We only permit executing exactly one command, so this is idempotent.
                self._command_abort = threading.Event()
                self._command_stdout = Pipe() if req.capture_stdout else None
                self._command_stderr = Pipe() if req.capture_stderr else None
                args = (req.command, req.env, self._command_abort,
                        self._command_stdout, self._command_stderr,
                        self._index,
                        req.prefix_output_with_timestamp)
                self._command_thread = in_thread(self._run_command, args)
        finally:
            self._wait_cond.notify_all()
            self._wait_cond.release()
        return network.AckResponse()

4.4.2 _run_command

_run_command 就是呼叫 safe_shell_exec.execute 直接執行。

def _run_command(self, command, env, event,
                 stdout=None, stderr=None, index=None,
                 prefix_output_with_timestamp=False):
    self._command_exit_code = safe_shell_exec.execute(
        command,
        env=env,
        stdout=stdout, stderr=stderr,
        index=index,
        prefix_output_with_timestamp=prefix_output_with_timestamp,
        events=[event])
    if stdout:
        stdout.close()
    if stderr:
        stderr.close()

因此,接下來就是在 Spark Executor 之中,開始執行

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

注意,此時是在 Spark Executor 之中,所以接下來行為會和之前不同。

4.4.3 mpirun_rsh

mpirun_rsh 依然是呼叫 rsh。

addresses = codec.loads_base64(sys.argv[1])
key = codec.loads_base64(os.environ.get(secret.HOROVOD_SECRET_KEY))
settings = codec.loads_base64(sys.argv[2])
host_hash = sys.argv[3]
command = " ".join(sys.argv[4:])
env = {}  # orted does not need any env vars, the target training code gets env from mpirun

# Since tasks with the same host hash have shared memory,
# we will run only one orted process on the first task.
rsh(addresses, key, host_hash, command, env, 0, settings.verbose)

4.4.4 rsh

程式碼如下:

def rsh(driver_addresses, key, host_hash, command, env, local_rank, verbose,
        stdout=None, stderr=None, prefix_output_with_timestamp=False,
        background=True, events=None):
    driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=verbose)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    task_index = task_indices[local_rank]
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=verbose)
    task_client.stream_command_output(stdout, stderr)
    task_client.run_command(command, env,
                            capture_stdout=stdout is not None,
                            capture_stderr=stderr is not None,
                            prefix_output_with_timestamp=prefix_output_with_timestamp)

但是此時執行就出現了與之前的不同之處

此時在 Spark Executor 再次呼叫

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

回憶一下 0x03 MPI 實驗 的結果,我們知道,pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" 這部分在遠端上其實不會有實際效果遠端 orted 轉而會繼續執行傳過來的 mpirun_exec_fn

如果哪位朋友對 MPI 有深入瞭解,還望賜教

4.4.5 mpirun_exec_fn

程式碼位於:horovod/spark/task/mpirun_exec_fn.py。

就是呼叫到了task_exec。

def main(driver_addresses, settings):
    # prepend HOROVOD_SPARK_PYTHONPATH to PYTHONPATH
    if 'HOROVOD_SPARK_PYTHONPATH' in os.environ:
        ppath = os.environ['HOROVOD_SPARK_PYTHONPATH']

        # add injected HOROVOD_SPARK_PYTHONPATH to sys.path
        for p in reversed(ppath.split(os.pathsep)):
            sys.path.insert(1, p)  # don't put it in front which is usually .

        if 'PYTHONPATH' in os.environ:
            ppath = os.pathsep.join([ppath, os.environ['PYTHONPATH']])
        os.environ['PYTHONPATH'] = ppath

    # change current working dir to where the Spark worker runs
    # because orted runs this script where mpirun was executed
    # this env var is injected by the Spark task service
    work_dir = os.environ.get('HOROVOD_SPARK_WORK_DIR')
    if work_dir:
        os.chdir(work_dir)

    task_exec(driver_addresses, settings, 'OMPI_COMM_WORLD_RANK', 'OMPI_COMM_WORLD_LOCAL_RANK')

0x05 第五階段 : 執行使用者程式碼

5.1 task_exec

task_exec 就是用來執行使用者程式碼。

可以看到,是從 Driver 之中取出之前儲存的程式碼,然後執行。

def task_exec(driver_addresses, settings, rank_env, local_rank_env):
    # Die if parent process terminates
    in_thread(target=_parent_process_monitor, args=(os.getppid(),))

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    rank = int(os.environ[rank_env])
    local_rank = int(os.environ[local_rank_env])
    driver_client = driver_service.SparkDriverClient(driver_addresses, key,
                                                     verbose=settings.verbose)

    # tell driver about local rank and rank
    # in elastic mode the driver already knows this mapping
    # for simplicity we keep code paths the same for elastic and static mode
    host_hash = os.environ['HOROVOD_HOSTNAME']
    task_index = driver_client.set_local_rank_to_rank(host_hash, local_rank, rank)

    # gather available resources from task service
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key,
                                               verbose=settings.verbose)
    task_info.set_resources(task_client.resources())

    fn, args, kwargs = driver_client.code()
    result = fn(*args, **kwargs)
    task_client.register_code_result(result)

5.2 獲取訓練程式碼

在MapReduce之中,則是把Jar包(二進位制庫)分發到各個節點,然後各個節點執行jar包之中相應的程式碼。其實這樣很不方便。

Spark提出了函式序列化功能,可以很好的解決這個問題,這是Spark對分散式程式設計的一個貢獻。Spark系統會把你寫的那些自定義函式(你的業務功能)自動序列化到各個節點去執行。函式序列化傳送功能給Spark帶來的另外好處是:使用者可以使用spark-shell在命令列直接寫分散式程式碼,實時操作,實時得到結果。

比如,初始化/協調等工作是在Driver程式中進行,但是程式碼實際執行是在Worker節點中的Executor中進行。當Executor端執行時需要用到Driver端封裝的class物件時,Driver端就需要把Driver端的class物件通過序列化傳輸到Executor端,這個class物件則需要實現Serializable序列化方法。

Horovod on spark 這裡就是直接傳輸 python 訓練程式碼原始文字,因為 pyton 的指令碼特性,所以可以直接執行程式碼原始文字

獲取訓練程式碼的函式如下,在 SparkDriverClient 類之中就是給 Driver 傳送 CodeRequest 請求:

def code(self):
    resp = self._send(CodeRequest())
    return resp.fn, resp.args, resp.kwargs

在 SparkDriverService 之中,收到 CodeRequest 請求之後,會進行處理。

if isinstance(req, CodeRequest):
    return CodeResponse(self._fn, self._args, self._kwargs)

就是把 SparkDriverService 之前儲存的訓練程式碼 _fn 以及其引數一起發給 SparkTaskService。

class CodeResponse(object):
    def __init__(self, fn, args, kwargs):
        self.fn = fn
        """Function."""

        self.args = args
        """Function args."""

        self.kwargs = kwargs
        """Function kwargs."""

最終邏輯大致如下:

+---------------------------------+                     +---------------------------------+
| Horovod Main thread             |                     | Spark Executor                  |
|                                 |                     |                                 |
|                                 |                     |                                 |
|  +-------------------------+    |       1 register    |        +----------------------+ |
|  |     SparkDriverService  | <---------------------------------+  SparkTaskService    | |
|  |                         |    |                     |        |                      | |
|  |                         |    |      2 notify start |        |                      | |
|  |                         | +-------------------------------> |                      | |
|  |                         |    |                     |        |                      | |
|  |                         |    |                     |        |                      | |
|  |                         |    | 3 RunCommandRequest |        |                      | |
|  |                         | +--------------------------------------> orted mpirun_rsh| |
|  |                         |    |                     |        |        +             | |
|  |                         |    |                     |        |        | 4           | |
|  |                         |    |                     |        |        |             | |
|  |                         |    |                     |        |        v             | |
|  |                         |    |                     |        |      task_exec       | |
|  |                         |    |                     |        |        +             | |
|  |                         |    |                     |        |        | 5           | |
|  |                         |    |                     +        |        |             | |
|  |                         |    |6 set_local_rank_to_rank      |        v             | |
|  |                         | +------------------------+---------> SparkTaskClient     | |
|  |                         |    |                     |        |                      | |
|  |                         |    |    7 code()         |        |                      | |
|  |                         | +---------------------------------------> 8 fn()         | |
|  |                         |    |                     |        |                      | |
|  +-------------------------+    |                     |        +----------------------+ |
+---------------------------------+                     +---------------------------------+

手機如下:

至此,spark on MPI 分析結束,我們下文介紹 spark on GLOO。

0xEE 個人資訊

★★★★★★關於生活和技術的思考★★★★★★

微信公眾賬號:羅西的思考

如果您想及時得到個人撰寫文章的訊息推送,或者想看看個人推薦的技術資料,敬請關注。

在這裡插入圖片描述

0xFF

mpirun,mpiexec和mpiexec.hydra有什麼區別和關係?

相關文章