[原始碼解析] 深度學習分散式訓練框架 horovod (9) --- 啟動 on spark
0x00 摘要
Horovod 是Uber於2017年釋出的一個易於使用的高效能的分散式訓練框架,在業界得到了廣泛應用。
本系列將通過原始碼分析來帶領大家瞭解 Horovod。這幾篇介紹 horovod 如何執行在 spark 之上。本文是第九篇,介紹 horovod on spark 如何啟動。
本系列其他文章如下:
[原始碼解析] 深度學習分散式訓練框架 Horovod (1) --- 基礎知識
[原始碼解析] 深度學習分散式訓練框架 horovod (2) --- 從使用者角度切入
[原始碼解析] 深度學習分散式訓練框架 horovod (3) --- Horovodrun背後做了什麼
[原始碼解析] 深度學習分散式訓練框架 horovod (4) --- 網路基礎 & Driver
[原始碼解析] 深度學習分散式訓練框架 horovod (5) --- 融合框架
[原始碼解析] 深度學習分散式訓練框架 horovod (6) --- 後臺執行緒架構
[原始碼解析] 深度學習分散式訓練框架 horovod (7) --- DistributedOptimizer
[原始碼解析] 深度學習分散式訓練框架 horovod (8) --- on spark
0x01 總體架構圖
首先,我們還是要祭出架構圖,這樣大家可以按圖索驥。
總體來說,Horovod on Spark 的總體邏輯分為以下階段:
- 啟動 SparkDriverService 服務,利用 _make_spark_thread 啟動 Spark task,然後 horovod 會等待啟動結束;
- 多執行緒在 spark executor 之中啟動 spark task,每個task之中執行一個 SparkTaskService,SparkTaskService 會向 hovorod 主程式中的 SparkDriverTask 進行註冊,並且等待下一步執行啟動的指令;
- Horovod 收到所有 task 結束的資訊之後,通知各個 task,進入下一階段;
- Horovod 呼叫 mpi_run (又利用到 mpirun_rsh.py)在每一個 spark executor 上啟動 orted 程式,以啟動 MPI cluster;
- orted 在每一個 executor 之上執行訓練程式碼;
我們下面就具體看看如何啟動。
0x02 第一階段 :Horovod 啟動
本部分主要邏輯是:啟動 SparkDriverService 服務,利用 _make_spark_thread 啟動 Spark task,然後 horovod 會等待啟動結束。
2.1 Driver服務 :SparkDriverService
SparkDriverService 繼承了 driver_service.BasicDriverService,所以其內部啟動了一個 socket server,可以進行網路互動。
Horovod 利用 SparkDriverService 來和 Spark executor(通過其中執行的SparkTaskService)互動,比如收集資訊,讓 spark 啟動訓練job等等。這是一個 RPC 機制。
具體 SparkDriverService 的功能可以參見其內部處理的各種 Request,比如
- CodeRequest :SparkTaskService會用來請求使用者程式碼;
- TaskHostHashIndicesRequest :獲取 task host 地址;
- TaskIndexByRankRequest :從 rank 獲取到 task index;
- SetLocalRankToRankRequest :從 local rank 得到 rank 資訊;
- WaitForTaskShutdownRequest :等待 shutdown;
和前文介紹的 HorovodRunDriverService 有些類似。
其中,其成員變數 _fn 就是訓練函式,以後當 SparkTaskService 請求程式碼的時候,就通過 CodeResponse 把 _fn 直接傳送回去。這樣就解決了程式碼釋出問題。
class SparkDriverService(driver_service.BasicDriverService):
NAME = 'driver service'
def __init__(self, initial_np, num_proc, fn, args, kwargs, key, nics):
super(SparkDriverService, self).__init__(num_proc,
SparkDriverService.NAME,
key, nics)
self._initial_np = initial_np
self._fn = fn # 儲存使用者程式碼
self._args = args # 使用者引數
self._kwargs = kwargs
self._key = key
self._nics = nics # 網路卡資訊
self._ranks_to_indices = {}
self._spark_job_failed = False
self._lock = threading.Lock()
self._task_shutdown = threading.Event()
def _handle(self, req, client_address):
if isinstance(req, TaskHostHashIndicesRequest): # 獲取 task host 地址
return TaskHostHashIndicesResponse(self._task_host_hash_indices[req.host_hash])
if isinstance(req, SetLocalRankToRankRequest): # 從 local rank 得到 rank 資訊
self._lock.acquire()
try:
# get index for host and local_rank
indices = self._task_host_hash_indices[req.host]
index = indices[req.local_rank]
values = list(self._ranks_to_indices.values())
prev_pos = values.index(index) if index in values else None
if prev_pos is not None:
prev_rank = list(self._ranks_to_indices.keys())[prev_pos]
del self._ranks_to_indices[prev_rank]
# memorize rank's index
self._ranks_to_indices[req.rank] = index
finally:
self._lock.release()
return SetLocalRankToRankResponse(index)
if isinstance(req, TaskIndexByRankRequest): # 是從 rank 獲取到 task index
self._lock.acquire()
try:
return TaskIndexByRankResponse(self._ranks_to_indices[req.rank])
finally:
self._lock.release()
if isinstance(req, CodeRequest): # SparkTaskService會用來請求使用者程式碼
return CodeResponse(self._fn, self._args, self._kwargs)
if isinstance(req, WaitForTaskShutdownRequest): # 等待任務結束
self._task_shutdown.wait()
return network.AckResponse()
return super(SparkDriverService, self)._handle(req, client_address)
2.2 啟動spark task : _make_spark_thread
在 Horovod.spark.run 之中,_make_spark_thread 建立了 thread。這裡關鍵程式碼是:
mapper = _make_mapper(driver.addresses(), settings, use_gloo, is_elastic)
result = procs.mapPartitionsWithIndex(mapper).collect()
mapPartitionsWithIndex 這句程式碼會促使 Spark 在多個 Executor 之中執行 mapper 函式,並且得到執行結果。
即建立 settings.num_proc
個 Spark tasks,每個 task 會執行 mapper(_task_fn), 外部的 run 函式會等待這些執行結果。其實如果需要使用RDD,也許可以使用 foreachPartition
,這樣每個結點上將會在記憶體中持有RDD的一個分割槽。
def _make_spark_thread(spark_context, spark_job_group, driver, result_queue,
settings, use_gloo, is_elastic):
"""Creates `settings.num_proc` Spark tasks in a parallel thread."""
def run_spark():
"""Creates `settings.num_proc` Spark tasks, each executing `_task_fn` and waits for them to terminate."""
try:
spark_context.setJobGroup(spark_job_group, "Horovod Spark Run", interruptOnCancel=True)
procs = spark_context.range(0, numSlices=settings.max_np if settings.elastic else settings.num_proc)
# We assume that folks caring about security will enable Spark RPC encryption,
# thus ensuring that key that is passed here remains secret.
mapper = _make_mapper(driver.addresses(), settings, use_gloo, is_elastic)
# 促使 Spark 在多個 Executor 之中執行 mapper 函式,並且得到執行結果
result = procs.mapPartitionsWithIndex(mapper).collect()
result_queue.put(result)
except:
driver.notify_spark_job_failed()
raise
spark_thread = in_thread(target=run_spark, daemon=False)
return spark_thread
2.3 等待 spark task 啟動結束
啟動了 spark task 之後,horovod 主程式會呼叫如下來等待 task 全部 啟動完成。
# wait for all tasks to register, notify them and initiate task-to-task address registration
_notify_and_register_task_addresses(driver, settings)
即,run 函式中,當 _make_spark_thread 之後,horovod 主程式呼叫 _notify_and_register_task_addresses,從而呼叫 driver.wait_for_initial_registration(settings.start_timeout) ,進行總體等待。
等待的內容是:等待所有 num_proc tasks 來註冊。當所有 spark thread 都ready 之後,主 horovod 程式會繼續執行。
2.3.1 _notify_and_register_task_addresses
在 horovod 主程式之中,會使用 _notify_and_register_task_addresses
來等待這些 spark task 來註冊,從而呼叫 driver.wait_for_initial_registration(settings.start_timeout) ,進行總體等待。
注意,同時傳送註冊請求之後, spark task 自己也呼叫 task.wait_for_initial_registration 等待 horovod 再通知下一階段的啟動。
而在horovod 主程式的 _notify_and_register_task_addresses 其實也很複雜:
- 呼叫 driver.wait_for_initial_registration 等待task來註冊,需要等待 num_proc 個task;
- 利用 notify_and_register 註冊task,並且通知各個 task 開始下一步;
具體程式碼如下:
def _notify_and_register_task_addresses(driver, settings, notify=True):
# wait for num_proc tasks to register
# 等待task來註冊,需要等待 num_proc 個task
driver.wait_for_initial_registration(settings.start_timeout)
def notify_and_register(index): # 註冊task,並且通知各個 task 開始下一步
task_client = task_service.SparkTaskClient(index,
driver.task_addresses_for_driver(index),
settings.key, settings.verbose)
if notify:
task_client.notify_initial_registration_complete()
next_task_index = (index + 1) % settings.num_proc
next_task_addresses = driver.all_task_addresses(next_task_index)
task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)
for index in driver.task_indices():
in_thread(notify_and_register, (index,)) #在thread之中啟動task
driver.wait_for_task_to_task_address_updates(settings.start_timeout)
我們目前只能看其第一步 “等待註冊”。
2.3.2 driver.wait_for_initial_registration
在這裡 SparkDriverSerivce 首先等待所有 spark executor 註冊。
在 class BasicDriverService(network.BasicService): 有如下程式碼,可以看到,只有全部 _num_proc 註冊完成,當所有 spark thread 都ready 之後,主 horovod 程式會繼續執行。
這裡關鍵是:while len(self._all_task_addresses) < self._num_proc
就是等待 self._all_task_addresses 的數目達到 _num_proc。
class BasicDriverService(network.BasicService):
def wait_for_initial_registration(self, timeout):
self._wait_cond.acquire()
try:
# 等待 self._all_task_addresses 的數目達到 _num_proc
while len(self._all_task_addresses) < self._num_proc:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('tasks to start')
finally:
self._wait_cond.release()
2.4 等待
關於等待程式碼,我們要做一下特殊說明,具體看圖。
這裡有兩套 wait_for_initial_registration。可以認為是兩套 barrier。
就是:
- barrier 1 :SparkDriverSerivce 等待所有 SparkTaskSerivce ready;
- barrier 2 :所有 SparkTaskSerivce 需要一起執行,所以 SparkTaskSerivce們 都在等待 barrier 2。SparkDriverSerivce 會通知 這些 SparkTaskSerivce 一起發動;
2.3.1 Barrier 1 in Driver
在 run 函式中,當 _make_spark_thread 之後,horovod 主程式呼叫 _notify_and_register_task_addresses,從而呼叫 driver.wait_for_initial_registration(settings.start_timeout) ,進行總體等待。
等待的內容是:等待所有 num_proc tasks 來註冊。當所有 spark thread 都ready 之後,主 horovod 程式會繼續執行。這裡關鍵是:
while len(self._all_task_addresses) < self._num_proc
就是等待 self._all_task_addresses 的數目達到 _num_proc。
def wait_for_initial_registration(self, timeout):
self._wait_cond.acquire()
try:
while len(self._all_task_addresses) < self._num_proc:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('tasks to start')
finally:
self._wait_cond.release()
在 BasicDriverService 之中,如果收到了 spark executor 的註冊請求就進行處理,這裡最重要是:
self._all_task_addresses[req.index] = req.task_addresses
當所有的 spark executor 都註冊了,這裡就等待成功。
2.3.2 Barrier 2 in task
每個 spark thread 在 _task_fn 之中執行,就是在 spark task 之中執行。這裡也可以看出來是 Spark task 的一個總體流程:
- 首先 呼叫
register_task
; - 其次 呼叫
task.wait_for_initial_registration(settings.start_timeout)
; - 然後 呼叫
wait_for_command_termination
來等待結束;
task.wait_for_initial_registration 會等待 self._initial_registration_complete = True 這個條件,就是等待 register_task
註冊完成。
每個 Spark Executor 都有一個 SparkTaskService,所以 每個spark task 都有自己的 _initial_registration_complete。
hovorod.run 主程式會逐一通知每個 SparkTaskService 的 _initial_registration_complete。
即,哪個 SparkTaskService 好了,就通知哪個 SparkTaskService 的 _initial_registration_complete。這樣,這個 SparkTaskService 就可以正式執行了。
2.3.3 總體等待流程
總體等待流程具體如圖,數字就是執行順序:
- SparkDriverSerivce 呼叫 driver.wait_for_initial_registration 來等待 SparkTaskSerivce 的註冊,這是 barrier 1;
- SparkTaskSerivce 1 進行註冊,然後 SparkTaskSerivce 1 自己也呼叫 task.wait_for_initial_registration 等待 horovod 再通知下一階段的啟動,這是 barrier 2;
- SparkTaskSerivce 2 進行註冊,然後 SparkTaskSerivce 2 自己也呼叫 task.wait_for_initial_registration 等待 horovod 再通知下一階段的啟動,這是 barrier 2;
- hovorod.run 主程式在發現所有 task 都註冊之後,barrier 1 等待結束,會逐一通知每個 SparkTaskService 的 _initial_registration_complete。只有 4 完成之後,兩個 SparkTaskSerivce 才能繼續執行 5,6;
- SparkTaskSerivce 1 對於 barrier 2 等待結束,繼續執行;
- SparkTaskSerivce 2 對於 barrier 2 等待結束,繼續執行;
SparkTaskSerivce 1 SparkTaskSerivce 2 SparkDriverSerivce
+ + +
| | |
| | |
| | |
| | | 1
| | |
| | |
| | v
| |
| | +--------------------------------------+
| | | barrier 1 |
| | 2 | |
| 3 +-------> | |
| | | |
+-----------------------------------> | driver.wait_for_initial_registration |
| | | |
| | | |
| | | |
| | +--------------------+-----------------+
| | |
| | |
+-----------+----------------------+ | 4 |
|barrier 2 | <---------------------------------+
| | | |
|task.wait_for_initial_registration| | |
| | | |
+-----------+----------------------+ | |
| | |
| +-------------+----------------------+ |
| | barrier 2 | 4 |
| 6 | +<------+
| | task.wait_for_initial_registration | |
| | | |
| +-------------+----------------------+ |
| | |
| | |
| | 5 |
| | |
v v v
我們接下來詳細介紹 task 啟動內容 和 driver 後續工作。
0x03 第二階段 :Spark Task 啟動
本階段我們詳細介紹下 Spark Task 的啟動過程。
這部分主要功能是:多執行緒在 spark executor 之中啟動 spark task,每個spark task會執行_task_fn
函式,_task_fn
函式會執行一個 SparkTaskService,SparkTaskSerivce 會向 hovorod 主程式中的 SparkDriverTask 進行註冊,並且等待下一步執行啟動的指令;
此時程式(不是訓練程式,而是 SparkTaskService)已經在 Spark Executor內部執行了。我們看看在 spark Executor 之中,是如何啟動執行 SparkTaskService 的。
3.1 具體spark啟動邏輯 :_task_fn
Horovod 在 thread 裡面通過 _make_mapper 來讓 Spark 執行 _task_fn。
def _make_mapper(driver_addresses, settings, use_gloo, is_elastic):
def _mapper(index, _):
yield _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic)
return _mapper
_task_fn
的作用是為了註冊 horovod 進入到 spark task。即,在每一個 spark task (executor) 之中啟動一個 SparkTaskService。
一定要注意:這些 SparkTaskService 是執行在 spark executor 之中,通過網路與 horovod 之中的 SparkDriverService 互動。
可以看到,_task_fn 的總體邏輯是:
- 啟動 SparkTaskService;
- 通過 driver_service.SparkDriverClient.register_task 來向 horovod 中的 Driver 註冊;
- 通過 task.wait_for_initial_registration(settings.start_timeout) 來等待下一步啟動的開始指示;
- 如果下一步開始啟動了,則呼叫 task.wait_for_command_termination() 等待結束;
具體如下:
def _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic):
settings.key = key
hosthash = host_hash(salt='{}-{}'.format(index, time.time()) if is_elastic else None)
os.environ['HOROVOD_HOSTNAME'] = hosthash
# 啟動 SparkTaskService,SparkTaskService本身包括一個socket server,可以和driver互動
task = task_service.SparkTaskService(index, settings.key, settings.nics,...)
try:
driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose)
# 向 horovod 中的 Driver 註冊
driver_client.register_task(index, task.addresses(), hosthash)
# 這裡依然執行在spark task之中,但因為不是SparkTaskService,所以只是做協助工作,最後靜靜等待
if not is_elastic:
# 等待下一步啟動的開始指示
task.wait_for_initial_registration(settings.start_timeout)
task_indices_on_this_host = driver_client.task_host_hash_indices(hosthash)
local_rank_zero_index = task_indices_on_this_host[0]
else:
local_rank_zero_index = None
if is_elastic:
...... # 後續文章會介紹
elif use_gloo or index == local_rank_zero_index:
# Either Gloo or first task with MPI.
# 使用Gloo或者使用MPI的第一個task,讓這個task做操作
task.wait_for_command_start(settings.start_timeout)
# 等待結束
task.wait_for_command_termination()
else:
# The other tasks with MPI need to wait for the first task to finish.
# 讓其他的task等待第一個task結束
first_task_addresses = driver_client.all_task_addresses(local_rank_zero_index)
first_task_client = \
task_service.SparkTaskClient(local_rank_zero_index,
first_task_addresses, settings.key,
settings.verbose)
# 呼叫 task.wait_for_command_termination() 等待結束
first_task_client.wait_for_command_termination()
return task.fn_result()
finally:
task.shutdown()
3.2 SparkTaskService
再次強調如下程式碼:
task = task_service.SparkTaskService(index, settings.key, settings.nics,...)
每一個_task_fn 中都定義了一個 SparkTaskService,即每一個 Spark Executor 都會生成一個(或者多個) SparkTaskService,在 spark task 之中執行並且作用。
3.2.1 SparkTaskService 定義
SparkTaskService 定義如下,因為繼承了BasicTaskService,所以其內部最終也會啟動一個 socket server,以便同 horovod 中的 SparkDriverService 互動:
class SparkTaskService(task_service.BasicTaskService):
NAME_FORMAT = 'task service #%d'
def __init__(self, index, key, nics, minimum_command_lifetime_s, verbose=0):
# on a Spark cluster we need our train function to see the Spark worker environment
# this includes PYTHONPATH, HADOOP_TOKEN_FILE_LOCATION and _HOROVOD_SECRET_KEY
env = os.environ.copy()
# we inject the secret key here
env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key)
# we also need to provide the current working dir to mpirun_exec_fn.py
env['HOROVOD_SPARK_WORK_DIR'] = os.getcwd()
super(SparkTaskService, self).__init__(SparkTaskService.NAME_FORMAT % index,
index, key, nics, env, verbose)
self._key = key
self._minimum_command_lifetime_s = minimum_command_lifetime_s
self._minimum_command_lifetime = None
3.2.2 基本功能
SparkTaskService 的基本功能如下。
- _run_command 將會被用來在 spark 之中啟動訓練job;
- _handle 會處理 GetTaskToTaskAddressesRequest,用來獲取 task 地址,也會處理ResourcesRequest,返回資源;
- _get_resources 將返回 spark 資源;
- wait_for_command_termination 會等待命令執行結束;
具體程式碼如下:
def _run_command(self, command, env, event,
stdout=None, stderr=None, index=None,
prefix_output_with_timestamp=False):
# 在 spark 之中啟動訓練job
super(SparkTaskService, self)._run_command(command, env, event,
stdout, stderr, index,
prefix_output_with_timestamp)
if self._minimum_command_lifetime_s is not None:
self._minimum_command_lifetime = timeout.Timeout(self._minimum_command_lifetime_s,
message='Just measuring runtime')
def _handle(self, req, client_address):
# 返回資源
if isinstance(req, ResourcesRequest):
return ResourcesResponse(self._get_resources())
# 獲取 task 地址
if isinstance(req, GetTaskToTaskAddressesRequest):
next_task_index = req.task_index
next_task_addresses = req.all_task_addresses
# We request interface matching to weed out all the NAT'ed interfaces.
next_task_client = \
SparkTaskClient(next_task_index, next_task_addresses,
self._key, self._verbose,
match_intf=True)
return GetTaskToTaskAddressesResponse(next_task_client.addresses())
return super(SparkTaskService, self)._handle(req, client_address)
def _get_resources(self):
# 返回 spark 資源
if LooseVersion(pyspark.__version__) >= LooseVersion('3.0.0'):
task_context = pyspark.TaskContext.get()
if task_context:
return task_context.resources()
else:
print("Not running inside Spark worker, no resources available")
return dict()
def wait_for_command_termination(self):
"""
Waits for command termination. Ensures this method takes at least
self._minimum_command_lifetime_s seconds to return after command started.
"""
try:
# 等待命令執行結束
return super(SparkTaskService, self).wait_for_command_termination()
finally:
# command terminated, make sure this method takes at least
# self._minimum_command_lifetime_s seconds after command started
# the client that started the command needs some time to connect again
# to wait for the result (see horovod.spark.driver.rsh).
if self._minimum_command_lifetime is not None:
time.sleep(self._minimum_command_lifetime.remaining())
3.3 註冊Task
下一步程式碼就是用來向 Driver 註冊 本 task。
driver_client.register_task(index, task.addresses(), hosthash)
3.3.1 傳送註冊請求
註冊具體通過如下完成,這裡呼叫了 network.py 的 _send 函式,就是通過 socket,spark executor 和 horovod driver 進行了網路互動:
class BasicDriverClient(network.BasicClient):
def register_task(self, index, task_addresses, host_hash):
self._send(RegisterTaskRequest(index, task_addresses, host_hash))
3.3.2 Driver處理
我們先來到 Horovod 中執行的 Driver來看看(下一節內容,這裡提前看看)。
在 BasicDriverService 之中,如果收到了RegisterTaskRequest請求就進行處理,這裡最重要是:
self._all_task_addresses[req.index] = req.task_addresses
這樣,self._all_task_addresses 的數目就增加了。
而我們之前提到了,horovod 正在 driver.wait_for_initial_registration 上面等待,其關鍵是:
while len(self._all_task_addresses) < self._num_proc
如果self._all_task_addresses
的數目達到了_num_proc
,driver.wait_for_initial_registration 就結束了,就順利執行。
具體處理 RegisterTaskRequest 的程式碼如下,BasicDriverService 之中有各種成員變數,用來維護各種所需資訊,我們在前文 [原創 原始碼解析] 深度學習分散式訓練框架 horovod (4) --- 網路基礎 & Driver 中已經詳細講解過,_handle函式的RegisterTaskRequest 處理就是用來更新這些成員變數:
class BasicDriverService(network.BasicService):
def _handle(self, req, client_address):
if isinstance(req, RegisterTaskRequest):
self._wait_cond.acquire()
try:
self._all_task_addresses[req.index] = req.task_addresses
# Just use source address for service for fast probing.
self._task_addresses_for_driver[req.index] = \
self._filter_by_ip(req.task_addresses, client_address[0])
# Remove host hash earlier registered under this index.
if req.index in self._task_index_host_hash:
earlier_host_hash = self._task_index_host_hash[req.index]
if earlier_host_hash != req.host_hash:
self._task_host_hash_indices[earlier_host_hash].remove(req.index)
# Make index -> host hash map.
self._task_index_host_hash[req.index] = req.host_hash
# Make host hash -> indices map.
if req.host_hash not in self._task_host_hash_indices:
self._task_host_hash_indices[req.host_hash] = []
self._task_host_hash_indices[req.host_hash].append(req.index)
# TODO: this sorting is a problem in elastic horovod
self._task_host_hash_indices[req.host_hash].sort()
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
return network.AckResponse()
3.4 Task 等待下一步通知
前面提到了,當 spark task 向 driver 傳送註冊請求之後,Spark task 通過 task.wait_for_initial_registration(settings.start_timeout) 來等待下一步啟動的開始指示。就是 driver 認為你一景註冊完成了,可以開始進入下一步了。
task.wait_for_initial_registration 會等待 self._initial_registration_complete = True 這個條件,就是等待 register_task
註冊完成。
class BasicTaskService(network.BasicService):
def wait_for_initial_registration(self, timeout):
self._wait_cond.acquire()
try:
while not self._initial_registration_complete:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('tasks to start')
finally:
self._wait_cond.release()
每個 Spark Executor 都有一個 SparkTaskService,所以 每個spark task 都有自己的 _initial_registration_complete。
hovorod.run 主程式會逐一通知每個 SparkTaskService 的 _initial_registration_complete。即,哪個 SparkTaskService 好了,就通知哪個 SparkTaskService 的 _initial_registration_complete。
hovorod.run 主程式 是通過傳送 NotifyInitialRegistrationCompleteRequest完成這一步的。
def notify_initial_registration_complete(self):
self._send(NotifyInitialRegistrationCompleteRequest())
BasicTaskService 在等待 NotifyInitialRegistrationCompleteRequest,如果收到了,就設定為 True,這樣wait_for_initial_registration 就等待結束了。
if isinstance(req, NotifyInitialRegistrationCompleteRequest):
self._wait_cond.acquire()
try:
self._initial_registration_complete = True
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
return network.AckResponse()
就說明當本 thread 註冊在 horovod 之後,就算本 spark thread 啟動成功了。
+-------------------------------------+ +----------------------------------------------------+
| Horovod Main thread | | Spark Executor |
| | | _task_fn |
| | | + |
| | | | |
| | | | |
| | | v |
| +-------------------------------+ | | +---------------------+------------------------+ |
| | SparkDriverService | | | | SparkTaskService | |
| | | | | | + | |
| | | | 1 register | | | | |
| | self._all_task_addresses <----------------------------------------+ | |
| | | | | | | | |
| | + | | | | | | |
| | | | | | | | | |
| | | 3 | | | | | | |
| | | | | | | | 2 | |
| | v | | | | | | |
| | self._wait_cond.notify_all() | | | | | | |
| | + | | | | v | |
| | | | | | | +---------+---------------------------+ | |
| | | | | | | | | | |
| | | | | | | | task.wait_for_initial_registration | | |
| | | | | | | | | | |
| | | | | | | +-------------------------------------+ | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | v | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| +-------------------------------+ | | +----------------------------------------------+ |
+-------------------------------------+ +----------------------------------------------------+
手機如下:
0x04 第三階段:Driver 通知 task 註冊成功
本階段的作用是:Horovod 收到所有 task 結束的資訊之後,通知各個 task,進入下一階段。
4.1 _notify_and_register_task_addresses
前面提到。在 horovod 主程式之中,會使用 _notify_and_register_task_addresses
來等待這些 spark task 來註冊,從而呼叫 driver.wait_for_initial_registration(settings.start_timeout) ,進行總體等待。
注意,同時傳送註冊請求之後, spark task 自己也呼叫 task.wait_for_initial_registration 等待horovod 再通知下一階段的啟動。
而 _notify_and_register_task_addresses 中其實也很複雜:
- 呼叫 driver.wait_for_initial_registration 等待task來註冊;(目前這一步已經完成)
- 利用 notify_and_register 註冊task,並且通知各個 task 開始下一步;(我們這裡進入後面這兩步)
- 利用 driver.wait_for_task_to_task_address_updates 再次確認下所有 task 都OK;
def _notify_and_register_task_addresses(driver, settings, notify=True):
# wait for num_proc tasks to register
driver.wait_for_initial_registration(settings.start_timeout)
def notify_and_register(index):
# 註冊task,並且通知各個 task 開始下一步
task_client = task_service.SparkTaskClient(index,
driver.task_addresses_for_driver(index),
settings.key, settings.verbose)
if notify:
task_client.notify_initial_registration_complete()
next_task_index = (index + 1) % settings.num_proc
next_task_addresses = driver.all_task_addresses(next_task_index)
task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)
for index in driver.task_indices():
in_thread(notify_and_register, (index,)) # 註冊task,並且通知各個 task 開始下一步
# 再次確認下所有 task 都OK
driver.wait_for_task_to_task_address_updates(settings.start_timeout)
4.2 notify_and_register
可以看到 notify_and_register 的作用就是:
- 呼叫 task_client.notify_initial_registration_complete() 通知 spark task 註冊成功了,這樣就讓所有等待 task.wait_for_initial_registration 的 spark executor 一起執行下一階段。
- 呼叫 driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses) 來讓 Driver 完成註冊。
def wait_for_task_to_task_address_updates(self, timeout):
self._wait_cond.acquire()
try:
while len(self._task_addresses_for_tasks) < self._initial_np:
self.check_for_spark_job_failure()
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
finally:
self._wait_cond.release()
4.3 wait_for_task_to_task_address_updates
這裡會再次確認所有 spark task 都OK。
def wait_for_task_to_task_address_updates(self, timeout):
self._wait_cond.acquire()
try:
while len(self._task_addresses_for_tasks) < self._initial_np:
self.check_for_spark_job_failure()
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
finally:
self._wait_cond.release()
4.4 等待 In Task
在 Spark task 之中,如果收到了下一步啟動指示之後,會呼叫 wait_for_command_termination 進行等待。
其實,這一步也就意味著 spark exector 自己本身的邏輯任務結束了,因為以後都是 SparkTaskService 自己獨立完成的動作,它來負責訓練程式碼的啟動。既然 _task_fn
的邏輯任務已經結束,那麼靜靜地等待即可。
4.4.1 wait_for_command_termination
在 horovod-master/horovod/spark/task/task_service.py
def wait_for_command_termination(self):
"""
Waits for command termination. Ensures this method takes at least
self._minimum_command_lifetime_s seconds to return after command started.
"""
try:
return super(SparkTaskService, self).wait_for_command_termination()
finally:
# command terminated, make sure this method takes at least
# self._minimum_command_lifetime_s seconds after command started
# the client that started the command needs some time to connect again
# to wait for the result (see horovod.spark.driver.rsh).
if self._minimum_command_lifetime is not None:
time.sleep(self._minimum_command_lifetime.remaining())
在 horovod-master/horovod/runner/common/service/task_service.py 中可以看到,就是等待訓練程式碼所在的 thread 結束。
def wait_for_command_termination(self):
self._command_thread.join() # 馬上會說明
4.4.2 _command_thread
這裡對 _command_thread 略作說明。
在 SparkTaskService 處理 RunCommandRequest 時候,執行 Command 的 thread 就是被賦值為 _command_thread。
class BasicTaskService(network.BasicService):
def _handle(self, req, client_address):
if isinstance(req, RunCommandRequest): # 執行命令請求
self._wait_cond.acquire()
try:
if self._command_thread is None:
if self._command_env:
env = self._command_env.copy()
self._add_envs(env, req.env)
req.env = env
self._command_abort = threading.Event()
self._command_stdout = Pipe() if req.capture_stdout else None
self._command_stderr = Pipe() if req.capture_stderr else None
# 配置各種引數資訊
args = (req.command, req.env, self._command_abort,
self._command_stdout, self._command_stderr,
self._index,
req.prefix_output_with_timestamp)
# 啟動一個新執行緒來執行命令
self._command_thread = in_thread(self._run_command, args)
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
return network.AckResponse()
邏輯如下:
+-------------------------------------+ +----------------------------------------------------+
| Horovod Main thread | | Spark Executor |
| | | _task_fn |
| | | + |
| | | | |
| | | | |
| | | v |
| +-------------------------------+ | | +---------------------+------------------------+ |
| | SparkDriverService | | | | SparkTaskService | |
| | | | | | + | |
| | | | 1 register | | | | |
| | self._all_task_addresses <----------------------------------------+ | |
| | | | | | | | |
| | + | | | | | | |
| | | | | | | | | |
| | | 3 | | | | | | |
| | | | | | | | 2 | |
| | v | | | | | | |
| | self._wait_cond.notify_all() | | | | | | |
| | + | | | | v | |
| | | | + + + +---------+---------------------------+ | |
| | | 4 | RegistrationComplete | | | |
| | | +-----------------+-------------+--+---> | task.wait_for_initial_registration | | |
| | | | | | | | | | |
| | | | | | | +---------+---------------------------+ | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | 5 | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | v | |
| | | | | | | wait_for_command_termination | |
| | | | 6 | RunCommand | | + | |
| | | | | | | | | |
| | +-----------------------------------------------> | 7 | |
| | | | | | | v | |
| | v | | | | self._command_thread.join() | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| +-------------------------------+ | | +----------------------------------------------+ |
+-------------------------------------+ +----------------------------------------------------+
手機如下:
至此,第一階段完成,我們下一篇繼續,敬請期待。
0x05 總結
總體來說,Horovod on Spark 的總體邏輯分為以下階段:
- 啟動 SparkDriverService 服務,利用 _make_spark_thread 啟動 Spark task,然後 horovod 會等待啟動結束;
- 多執行緒在 spark executor 之中啟動 spark task,每個task之中執行一個 SparkTaskService,SparkTaskService 會向 hovorod 主程式中的 SparkDriverTask 進行註冊,並且等待下一步執行啟動的指令;
- Horovod 收到所有 task 結束的資訊之後,通知各個 task,進入下一階段;
- Horovod 呼叫 mpi_run (又利用到 mpirun_rsh.py)在每一個 spark executor 上啟動 orted,以啟動 MPI cluster;
- orted 在每一個 executor 之上執行訓練程式碼;
本文介紹了前三個階段,即啟動階段。下文介紹後續兩個階段,敬請期待。
0xEE 個人資訊
★★★★★★關於生活和技術的思考★★★★★★
微信公眾賬號:羅西的思考
如果您想及時得到個人撰寫文章的訊息推送,或者想看看個人推薦的技術資料,敬請關注。