[原始碼解析] 深度學習分散式訓練框架 horovod (14) --- 彈性訓練發現節點 & State
0x00 摘要
Horovod 是Uber於2017年釋出的一個易於使用的高效能的分散式訓練框架,在業界得到了廣泛應用。
本系列將通過原始碼分析來帶領大家瞭解 Horovod。本文是系列第十四篇,看看horovod 如何動態發現節點 和 狀態資訊。
本系列其他文章連結如下:
[原始碼解析] 深度學習分散式訓練框架 Horovod (1) --- 基礎知識
[原始碼解析] 深度學習分散式訓練框架 horovod (2) --- 從使用者角度切入
[原始碼解析] 深度學習分散式訓練框架 horovod (3) --- Horovodrun背後做了什麼
[原始碼解析] 深度學習分散式訓練框架 horovod (4) --- 網路基礎 & Driver
[原始碼解析] 深度學習分散式訓練框架 horovod (5) --- 融合框架
[原始碼解析] 深度學習分散式訓練框架 horovod (6) --- 後臺執行緒架構
[原始碼解析] 深度學習分散式訓練框架 horovod (7) --- DistributedOptimizer
[原始碼解析] 深度學習分散式訓練框架 horovod (8) --- on spark
[原始碼解析] 深度學習分散式訓練框架 horovod (9) --- 啟動 on spark
[原始碼解析] 深度學習分散式訓練框架 horovod (10) --- run on spark
[原始碼解析] 深度學習分散式訓練框架 horovod (11) --- on spark --- GLOO 方案
[原始碼解析] 深度學習分散式訓練框架 horovod (12) --- 彈性訓練總體架構
[原始碼解析] 深度學習分散式訓練框架 horovod (13) --- 彈性訓練之 Driver
0x01 設計點
本文對應架構圖中的 Host Discovery 部分,因為是被 Driver Main 呼叫,所以把兩部分一起展示出。
發現節點機制的幾個關鍵設計點如下:
- 有節點變化時候,如何即時發現?Horovod是通過定期呼叫完成。
- 發現節點變化時候,如何通知各個worker? Horovod通過構建了一個通知機制完成。即,每個worker把自己註冊到WorkerNotificationManager 之上,當有節點變化時候,WorkerNotificationManager 會逐一通知這些worker。
- worker得到通知之後,如何處理?Horovod 把worker的狀態在深度框架上進一步封裝成各種State,得到通知之後就會呼叫State的對應callback函式,或者同步狀態,或者進行其他處理。
0x02 發現機制
這部分程式碼主要在:horovod/runner/elastic/discovery.py。
2.1 發現指令碼
HostDiscoveryScript 的主要作用就是儲存指令碼(程式啟動時候設定進來),然後當執行 find_available_hosts_and_slots 的時候,呼叫這個發現指令碼,得到 host 資訊。
該指令碼的輸出的格式 就是呼叫 horovodrun 時候 的 host 引數格式,比如:
$ sh ./discover_hosts.sh # 執行指令碼,輸出節點資訊
10.68.32.2:4
10.68.32.3:4
10.68.32.4:4
定義如下:
class HostDiscoveryScript(HostDiscovery):
def __init__(self, discovery_script, slots):
self._discovery_script = discovery_script # 設定指令碼
self._default_slots = slots # 審定slots
super(HostDiscoveryScript, self).__init__()
def find_available_hosts_and_slots(self):
stdout = io.StringIO()
# 執行發現指令碼
exit_code = safe_shell_exec.execute(self._discovery_script, stdout=stdout)
# 讀取指令碼輸出,解析出來host資訊
host_slots = {}
lines = set(stdout.getvalue().strip().split('\n'))
for line in lines:
host = line
if ':' in line:
host, slots = line.split(':')
host_slots[host] = int(slots)
else:
host_slots[host] = self._default_slots
return host_slots
2.2 HostManager
HostManager 是 host discovery 的核心,作用是維護當前 host 以及 狀態,其主要變數是:
- self._current_hosts : 當前的 host 資訊,包括 slot,assign order 等等;
- self._hosts_state :當前的 host 狀態,包括黑名單,event 等;
- self._discovery :可以認為是對 發現指令碼 的一個封裝,用來動態執行 發現指令碼,獲取 host 資訊;
class HostManager(object):
def __init__(self, discovery):
self._current_hosts = DiscoveredHosts(host_slots={}, host_assignment_order=[])
self._hosts_state = defaultdict(HostState)
self._discovery = discovery
def update_available_hosts(self):
# TODO(travis): also check for hosts removed from the blacklist in the future
# 檢查更新,給出是新增,還是刪除節點
def check_update(cur_host_slots, prev_host_slots):
res = HostUpdateResult.no_update
for prev_h in prev_host_slots:
if prev_h not in cur_host_slots:
# prev_h is a removed host
res |= HostUpdateResult.removed
for h in cur_host_slots:
if h not in prev_host_slots:
# h is an added host
res |= HostUpdateResult.added
elif cur_host_slots[h] > prev_host_slots[h]:
# h has more slots added
res |= HostUpdateResult.added
elif cur_host_slots[h] < prev_host_slots[h]:
# h has removed some slots
res |= HostUpdateResult.removed
return res
prev_host_slots = self._current_hosts.host_slots
prev_host_assignment_order = self._current_hosts.host_assignment_order
host_slots = self._discovery.find_available_hosts_and_slots()
if prev_host_slots != host_slots: # 有修改
# 找到不在黑名單裡的host
available_hosts = set([host for host in host_slots.keys() if not self._hosts_state[host].is_blacklisted()])
# 找到host的order
host_assignment_order = HostManager.order_available_hosts(available_hosts, prev_host_assignment_order)
self._current_hosts = DiscoveredHosts(host_slots=host_slots,
host_assignment_order=host_assignment_order)
# 檢查更新
return check_update(self._current_hosts.host_slots, prev_host_slots)
else: # 沒修改就不更新
return HostUpdateResult.no_update
HostManager 核心邏輯是 update_available_hosts 方法,就是用來發現可用的 host。
2.2.1 order_available_hosts
order_available_hosts 的作用是:確保最老的host被賦予最低的rank,即rank 0,因為最老的host最有可能擁有原來訓練的模型以及訓練狀態,這些資訊需要在下一輪新迭代之前,發給所有worker。
@staticmethod
def order_available_hosts(available_hosts, prev_host_assignment_order):
# We need to ensure this list preserves relative order to ensure the oldest hosts are assigned lower ranks.
host_assignment_order = [host for host in prev_host_assignment_order if host in available_hosts]
known_hosts = set(host_assignment_order)
for host in available_hosts:
if host not in known_hosts:
host_assignment_order.append(host)
return host_assignment_order
2.3 配置
我們看看是發現指令碼如何配置進入HostManager之中。
首先,發現指令碼是在_run_elastic之中配置。
def _run_elastic(args):
# construct host discovery component
if args.host_discovery_script:
# 如果引數中有設定發現指令碼,則賦值為discover_hosts
discover_hosts = discovery.HostDiscoveryScript(args.host_discovery_script, args.slots)
elif args.hosts: # 如果引數設定好了hosts,則賦值為discover_hosts
_, available_host_slots = hosts.parse_hosts_and_slots(args.hosts)
if len(available_host_slots) < 2:
raise ValueError('Cannot run in fault tolerance mode with fewer than 2 hosts.')
discover_hosts = discovery.FixedHosts(available_host_slots)
else: # 丟擲異常
raise ValueError('One of --host-discovery-script, --hosts, or --hostnames must be provided')
# 配置進入setting
settings = elastic_settings.ElasticSettings(discovery=discover_hosts,
.....)
env = os.environ.copy()
config_parser.set_env_from_args(env, args)
gloo_run_elastic(settings, env, args.command)
其次,發現指令碼被設定到ElasticSettings之中。
class ElasticSettings(BaseSettings):
def __init__(self, discovery, min_np, max_np, elastic_timeout, reset_limit, **kwargs):
self.discovery = discovery
當啟動時候,會設定到 ElasticDriver 之中。
def start(self):
"""Starts the Horovod driver and services."""
self.rendezvous = RendezvousServer(self.settings.verbose)
self.driver = ElasticDriver(
rendezvous=self.rendezvous,
discovery=self.settings.discovery, # 在這裡設定發現指令碼
min_np=self.settings.min_np,
max_np=self.settings.max_np,
timeout=self.settings.elastic_timeout,
reset_limit=self.settings.reset_limit,
verbose=self.settings.verbose)
最後,建立HostManager時候,會設定發現指令碼。
class ElasticDriver(object):
def __init__(self, rendezvous, discovery, min_np, max_np, timeout=None, reset_limit=None, verbose=0):
self._rendezvous = rendezvous
self._host_manager = HostManager(discovery) # 設定指令碼
0x03 如何呼叫
3.1 無限迴圈執行緒
HostManager 的呼叫邏輯是在 ElasticDriver 類中。
ElasticDriver 在初始化時候,生成一個後臺執行緒 _discovery_thread。
self._discovery_thread = threading.Thread(target=self._discover_hosts)
3.1.1 定時探尋
在 _discovery_thread
之中,會執行_discover_hosts。
ElasticDriver._discover_hosts
會:
- 首先呼叫
self._host_manager.update_available_hosts(self._host_manager.current_hosts, update_res)
得到最新的host狀態; - 其次,如果新 host 狀態已經發生的變化,於是就呼叫 _notify_workers_host_changes 和 _wait_hosts_cond.notify_all 來通知大家有 host 變化了;
def _discover_hosts(self):
first_update = True
while not self._shutdown.is_set():
self._wait_hosts_cond.acquire()
try:
# 得到最新的host狀態
update_res = self._host_manager.update_available_hosts()
if update_res != HostUpdateResult.no_update:
self._notify_workers_host_changes(self._host_manager.current_hosts, update_res)
self._wait_hosts_cond.notify_all() # 通知大家有 host 變化
except RuntimeError as e:
if first_update:
# Misconfiguration, fail the job immediately
self._shutdown.set()
self._wait_hosts_cond.notify_all() # 通知大家有 host 變化
raise
# Transient error, retry until timeout
logging.warning(str(e))
finally:
self._wait_hosts_cond.release()
first_update = False
self._shutdown.wait(DISCOVER_HOSTS_FREQUENCY_SECS)
邏輯如下,是一個 thread loop 定時執行:
<--------------------^
+ |
| thread loop |
| |
| +----------------+----------------------+
| | ElasticDriver._discovery_thread |
| | |
| | |
| | |
| | |
| | HostManager.update_available_hosts |
| | |
| +----------------+----------------------+
| ^
| |
v |
+-------------------->+
3.1.2 通知變化
如果發現有host 變化,則呼叫 self._notify_workers_host_changes
來通知。
即,當Driver的定時程式通過節點發現指令碼發現某一個節點被標記為新增或者移除時,它將 呼叫 _notify_workers_host_changes 傳送一個通知到所有workers。
邏輯如下:
<--------------------^
+ |
| thread loop |
| |
| +----------------+-----------------------------------------------+
| | ElasticDriver._discovery_thread |
| | |
| | |
| | HostManager.update_available_hosts |
| | + |
| | | |
| | | |
| | v |
| | YES |
| | update_res != no_update ??? +--------+ |
| | + | |
| | | | |
| | | v |
| | | NO |
| | | _notify_workers_host_changes |
| | v |
| +----------------------------------------------------------------+
| |
| |
| |
v |
+-------------------->+
具體如下:
def _notify_workers_host_changes(self, current_hosts, update_res):
next_host_assignments = {}
if current_hosts.count_available_slots() >= self._min_np:
# Assignments are required to be stable via contract
next_host_assignments, _ = self._get_host_assignments(current_hosts)
if next_host_assignments == self.host_assignments:
# Skip notifying workers when host changes would not result in changes of host assignments
return
coordinator_slot_info = self.get_coordinator_info()
# 獲取 WorkerNotificationClient
coordinator_client = self.get_worker_client(coordinator_slot_info)
timestamp = _epoch_time_s()
coordinator_client.notify_hosts_updated(timestamp, update_res) # 通知
get_worker_client 函式就是獲取 WorkerNotificationClient,然後呼叫 WorkerNotificationClient 來進行通知,所以下面我們接下來看 WorkerNotificationClient。
def get_worker_client(self, slot_info):
return self._worker_clients.get((slot_info.hostname, slot_info.local_rank))
具體如下:
<--------------------^
+ |
| thread loop |
| |
| +----------------+------------------------------------+
| | ElasticDriver._discovery_thread |
| | + |
| | | |
| | v |
| | HostManager.update_available_hosts |
| | + |
| | | |
| | | |
| | v YES | +---------------------------+
| | update_res != no_update ??? +-----+ | | |
| | + | | | |
| | | | | | WorkerNotificationClient |
| | | v | notify_hosts_updated | |
| | | NO | | |
| | | _notify_workers_host_changes+------------------------> | |
| | v | | |
| +-----------------------------------------------------+ +---------------------------+
| |
| |
| |
v |
+-------------------->+
手機如下:
3.2 如何通知
就是利用 WorkerNotificationClient 傳送 HostsUpdatedRequest。
3.2.1 WorkerNotificationClient
可以看到,WorkerNotificationService 繼承了 network.BasicService,所以 WorkerNotificationClient 就是作為 WorkerNotificationService 的操作介面,從而給 WorkerNotificationService 傳送 HostsUpdatedRequest。
class WorkerNotificationClient(network.BasicClient):
def __init__(self, addresses, key, verbose, match_intf=False):
super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
addresses,
key,
verbose,
match_intf=match_intf)
def notify_hosts_updated(self, timestamp, update_res):
self._send(HostsUpdatedRequest(timestamp, update_res))
3.2.2 WorkerNotificationService
WorkerNotificationService 會響應 HostsUpdatedRequest。
class WorkerNotificationService(network.BasicService):
NAME = 'worker notification service'
def __init__(self, key, nic, manager):
super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
key,
nic)
self._manager = manager
def _handle(self, req, client_address):
if isinstance(req, HostsUpdatedRequest):
self._manager.handle_hosts_updated(req.timestamp, req.res)
return network.AckResponse()
return super(WorkerNotificationService, self)._handle(req, client_address)
3.2.3 WorkerNotificationManager
handle_hosts_updated 會逐一通知註冊在WorkerNotificationManager 上的 listener(就是使用者程式碼中的 State)。
WorkerNotificationManager 是在 horovod/common/elastic.py 構建,每一個host上執行一個。
notification_manager = WorkerNotificationManager()
具體定義如下:
class WorkerNotificationManager(object):
def __init__(self):
self._lock = threading.Lock()
self._service = None
self._listeners = set()
def init(self, rendezvous_addr=None, rendezvous_port=None,
nic=None, hostname=None, local_rank=None):
with self._lock:
if self._service:
return
rendezvous_addr = rendezvous_addr or os.environ.get(HOROVOD_GLOO_RENDEZVOUS_ADDR)
if not rendezvous_addr:
return
rendezvous_port = rendezvous_port if rendezvous_port is not None else \
int(os.environ.get(HOROVOD_GLOO_RENDEZVOUS_PORT))
nic = nic or os.environ.get(HOROVOD_GLOO_IFACE)
hostname = hostname or os.environ.get(HOROVOD_HOSTNAME)
local_rank = local_rank if local_rank is not None else \
int(os.environ.get(HOROVOD_LOCAL_RANK))
secret_key = secret.make_secret_key()
self._service = WorkerNotificationService(secret_key, nic, self)
value = (self._service.addresses(), secret_key)
put_data_into_kvstore(rendezvous_addr,
rendezvous_port,
PUT_WORKER_ADDRESSES,
self._create_id(hostname, local_rank),
value)
def register_listener(self, listener):
self._listeners.add(listener)
def remove_listener(self, listener):
self._listeners.remove(listener)
def handle_hosts_updated(self, timestamp, update_res):
for listener in self._listeners:
listener.on_hosts_updated(timestamp, update_res)
3.2.4 通知 State
我們再梳理以下流程:
- 當Driver的定時程式通過節點發現指令碼發現某一個節點被標記為新增或者移除時,它將傳送一個通知到所有workers。
- 每一個 worker 有自己對應的 State,都被儲存於
WorkerNotificationManager . _listeners
。 - _host_messages 會在state 之中註冊host的變化,就是往其 _host_messages 之中放入"host 有變化" 的訊息。
- 因為這個訊息不是一定要立即處理的,所以這裡只是先放入 State 的佇列之中。
邏輯如下:
<--------------------^
+ |
| thread loop |
| |
| +----------------+------------------------------------+
| | ElasticDriver._discovery_thread |
| | + |
| | | |
| | v |
| | HostManager.update_available_hosts |
| | + |
| | | |
| | | |
| | v YES |
| | update_res != no_update ??? +-----+ | +--------------------------+ +----------------------------+
| | + | | | | | |
| | | | | | WorkerNotificationClient | | WorkerNotificationService |
| | | v | notify_hosts_updated | | HostsUpdatedRequest | |
| | | NO | | | | |
| | | _notify_workers_host_changes+------------------------> | | +-------------------> | |
| | v | | | | |
| +-----------------------------------------------------+ +--------------------------+ +----------------+-----------+
| | |
| | |
| | handle_hosts_updated |
v | |
+-------------------->+ v
+------------------+-----------+
| |
| WorkerNotificationManager |
+-----------+ +----------+ +----------+ | |
| | | | | | | |
| State 1 | | State 2 | ...... | State n | <---------------------+ _listeners |
| | | | | | | |
+-----------+ +----------+ +----------+ | |
| |
^ ^ ^ | |
| | | | |
on_hosts_updated | | on_hosts_updated | on_hosts_updated | |
| | | | |
+--------------+-------------------+-------------------------+ handle_hosts_updated |
| |
+------------------------------+
手機如下:
3.2.5 何時處理
何時處理這個通知?在下一次 state.commit() 或者更輕量的 state.check_host_updates() 被呼叫時,state.check_host_updates 會從 _host_messages 中讀取訊息,積累更新。
如 check_host_updates 方法中註釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時丟擲 HostsUpdateInterrupt 異常,具體同步使用 _bcast_object(然後內部呼叫到了 MPI)。
我們接下來就會在 State 的介紹之中,講解check_host_updates 。
0x04 狀態抽象
Horovod 實現了一個 State 物件,這是把機器訓練的模型又做了一步抽象。
每一個Worker擁有一個 State 物件。
-
Horovod 把所有需要在workers之間同步的變數都放進 hvd.elastic.State (比如model parameters,optimizer state,當前epoch和batch進度等等)物件之中。
-
State 物件的作用是定期儲存訓練狀態,在需要時候從 State 物件中恢復機器學習的狀態。這樣在某些worker發生意外錯誤時,可以避免因為狀態被損壞而無法恢復現場。
-
比如,假設一個worker剛好在引數更新過程中突然掛掉,而此時部分梯度更新可能只更新到一半,這個狀態是不可逆而又無法繼續,導致引數是被損壞狀態無法用於恢復訓練。
4.1 State
State 的作用是:在不同的 worker 之中跟蹤記憶體狀態。
主要變數&方法是:
- on_reset : 當需要重啟狀態時候呼叫;
- on_hosts_updated :當有 host 變化時候呼叫,即 向 _host_messages 這個 queue 放入一個訊息;
- commit :使用者會定期呼叫此函式,會儲存狀態(state)到記憶體,檢查 host 更改;
- 當有異常發生時,會丟擲一個 HorovodInternalError 異常,當 hvd.elastic.run 捕獲到這個異常後,會利用最新一次commit中恢復所有狀態。
- 因為commit狀態代價高昂(比如如引數量太大會導致耗時過長),所以需要在"每個batch的處理時間"與"如果出錯,訓練需要從多久前的狀態恢復"之間選取一個平衡點。比如,如果你每訓練10個batches就commit一次,你就把複製時間降低了10倍。但是當發生錯誤時,你需要回滾到10個batches前的狀態。
- check_host_updates : 會從
_host_messages
中讀取訊息,積累更新,如方法中註釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時丟擲異常。具體同步使用_bcast_object
(然後內部呼叫到了 MPI);- 如果節點發現指令碼可以預見到某個節點是需要被移除或新增,Elastic Horvod可以避免回滾操作。當Driver的定時程式通過節點發現指令碼發現某一個節點被標記為新增或者移除時,它將傳送一個通知到所有workers,於是在下一次 state.commit() 或者更輕量的 state.check_host_updates() 被呼叫時,會丟擲一個 HostsUpdateInterrupt 異常。這個異常類似於 HorovodInternalError 異常,但是引數狀態等不會從最近一次commit中恢復,而是從當前實時的引數中恢復。
- 一般來說,如果你的硬體設施是可靠與穩定的,並且你的編排系統會在任務節點移除時提供足夠的告警,你就可低頻次呼叫 state.commit() 函式,同時只在每個batch結束時呼叫相對不耗時的 state.check_host_updates() 來檢查節點變更情況。
- _reset_callbacks :使用者可以註冊一些回撥函式到 hvd.elastic.State 物件中,用於響應worker成員發生變化的情況。
- 比如回撥函式可以處理如下情況:
- 當worker數量發生改變時,學習率需要根據新的world size進行相應改變。
- 對資料集進行重新分割槽。
- 這些回撥函式會在"Horovod被重啟之後"和"狀態在節點間同步之前"這兩個階段中間被呼叫。
- 比如回撥函式可以處理如下情況:
具體定義如下:
class State(object):
"""State representation used for tracking in memory state across workers.
Args:
bcast_object: Function used to broadcast a variable from rank 0 to the other workers.
get_rank: Function that returns the current rank of this worker.
"""
def __init__(self, bcast_object, get_rank):
self._bcast_object = bcast_object
self._rank = get_rank
self._host_messages = queue.Queue()
self._last_updated_timestamp = 0
self._reset_callbacks = []
def on_reset(self):
self._host_messages = queue.Queue()
self.reset()
for callback in self._reset_callbacks:
callback()
def on_hosts_updated(self, timestamp, update_res):
self._host_messages.put((timestamp, update_res))
def commit(self):
self.save()
self.check_host_updates()
def check_host_updates(self):
"""Checks that a notification has been sent indicating that hosts can be added or will be removed.
Raises a `HostsUpdatedInterrupt` if such a notification has been received.
"""
# Iterate through the update messages sent from the server. If the update timestamp
# is greater than the last update timestamp, then trigger a HostsUpdatedException.
# 遍歷更新訊息,如果更新時間戳大於上次更新時間戳,就觸發一個HostUpdateResult
last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
all_update = HostUpdateResult.no_update
while not self._host_messages.empty():
timestamp, update = self._host_messages.get()
if timestamp > last_updated_timestamp:
last_updated_timestamp = timestamp
all_update |= update
# In order to ensure all workers raise the exception at the same time, we need to sync
# the updated state across all the workers.
# TODO(travis): this should be a max allreduce to account for changes in rank 0
# 會從 `_host_messages` 中讀取訊息,積累更新,如方法中註釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時丟擲異常。具體同步使用 `_bcast_object`(然後內部呼叫到了 MPI)
prev_timestamp, self._last_updated_timestamp, all_update = \
self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))
# At this point, updated state is globally consistent across all ranks.
if self._last_updated_timestamp > prev_timestamp:
raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)
因此,我們加入 Commit 之後,邏輯如圖:
<--------------------^
+ |
| thread loop |
| |
| +----------------+------------------------------------+
| | ElasticDriver._discovery_thread |
| | + |
| | | |
| | v |
| | HostManager.update_available_hosts |
| | + |
| | | |
| | | |
| | v YES |
| | update_res != no_update ??? +-----+ | +--------------------------+ +----------------------------+
| | + | | | | | |
| | | | | | WorkerNotificationClient | | WorkerNotificationService |
| | | v | notify_hosts_updated | | HostsUpdatedRequest | |
| | | NO | | | | |
| | | _notify_workers_host_changes+------------------------> | | +-------------------> | |
| | v | | | | |
| +-----------------------------------------------------+ +--------------------------+ +----------------+-----------+
| | |
| | |
| | _bcast_object handle_hosts_updated |
v | |
+-------------------->+ +-------------+----------------------+ v
| | | +------------------+-----------+
| | | | |
v v v | WorkerNotificationManager |
+--------------------+ +----+------+ +---+------+ +------+---+ | |
| | | | | | | | | |
| Python xxx.py +-------------------------------------> | State 1 | | State 2 | ...... | State n | <---------------------+ _listeners |
| | commit / check_host_updates | | | | | | | |
+--------------------+ +-----------+ +----------+ +----------+ | |
| |
^ ^ ^ | |
| | | | |
on_hosts_updated | | on_hosts_updated | on_hosts_updated | |
| | | | |
+--------------+-------------------+-------------------------+ handle_hosts_updated |
| |
+------------------------------+
具體如下:
我們接下來介紹各級派生類。
4.2 ObjectState
ObjectState 的目的是組裝成 simple Python objects。
class ObjectState(State):
"""State for simple Python objects.
Every object is specified as a keyword argument, and will be assigned as an attribute.
Args:
bcast_object: Horovod broadcast object function used to sync state dictionary.
get_rank: Horovod rank function used to identify is this process is the coordinator.
kwargs: Properties to sync, will be exposed as attributes of the object.
"""
def __init__(self, bcast_object, get_rank, **kwargs):
self._bcast_object = bcast_object
self._saved_state = kwargs
self._set_attrs()
super(ObjectState, self).__init__(bcast_object=bcast_object, get_rank=get_rank)
def save(self):
new_state = {}
for attr in self._saved_state.keys():
new_state[attr] = getattr(self, attr)
self._saved_state = new_state
def restore(self):
self._set_attrs()
def sync(self):
if self._saved_state:
self._saved_state = self._bcast_object(self._saved_state)
self._set_attrs()
def _set_attrs(self):
for attr, value in self._saved_state.items():
setattr(self, attr, value)
4.3 TensorFlowKerasState
Horovod 預設已提供標準的TensorFlow,Keras和PyTorch的狀態保持和恢復實現,如果需要在某些場景下自定義,可以過載 hvd.elastic.State 這個物件。
TensorFlowKerasState 是 TensorFlow Keras model and optimizer 的狀態抽象。
初始化函式中,會設定各種相關變數,比如廣播函式。
class TensorFlowKerasState(ObjectState):
def __init__(self, model, optimizer=None, backend=None, **kwargs):
self.model = model
if not _model_built(model):
raise ValueError('Model must be built first. Run `model.build(input_shape)`.')
self.optimizer = optimizer or model.optimizer
self.backend = backend
self._save_model()
if not backend or _executing_eagerly():
self._bcast_model = lambda: _broadcast_model(self.model, self.optimizer, backend=self.backend)
bcast_object = broadcast_object
else:
# For TensorFlow v1, we need to reuse the broadcast op to prevent incrementing the uids
bcast_op = broadcast_variables(_global_variables(), root_rank=0)
self._bcast_model = lambda: self.backend.get_session().run(bcast_op)
bcast_object = broadcast_object_fn(session=self.backend.get_session())
super(TensorFlowKerasState, self).__init__(bcast_object=bcast_object,
get_rank=rank,
**kwargs)
具體實現了幾個方法,基本就是 儲存,恢復 state,同步。
def save(self):
self._save_model()
super(TensorFlowKerasState, self).save()
def restore(self):
self._load_model()
super(TensorFlowKerasState, self).restore()
def sync(self):
self._bcast_model()
self._save_model()
super(TensorFlowKerasState, self).sync()
def _save_model(self):
if _executing_eagerly():
self._saved_model_state = [tf.identity(var) for var in self.model.variables]
self._saved_optimizer_state = [tf.identity(var) for var in self.optimizer.variables()]
else:
self._saved_model_state = self.model.get_weights()
self._saved_optimizer_state = self.optimizer.get_weights()
def _load_model(self):
if _executing_eagerly():
for var, saved_var in zip(self.model.variables, self._saved_model_state):
var.assign(saved_var)
for var, saved_var in zip(self.optimizer.variables(), self._saved_optimizer_state):
var.assign(saved_var)
else:
self.model.set_weights(self._saved_model_state)
self.optimizer.set_weights(self._saved_optimizer_state)
4.4 Restore
我們看到了,restore 會從記憶體中恢復模型。
def restore(self):
self._load_model()
super(TensorFlowKerasState, self).restore()
於是,我們有一個問題:何時呼叫restore?
發現是如果 horovod 捕獲了 HorovodInternalError 之後,會用 restore 來恢復。
def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False
try:
while True:
if not skip_sync:
state.sync()
try:
return func(state, *args, **kwargs)
except HorovodInternalError:
state.restore() # 在這裡呼叫
skip_sync = False
except HostsUpdatedInterrupt as e:
skip_sync = e.skip_sync
reset()
state.on_reset()
finally:
notification_manager.remove_listener(state)
return wrapper
0x05 總結
我們再次重複一下,發現節點機制的幾個關鍵設計點:
- 有節點變化時候,如何即時發現?Horovod是通過定期呼叫完成。
- 發現節點變化時候,如何通知各個worker? Horovod通過構建了一個通知機制完成。即,每個worker把自己註冊到WorkerNotificationManager 之上,當有節點變化時候,WorkerNotificationManager 會逐一通知這些worker。
- worker得到通知之後,如何處理?Horovod 把worker的狀態在深度框架上進一步封裝成各種State,得到通知之後就會呼叫State的對應callback函式,或者同步狀態,或者進行其他處理。
簡化版總體邏輯如下:
+-----------------------------v
^ thread loop |
| |
+----------------+----------------------+ |
| ElasticDriver._discovery_thread | |
_notify_workers_host_changes | | |
| | |
+------------------+ | |
| | | |
| | HostManager.update_available_hosts | |
| | | |
| +-----------------+---------------------+ |
| ^ |
| | |
| | |
| +----------<---------------+ v
v
+---------------------------+ HostsUpdatedReques +----------------------------+ handle_hosts_updated +----------------------------+
| | | | | |
| WorkerNotificationClient +----------------------> | WorkerNotificationService | +------------------> | WorkerNotificationManager |
| | | | | |
+---------------------------+ +----------------------------+ +--------+-------------------+
|
|
| on_hosts_updated
|
v
+----+---+
| State |
+--------+
手機如下:
至此,發現節點部分介紹完畢,因為本文只是使用了 WorkerNotificationService 完成通知,但是沒有深入介紹,所以下一篇介紹內部廣播和通知機制。
0xEE 個人資訊
★★★★★★關於生活和技術的思考★★★★★★
微信公眾賬號:羅西的思考
如果您想及時得到個人撰寫文章的訊息推送,或者想看看個人推薦的技術資料,敬請關注。