[原始碼解析] 深度學習分散式訓練框架 horovod (4) --- 網路基礎 & Driver
0x00 摘要
Horovod 是Uber於2017年釋出的一個易於使用的高效能的分散式訓練框架,在業界得到了廣泛應用。
本系列將通過原始碼分析來帶領大家瞭解 Horovod。本文是系列第四篇,看看如何獲取 host 之間的路由等網路資訊。
前面幾篇連結如下:
[原始碼解析] 深度學習分散式訓練框架 Horovod (1) --- 基礎知識
[原始碼解析] 深度學習分散式訓練框架 horovod (2) --- 從使用者角度切入
[原始碼解析] 深度學習分散式訓練框架 horovod (3) --- Horovodrun背後做了什麼
0x01 引子
在 horovod/runner/launch.py 檔案中,_run_static 函式中使用 driver_service.get_common_interfaces
來獲取路由資訊等。
def _run_static(args):
nics = driver_service.get_common_interfaces(settings, all_host_names,
remote_host_names, fn_cache)
因為這部分比較複雜( Driver 的概念很類似 Spark 之中 Driver 的概念),所以本文我們單獨來分析。
本文的分析問題點是:
- 為什麼要知道路由資訊?
- 當有多個host時候,horovod如何處理?
- 如何找到路由資訊?
- 怎麼互相互動?
- (後文會詳細分析)SparkDriverService,SparkTaskService,ElasticDriver, Worker 都有什麼區別和聯絡?
本文重點分析 HorovodRunDriverService 和 HorovodRunTaskService 相關。
先給出一個圖例,大家可以有些概念。
0x02 總體架構
從註釋可知,get_common_interfaces 完成了獲得路由資訊(所有host之間的共有路由介面集合)的功能,主要是呼叫 _driver_fn 來完成相關工作。
def get_common_interfaces(settings, all_host_names, remote_host_names=None, fn_cache=None):
'''
Find the set of common and routed interfaces on all the hosts.
'''
# 得到遠端host地址
if remote_host_names is None:
remote_host_names = network.filter_local_addresses(all_host_names)
if len(remote_host_names) > 0:
if settings.nics: # 如果引數有設定網路介面,就使用
# If args.nics is provided, we will use those interfaces. All the workers
# must have at least one of those interfaces available.
nics = settings.nics
else:
# Find the set of common, routed interfaces on all the hosts (remote
# and local) and specify it in the args to be used by NCCL. It is
# expected that the following function will find at least one interface
# otherwise, it will raise an exception.
local_host_names = set(all_host_names) - set(remote_host_names)
# 獲取其他host的網路介面
nics = _driver_fn(all_host_names, local_host_names, settings, fn_cache=fn_cache)
else:
nics = get_local_interfaces(settings) # 獲取本地的網路介面
return nics
2.1 get_local_interfaces
此函式比較簡單,目的是獲取本地的網路介面。
def get_local_interfaces(settings):
# If all the given hosts are local, find the interfaces with address
# 127.0.0.1
nics = set()
for iface, addrs in net_if_addrs().items():
if settings.nics and iface not in settings.nics:
continue
for addr in addrs:
if addr.family == AF_INET and addr.address == '127.0.0.1':
nics.add(iface)
break
return nics
2.2 _driver_fn
這是本文重點,獲取其他host 的網路介面,_driver_fn 的作用是:
- 啟動 service 服務;
- 使用 driver.addresses() 獲取 Driver 服務的地址(使用
self._addresses = self._get_local_addresses()
完成); - 使用 _launch_task_servers(利用 Driver 服務的地址)在每個 worker 之中啟動 task 服務,然後 task 服務會在 service 服務中註冊;
- 因為是一個環形,每個 worker 會探測 worker index + 1 的所有網路介面;
- 最後 _run_probe 返回一個所有 workers 上的所有路由介面的交集;
程式碼如下:
這裡需要注意的一點是:@cache.use_cache() 的使用:當第一次使用過之後,會把結果放入快取。
@cache.use_cache()
def _driver_fn(all_host_names, local_host_names, settings):
"""
launches the service service, launches the task service on each worker and
have them register with the service service. Each worker probes all the
interfaces of the worker index + 1 (in a ring manner) and only keeps the
routed interfaces. Function returns the intersection of the set of all the
routed interfaces on all the workers.
:param all_host_names: list of addresses. for example,
['worker-0','worker-1']
['10.11.11.11', '10.11.11.12']
:type all_host_names: list(string)
:param local_host_names: host names that resolve into a local addresses.
:type local_host_names: set
:param settings: the object that contains the setting for running horovod
:type settings: horovod.runner.common.util.settings.Settings
:return: example: ['eth0', 'eth1']
:rtype: list[string]
"""
# Launch a TCP server called service service on the host running horovod
# 啟動 service 服務
num_hosts = len(all_host_names)
driver = HorovodRunDriverService(num_hosts, settings.key, settings.nics)
# Have all the workers register themselves with the service service.
#(利用 Driver 服務的地址)在每個worker之中啟動 task 服務,然後task服務會在 service 服務中註冊
_launch_task_servers(all_host_names, local_host_names,
driver.addresses(), settings)
try:
# 返回一個所有 workers 上的所有路由介面的交集
return _run_probe(driver, settings, num_hosts)
finally:
driver.shutdown()
2.3 獲取路由介面
我們對 _run_probe 函式做進一步分析。
2.3.1 probe邏輯
_run_probe 函式就是當 所有 task 都啟動,註冊,probe 環中下一個worker 鄰居完成 之後,得到 介面集合。
- 利用 wait_for_initial_registration 等待所有 task 完成註冊;
- 對於所有 task,完成 task.notify_initial_registration_complete 通知;
- 利用 driver.wait_for_task_to_task_address_updates 等待 每一個 worker probe 完成;
- 利用 nics.intersection_update 得到介面集合;
def _run_probe(driver, settings, num_hosts):
# wait for all the hosts to register with the service service.
driver.wait_for_initial_registration(settings.start_timeout)
tasks = [
task_service.HorovodRunTaskClient(
index,
driver.task_addresses_for_driver(index),
settings.key,
settings.verbose) for index in range(
num_hosts)]
# Notify all the drivers that the initial registration is complete.
for task in tasks:
task.notify_initial_registration_complete()
# Each worker should probe the interfaces of the next worker in a ring
# manner and filter only the routed ones -- it should filter out
# interfaces that are not really connected to any external networks
# such as lo0 with address 127.0.0.1.
driver.wait_for_task_to_task_address_updates(settings.start_timeout)
# Determine a set of common interfaces for task-to-task communication.
nics = set(driver.task_addresses_for_tasks(0).keys())
for index in range(1, num_hosts):
nics.intersection_update(
driver.task_addresses_for_tasks(index).keys())
return nics
2.3.2 等待函式
probe 利用 wait_for_initial_registration 等待所有 task 完成註冊,具體等待函式如下:
def wait_for_initial_registration(self, timeout):
self._wait_cond.acquire()
try:
while len(self._all_task_addresses) < self._num_proc:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('tasks to start')
finally:
self._wait_cond.release()
def wait_for_task_to_task_address_updates(self, timeout):
self._wait_cond.acquire()
try:
while len(self._task_addresses_for_tasks) < self._num_proc:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for(
'tasks to update task-to-task addresses')
finally:
self._wait_cond.release()
0x03 基礎網路服務
前面提到,Horovod Driver 的概念很類似 Spark 之中 Driver 的概念。Spark應用程式執行時主要分為 Driver 和 Executor,Driver負載總體排程及UI展示,Executor負責Task執行。使用者的Spark應用程式執行在Driver上(某種程度上說,使用者的程式就是Spark Driver程式),經過Spark排程封裝成一個個Task,再將這些Task資訊發給Executor執行,Task資訊包括程式碼邏輯以及資料資訊,Executor不直接執行使用者的程式碼。
對於 Horovod 來說:
- HorovodRunDriverService 就是 Driver 的實現類。
- HorovodRunTaskService 提供了 Task 部分服務功能,這些 task 需要註冊到 HorovodRunDriverService 之中。
- 這套 driver & task 機制的底層由 "基礎網路服務" 支撐。
所以我們就仔細分析下基礎網路服務。
3.1 繼承關係
首先給出繼承關係,我們下面講解的 Driver 服務由 HorovodRunDriverService 提供,Task 服務由HorovodRunTaskService 提供。
這兩個類最終都繼承了 network.BasicService。
network.BasicService
^ ^
| |
+-------------------+ +-------------+
| |
+ +
driver_service.BasicDriverService task_service.BasicTaskService
^ ^
| |
| |
| |
+ +
HorovodRunDriverService HorovodRunTaskService
3.2 network.BasicService
BasicService 提供了一個網路伺服器功能。即通過find_port函式構建了一個ThreadingTCPServer
,對外提供服務。
class BasicService(object):
def __init__(self, service_name, key, nics):
self._service_name = service_name
self._wire = Wire(key)
self._nics = nics
self._server, _ = find_port(
lambda addr: socketserver.ThreadingTCPServer(
addr, self._make_handler()))
self._server._block_on_close = True
self._port = self._server.socket.getsockname()[1]
self._addresses = self._get_local_addresses()
self._thread = in_thread(target=self._server.serve_forever)
3.2.1 建立Server
建立伺服器程式碼如下,這裡是搜尋一個隨機埠,然後設定:
def find_port(server_factory):
min_port = 1024
max_port = 65536
num_ports = max_port - min_port
start_port = random.randrange(0, num_ports)
for port_offset in range(num_ports):
try:
port = min_port + (start_port + port_offset) % num_ports
addr = ('', port)
server = server_factory(addr)
return server, port
except Exception as e:
pass
raise Exception('Unable to find a port to bind to.')
3.2.2 Server功能
伺服器就是基本的功能,比如獲取本server地址,處理 ping,網路互動等。
def _make_handler(self):
server = self
class _Handler(socketserver.StreamRequestHandler):
def handle(self):
try:
req = server._wire.read(self.rfile)
resp = server._handle(req, self.client_address)
# A tuple is the usual response object followed by a utf8 text stream
if type(resp) == tuple:
(resp, stream) = resp
server._wire.write(resp, self.wfile)
server._wire.stream(stream, self.wfile)
else:
server._wire.write(resp, self.wfile)
except (EOFError, BrokenPipeError):
# Happens when client is abruptly terminated, don't want to pollute the logs.
pass
return _Handler
def _handle(self, req, client_address):
if isinstance(req, PingRequest):
return PingResponse(self._service_name, client_address[0])
raise NotImplementedError(req)
def _get_local_addresses(self):
result = {}
for intf, intf_addresses in psutil.net_if_addrs().items():
if self._nics and intf not in self._nics:
continue
for addr in intf_addresses:
if addr.family == socket.AF_INET:
if intf not in result:
result[intf] = []
result[intf].append((addr.address, self._port))
return result
def addresses(self):
return self._addresses.copy()
def shutdown(self):
self._server.shutdown()
self._server.server_close()
self._thread.join()
def get_port(self):
return self._port
3.3 network.BasicClient
HorovodRunDriverClient 和 HorovodRunTaskClient 這兩個類都繼承了network.BasicClient。
network.BasicClient 的作用就是連線 network.BasicService,與其互動。即 network.BasicClient 是一個操作介面。
network.BasicClient
^ ^
| |
+------------------+ +---------------+
| |
+ |
+
driver_service.BasicDriverClient task_service.BasicTaskClient
^ ^
| |
| |
+ +
HorovodRunDriverClient HorovodRunTaskClient
兩個主要 API 如下:
3.3.1 _probe
_probe 獲取 server 的網路介面。
def _probe(self, addresses):
result_queue = queue.Queue()
threads = []
for intf, intf_addresses in addresses.items():
for addr in intf_addresses:
thread = in_thread(target=self._probe_one, args=(intf, addr, result_queue))
threads.append(thread)
for t in threads:
t.join()
result = {}
while not result_queue.empty():
intf, addr = result_queue.get()
if intf not in result:
result[intf] = []
result[intf].append(addr)
return result
3.3.2 傳送訊息
_send 的作用是給server傳送訊息。
def _send(self, req, stream=None):
"""
Sends the request and returns the response object.
Streaming data response is transferred to the optional stream parameter.
"""
# Since all the addresses were vetted, use the first one.
addr = list(self._addresses.values())[0][0]
return self._send_one(addr, req, stream)
3.4 總結
我們可以看到,network.BasicService 會提供了一個server,這個 Service 都是通過 network.BasicClient 來訪問。基於此,Horovod 的HorovodRunDriverService 和 HorovodRunTaskService 這兩個類就可以互相互動,進行溝通。
0x04 Driver 服務
Driver 服務由 HorovodRunDriverService 提供,其功能主要是維護維護各種 task 地址以及相應關係。具體各種 task 地址 就是 Task 服務 來註冊的。
需要注意的是:HorovodRunDriverService 和 HorovodRunTaskService 都最終繼承了 network.BasicService,他們之間可以是異地執行互動。
4.1 HorovodRunDriverService
HorovodRunDriverService 是對 BasicDriverService 的封裝。
HorovodRunDriverClient 是 其 訪問介面。
class HorovodRunDriverService(driver_service.BasicDriverService):
NAME = 'horovod driver service'
def __init__(self, num_hosts, key, nics):
super(HorovodRunDriverService, self).__init__(num_hosts,
HorovodRunDriverService.NAME,
key, nics)
class HorovodRunDriverClient(driver_service.BasicDriverClient):
def __init__(self, driver_addresses, key, verbose, match_intf=False):
super(HorovodRunDriverClient, self).__init__(
HorovodRunDriverService.NAME,
driver_addresses,
key,
verbose,
match_intf=match_intf)
4.2 BasicDriverService
BasicDriverService基類 主要就是 維護各種 task 地址以及相應關係。
class BasicDriverService(network.BasicService):
def __init__(self, num_proc, name, key, nics):
super(BasicDriverService, self).__init__(name, key, nics)
self._num_proc = num_proc
self._all_task_addresses = {}
self._task_addresses_for_driver = {}
self._task_addresses_for_tasks = {}
self._task_index_host_hash = {}
self._task_host_hash_indices = {}
self._wait_cond = threading.Condition()
這裡的各種 task 地址就是 Task 服務 註冊到 Driver 的數值。
可以看到裡面有各種關於地址的變數,為了讓大家理解這些變數的作用,對於每一個變數我們舉例如下(這裡有些變數是專門為 spark 設計,都放到基類裡面有點奇怪):
4.2.1 _all_task_addresses
本變數是記錄了所有 task 的地址,變數舉例如下:
self._all_task_addresses = {
1: {
'lo' : [('1.1.1.1', 12345)],
'eth0' : [('10.10.10.01', 12345)]
},
0: {
'lo' : [('2.2.2.2', 54321)],
'eth0' : [('10.10.10.02', 54321)]
}
}
本變數由 task 呼叫 RegisterTaskRequest 來註冊。
if isinstance(req, RegisterTaskRequest):
self._wait_cond.acquire()
try:
assert 0 <= req.index < self._num_proc
self._all_task_addresses[req.index] = req.task_addresses
使用時候,有幾個方式,舉例如下:
比如 all_task_addresses 獲取 self._all_task_addresses[index].copy() 來決定 ping /check 的下一個跳轉。
4.2.2 _task_addresses_for_driver
本變數是記錄了所有 task 的地址,但是網路卡介面有多種,這裡選擇與 本 driver 地址匹配的地址。
變數舉例如下:
self._task_addresses_for_driver = {
1: {
'eth0' : [('10.10.10.01', 12345)]
},
0: {
'eth0' : [('10.10.10.02', 54321)]
}
}
本變數由 task 呼叫 RegisterTaskRequest 來註冊。
# Just use source address for service for fast probing.
self._task_addresses_for_driver[req.index] = \
self._filter_by_ip(req.task_addresses, client_address[0])
具體使用舉例如下:
def task_addresses_for_driver(self, index):
self._wait_cond.acquire()
try:
return self._task_addresses_for_driver[index].copy()
finally:
self._wait_cond.release()
driver用這個地址來生成 其內部 task 變數。
tasks = [
task_service.HorovodRunTaskClient(
index,
driver.task_addresses_for_driver(index),
settings.key,
settings.verbose) for index in range(
num_hosts)]
4.2.3 _task_addresses_for_tasks
該變數舉例如下:
self._task_addresses_for_tasks = {
1: {
'eth0' : [('10.10.10.01', 12345)]
},
0: {
'eth0' : [('10.10.10.02', 54321)]
}
}
本變數由RegisterTaskToTaskAddressesRequest註冊。
if isinstance(req, RegisterTaskToTaskAddressesRequest):
self.register_task_to_task_addresses(req.index, req.task_addresses)
return network.AckResponse()
def register_task_to_task_addresses(self, index, task_addresses):
self._wait_cond.acquire()
try:
assert 0 <= index < self._num_proc
self._task_addresses_for_tasks[index] = task_addresses # 這裡賦值
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
該變數被 task 用來獲取 某個 task 的一套網路介面,比如:
# Determine a set of common interfaces for task-to-task communication.
nics = set(driver.task_addresses_for_tasks(0).keys())
4.2.4 _task_index_host_hash
每一個 task 有一個對應的 host hash,該數值被 MPI 作為 host name 來操作。
self._task_index_host_hash = {
1: {
'ip-10-10-10-01-dfdsfdsfdsfdsf2'
},
0: {
'ip-10-10-10-02-treterwrtqwer'
}
}
具體使用如下。這個函式是 spark 相關會使用,具體是逐一通知 spark task 進入下一階段。
def task_indices(self):
self._wait_cond.acquire()
try:
return list(self._task_index_host_hash.keys())
finally:
self._wait_cond.release()
或者使用如下,是為了獲取某一個 host 對應的 host hash name
。
def task_index_host_hash(self, index):
self._wait_cond.acquire()
try:
assert 0 <= index < self._num_proc
return self._task_index_host_hash[index]
finally:
self._wait_cond.release()
4.2.5 _task_host_hash_indices
該變數舉例如下:
self._task_host_hash_indices = {
{
'ip-10-10-10-01-dfdsfdsfdsfdsf2' : [1]
},
{
'ip-10-10-10-02-treterwrtqwer' : [0]
}
}
具體是在註冊 RegisterTaskRequest 時候生成。
self._task_host_hash_indices[req.host_hash].append(req.index)
使用具體程式碼是:
def task_host_hash_indices(self):
self._wait_cond.acquire()
try:
return self._task_host_hash_indices.copy()
finally:
self._wait_cond.release()
具體是被 rsh 使用。rsh 就是在某一個 host 上,讓某一個 horovod rank 啟動。具體邏輯是:
- 獲取某一個 host 上所有的 task indices ;
- 利用 task_host_hash_indices 取出本程式 local rank 對應的 task index;
- 取出在 driver 中 task index 對應保持的 task address;
- 最後依據這個 task addresses 生成一個 SparkTaskClient,進行後續操作。
driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=verbose)
task_indices = driver_client.task_host_hash_indices(host_hash)
task_index = task_indices[local_rank]
task_addresses = driver_client.all_task_addresses(task_index)
task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=verbose)
task_client.stream_command_output(stdout, stderr)
task_client.run_command(command, env,
capture_stdout=stdout is not None,
capture_stderr=stderr is not None,
prefix_output_with_timestamp=prefix_output_with_timestamp)
4.3 總體邏輯
總體邏輯如下:
network.BasicService
^ ^
| |
+-------------------+ +-------------+
| |
+ +
driver_service.BasicDriverService task_service.BasicTaskService
^ ^
| |
| |
| |
| +
+----------------+------------------+ HorovodRunTaskService
| HorovodRunDriverService |
| |
| |
| _all_task_addresses |
| |
| _task_addresses_for_driver |
| |
| _task_addresses_for_tasks |
| |
| _task_index_host_hash |
| |
| _task_host_hash_indices |
| |
+-----------------------------------+
0x05 Task 服務
HorovodRunTaskService 提供了 Task 部分服務功能。整體邏輯是由幾個函式共同完成。
5.1 啟動具體服務
_launch_task_servers 用來啟動具體服務,其主要作用是:多執行緒執行,在每一個執行緒中,遠端執行 horovod.runner.task_fn
。
其中:
- 傳入引數中,all_host_names 就是程式啟動時候配置的所有host,比如 ["1.1.1.1", "2.2.2.2"];
- 使用了我們之前提到的 safe_shell_exec.execute 完成了安全執行保證;
- 使用我們前文提到的 get_remote_command 完成了遠端命令的獲取,即在命令之前加上了
ssh -o PasswordAuthentication=no -o StrictHostKeyChecking=no
等等配置; - 最終每個啟動的命令舉例如下:
ssh -o PasswordAuthentication=no -o StrictHostKeyChecking=no 1.1.1.1 python -m horovod.runner.task_fn xxxxxxx
; - 使用 execute_function_multithreaded 在每一個 host 上執行,啟動 task 服務;
具體程式碼如下:
def _launch_task_servers(all_host_names, local_host_names, driver_addresses,
settings):
"""
Executes the task server and service client task for registration on the
hosts.
:param all_host_names: list of addresses. for example,
['worker-0','worker-1']
['10.11.11.11', '10.11.11.12']
:type all_host_names: list(string)
:param local_host_names: names that are resolved to one of the addresses
of local hosts interfaces. For example,
set(['localhost', '127.0.0.1'])
:type local_host_names: set
:param driver_addresses: map of interfaces and their address and port for
the service. For example:
{
'lo': [('127.0.0.1', 34588)],
'docker0': [('172.122.10.1', 34588)],
'eth0': [('11.111.33.73', 34588)]
}
:type driver_addresses: map
:param settings: the object that contains the setting for running horovod
:type settings: horovod.runner.common.util.settings.Settings
:return:
:rtype:
"""
def _exec_command(command):
host_output = io.StringIO()
try:
# 完成了安全執行保證
exit_code = safe_shell_exec.execute(command,
stdout=host_output,
stderr=host_output)
finally:
host_output.close()
return exit_code
args_list = []
num_hosts = len(all_host_names)
for index in range(num_hosts):
host_name = all_host_names[index] # all_host_names 就是程式啟動時候配置的所有host,比如 ["1.1.1.1", "2.2.2.2"]
command = \
'{python} -m horovod.runner.task_fn {index} {num_hosts} ' \
'{driver_addresses} {settings}' \
.format(python=sys.executable,
index=codec.dumps_base64(index),
num_hosts=codec.dumps_base64(num_hosts),
driver_addresses=codec.dumps_base64(driver_addresses),
settings=codec.dumps_base64(settings))
if host_name not in local_host_names:
# 完成了遠端命令的獲取,即在命令之前加上了 `ssh -o PasswordAuthentication=no -o StrictHostKeyChecking=no`等等配置
command = get_remote_command(command,
host=host_name,
port=settings.ssh_port,
identity_file=settings.ssh_identity_file)
args_list.append([command])
# Each thread will use ssh command to launch the server on one task. If an
# error occurs in one thread, entire process will be terminated. Otherwise,
# threads will keep running and ssh session -- and the the task server --
# will be bound to the thread. In case, the horovod process dies, all
# the ssh sessions and all the task servers will die as well.
# 使用 execute_function_multithreaded 在每一個 host 上執行,啟動 task 服務
threads.execute_function_multithreaded(_exec_command,
args_list,
block_until_all_done=False)
5.2 具體服務邏輯
上段有:{python} -m horovod.runner.task_fn {index} {num_hosts} {driver_addresses} {settings}
執行具體服務邏輯,所以我們介紹下 horovod.runner.task_fn
。
_task_fn
函式完成了
- 生成了 HorovodRunTaskService 例項,賦值給 task;
- 使用
HorovodRunDriverClient . register_task
來向 Driver 服務註冊task(自己)的地址; - 使用
HorovodRunDriverClient . register_task_to_task_addresses
來向 Driver 服務註冊自己在Ring上 下一個鄰居的地址; - 每一個 task 都做這個操作,最後就得到了在這個 ring cluster 之中的一個路由介面;
具體程式碼如下:
def _task_fn(index, num_hosts, driver_addresses, settings):
task = task_service.HorovodRunTaskService(index, settings.key, settings.nics)
try:
driver = driver_service.HorovodRunDriverClient(
driver_addresses, settings.key, settings.verbose)
# 向 Driver 服務註冊task(自己)的地址
driver.register_task(index,
task.addresses(),
host_hash.host_hash())
task.wait_for_initial_registration(settings.start_timeout)
# Tasks ping each other in a circular fashion to determine interfaces
# reachable within the cluster.
next_task_index = (index + 1) % num_hosts
next_task_addresses = driver.all_task_addresses(next_task_index)
# We request interface matching to weed out all the NAT'ed interfaces.
next_task = task_service.HorovodRunTaskClient(
next_task_index,
next_task_addresses,
settings.key,
settings.verbose,
match_intf=True,
attempts=10)
# 向 Driver 服務註冊自己在Ring上 下一個鄰居的地址
driver.register_task_to_task_addresses(next_task_index,
next_task.addresses())
# Notify the next task that the address checks are completed.
next_task.task_to_task_address_check_completed()
# Wait to get a notification from previous task that its address checks
# are completed as well.
task.wait_for_task_to_task_address_check_finish_signal(settings.start_timeout)
finally:
task.shutdown()
if __name__ == '__main__':
index = codec.loads_base64(sys.argv[1])
num_hosts = codec.loads_base64(sys.argv[2])
driver_addresses = codec.loads_base64(sys.argv[3])
settings = codec.loads_base64(sys.argv[4])
_task_fn(index, num_hosts, driver_addresses, settings)
5.3 HorovodRunTaskService
HorovodRunTaskService 主要的作用是提供了兩個等待函式。因為具體路由操作是需要彼此通知,所以需要互相等待。
class HorovodRunTaskService(task_service.BasicTaskService):
NAME_FORMAT = 'horovod task service #%d'
def __init__(self, index, key, nics):
super(HorovodRunTaskService, self).__init__(
HorovodRunTaskService.NAME_FORMAT % index,
index, key, nics)
self.index = index
self._task_to_task_address_check_completed = False
def _handle(self, req, client_address):
if isinstance(req, TaskToTaskAddressCheckFinishedSignal):
self._wait_cond.acquire()
try:
self._task_to_task_address_check_completed = True
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
return TaskToTaskAddressCheckFinishedSignalResponse(self.index)
return super(HorovodRunTaskService, self)._handle(req, client_address)
def wait_for_task_to_task_address_check_finish_signal(self, timeout):
self._wait_cond.acquire()
try:
while not self._task_to_task_address_check_completed:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('Task to task address check')
finally:
self._wait_cond.release()
class HorovodRunTaskClient(task_service.BasicTaskClient):
def __init__(self, index, task_addresses, key, verbose, match_intf=False, attempts=3):
super(HorovodRunTaskClient, self).__init__(
HorovodRunTaskService.NAME_FORMAT % index,
task_addresses, key, verbose,
match_intf=match_intf,
attempts=attempts)
self.index = index
def task_to_task_address_check_completed(self):
resp = self._send(TaskToTaskAddressCheckFinishedSignal(self.index))
return resp.index
邏輯如下:
_driver_fn
+
|
|
+---------------------------------------+-------------------------------------v
| |
| v
| _launch_task_servers
v +
driver = HorovodRunDriverService |
+ +--------------+-------------------+
| | |
| | |
v v v
+-------------------+---------------+ horovod.runner.task_fn ...... horovod.runner.task_fn
| HorovodRunDriverService | + +
| | | |
| | | |
| _all_task_addresses | | |
| | v v
| _task_addresses_for_driver | register_task +-----------+---------------+ +-------+--------------------+
| | | HorovodRunTaskService | | HorovodRunTaskService |
| _task_addresses_for_tasks | <--------------------------------+ | | |
| | | | wait | |
| _task_index_host_hash | | | <------> | |
| | <--------------------------------+ | | |
| _task_host_hash_indices | register_task_to_task_addresses | | | |
| | +---------------------------+ +----------------------------+
+-----------------------------------+ `
手機如下:
0x06 總結
本文總結如下:
- 因為 Horovod 分散式訓練 涉及到多個 hosts,所以如果要彼此訪問,需要知道路由資訊;
- 當所有 task 都啟動,註冊,probe 環中下一個worker 鄰居完成 之後,DriverService 會得到路由資訊(所有host之間的共有路由介面集合),返回給 Horovod 主體部分使用;
- network.BasicService 提供了網路服務功能;
- XXXService 都是通過 XXXClient作為介面才能訪問;
- HorovodRunDriverService 和 HorovodRunTaskService 都最終繼承了 network.BasicService,他們之間可以是異地執行互動。
- HorovodRunTaskService 提供了 Task 部分服務功能,這些 task 需要註冊到 Driver 之中(和Spark思路類似)。
- HorovodRunDriverService 是對 BasicDriverService 的封裝。BasicDriverService 就是 維護各種 task 地址以及相應關係,比如:
- _all_task_addresses :記錄了所有 task 的地址;
- _task_addresses_for_driver :記錄了所有 task 的地址,但是因為網路卡介面有多種,這裡選擇與 本driver 地址匹配的地址;
- _task_addresses_for_tasks :用來給某一個 task 分配一個地址,同時獲取本 task 的一套網路介面;
- _task_index_host_hash :每一個 task 有一個對應的 host hash。這個函式是 spark 相關會使用,具體是逐一通知 spark task 進入下一階段。或者是為了獲取某一個 host 對應的 host hash name;
- _task_host_hash_indices :具體是被 rsh 使用,由 rank 得到 在 driver 中 task index 對應保持的 task address;
- SparkDriverService,SparkTaskService,ElasticDriver, Worker 都有什麼區別和聯絡?
- HorovodRunDriverService 這裡只是用來得到路由資訊,記錄各種 Task 地址;
- SparkDriverService 除了記錄路由和地址之外,還提交執行任務(Command),因為具體在哪一個Spark Executor啟動之後,SparkDriverService 就需要知道 對應 SparkTaskService 的地址,這樣才能知道提交到哪裡;
- SparkTaskService 負責執行命令(拋棄了Spark Executor的邏輯,自己搞了一套),就是從 SparkDriverService 那裡獲得訓練函式,然後啟動 python 程式來執行;
- ElasticDriver 做得更多,因為還有彈性,需要容錯;
0xEE 個人資訊
★★★★★★關於生活和技術的思考★★★★★★
微信公眾賬號:羅西的思考
如果您想及時得到個人撰寫文章的訊息推送,或者想看看個人推薦的技術資料,敬請關注。
0xFF 參考
[原始碼解析] 深度學習分散式訓練框架 Horovod (1) --- 基礎知識