[原始碼解析] TensorFlow 分散式之 ClusterCoordinator
本文我們主要來看看ParameterServerStrategy如何分發計算,也就是ClusterCoordinator如何運作。這是TF分散式的最後一篇。
安利兩個github,都是非常好的學習資料,推薦。
https://github.com/yuhuiaws/ML-study
https://github.com/Jack47/hack-SysML
另外推薦西門宇少的最新大作讓Pipeline在Transformer LM上沿著Token level並行起來——TeraPipe。
本系列其他文章是:
[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"
[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構
[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯
[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯
[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache
[原始碼解析] TensorFlow 分散式環境(5) --- Session
[原始碼解析] TensorFlow 分散式環境(7) --- Worker 動態邏輯
[原始碼解析] TensorFlow 分散式環境(8) --- 通訊機制
[原始碼解析] TensorFlow 分散式 DistributedStrategy 之基礎篇
[原始碼解析] TensorFlow 分散式之 MirroredStrategy
[原始碼解析] TensorFlow 分散式之 MirroredStrategy 分發計算
[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V1
[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V2
1. 思路
TensorFlow 2 推薦使用一種基於中央協調的架構來進行引數伺服器訓練。每個工作者和引數伺服器都執行一個 tf.distribution.Server,在此基礎上,一個協調者任務負責在工作者和引數伺服器上建立資源,排程功能,並協調訓練。協調器使用 tf.distribution.experimental.coordinator.ClusterCoordinator 來協調叢集,使用 tf.distribution.experimental.ParameterServerStrategy 來定義引數伺服器上的變數和工作者的計算。
ClusterCoordinator 是一個用於安排和協調遠端函式執行的物件。該類用於建立容錯(fault-tolerant)資源和排程函式到遠端 TensorFlow 伺服器。目前該類不支援獨立使用,它應該與旨在與之合作的 tf.distribution 策略一起使用。ClusterCoordinator 類目前只適用於和 tf.distribution.experimental.ParameterServerStrategy 一起工作。
1.1 使用
在使用 ParameterServerStrategy 定義所有的計算後,使用者可以使用 tf.distribution.experimental.coordinator.ClusterCoordinator 類來建立資源並將訓練步驟分配給遠端工作者。
首先,我們來建立一個 ClusterCoordinator 物件並傳入策略物件。
strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)
其次,創一個屬於每個工作者(per-worker)的資料集和一個迭代器。在下面程式碼的 per_worker_dataset_fn 中,建議將 dataset_fn 包裹到 strategy.distribution_datasets_from_function 中,以允許無縫高效的把資料預取(prefetching )到 GPU。
@tf.function
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(dataset_fn)
per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
最後一步是使用 ClusterCoordinator.schedule 將計算分配給遠端工作者。
- schedule 方法把一個 tf.function 插入佇列,並立即返回一個 future-like 的 RemoteValue 。佇列之中的函式將被派發給後臺執行緒中的遠端工作者,RemoteValue 將被非同步填充結果。
- 使用者可以使用 join 方法( ClusterCoordinator.join )來等待所有被規劃(scheduled)的函式執行。
@tf.function
def step_fn(iterator):
return next(iterator)
num_epoches = 4
steps_per_epoch = 5
for i in range(num_epoches):
accuracy.reset_states()
for _ in range(steps_per_epoch):
coordinator.schedule(step_fn, args=(per_worker_iterator,))
# Wait at epoch boundaries.
coordinator.join()
print ("Finished epoch %d, accuracy is %f." % (i, accuracy.result().numpy()))
下面是如何得到 RemoteValue 的結果。
loss = coordinator.schedule(step_fn, args=(per_worker_iterator,))
print ("Final loss is %f" % loss.fetch())
使用者也可以啟動所有的步驟(steps),並在等待完成時做一些事情。
for _ in range(total_steps):
coordinator.schedule(step_fn, args=(per_worker_iterator,))
while not coordinator.done():
time.sleep(10)
# Do something like logging metrics or writing checkpoints.
1.2 問題點
依據前面的程式碼,我們總結出來問題點如下:
- Worker 如何知道使用哪些裝置?
- 如何具體執行使用者函式?
- 如何獲取資料?
接下來我們就嘗試通過分析程式碼來回答這些問題。
2. 定義
ClusterCoordinator 的主要思路如下。
- 協調者不是訓練工作者之一,相反,它負責建立資源,如變數和資料集,排程 "tf.function",儲存檢查點等等。
- 為了使訓練工作順利進行,協調者派遣 "tf.function" 在遠端工作者上執行。
- 在收到協調者的請求後,工作者通過從引數伺服器讀取變數、執行操作和更新引數伺服器上的變數來執行 "tf.function"。
- 每個工作者只處理來自協調者的請求,並與引數伺服器進行通訊。而不與叢集中的其他工作者直接互動。
ClusterCoordinator 定義具體如下,我們可以看到,其主要是配置了 _strategy 成員變數,生成了 _cluster 成員變數。
@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
class ClusterCoordinator(object):
def __new__(cls, strategy):
# ClusterCoordinator is kept as a single instance to a given Strategy .
if strategy._cluster_coordinator is None:
strategy._cluster_coordinator = super(
ClusterCoordinator, cls).__new__(cls)
return strategy._cluster_coordinator
def __init__(self, strategy):
"""Initialization of a ClusterCoordinator instance.
Args:
strategy: a supported tf.distribute.Strategy object. Currently, only
tf.distribute.experimental.ParameterServerStrategy is supported.
Raises:
ValueError: if the strategy being used is not supported.
"""
if not getattr(self, "_has_initialized", False):
if not isinstance(strategy,
parameter_server_strategy_v2.ParameterServerStrategyV2):
raise ValueError(
"Only tf.distribute.experimental.ParameterServerStrategy "
"is supported to work with "
" tf.distribute.experimental.coordinator.ClusterCoordinator "
"currently.")
self._strategy = strategy
self.strategy.extended._used_with_coordinator = True
self._cluster = Cluster(strategy)
self._has_initialized = True
def __del__(self):
self._cluster.stop()
@property
def strategy(self):
"""Returns the Strategy associated with the ClusterCoordinator ."""
return self._strategy
2.1 Schedule
由 ClusterCoordinator 物件提供的最重要的 API 是 schedule,其會分派 tf.function 到一個工作者,以便非同步執行,具體如下:
- 該方法是非阻塞的,因為它把 fn 插入佇列,並立即返回 tf.distribution.experimental.coordinator.RemoteValue 物件。fn 排隊等待稍後執行。
- 在佇列之中排隊的函式將被派發給後臺執行緒中的遠端工作者來非同步執行,他們的 RemoteValue 將被非同步賦值。
- 由於 schedule 不需要分配一個工作者,傳遞進來的 tf.function 可以在任何可用的工作者上執行。
- 可以呼叫 fetch 來等待函式執行完成,並從遠端工作者那裡獲取其輸出。另一方面,也可以呼叫 tf.distribution.experimental.coordinator.ClusterCoordinator.join 來等待所有預定的函式完成。
失敗和容錯的策略如下:
- 由於工作者在執行函式的任何時候都可能失敗,所以函式有可能被部分執行,但是 tf.distribution.experimental.coordinator.ClusterCoordinator 保證在這些事件中,函式最終將在任何可用的工作者上執行。
- schedule 保證 fn 至少在工作者上執行一次;如果其對應的工作者在執行過程中失敗,由於函式的執行不是原子性的,所以一個函式可能被執行多次。
- 如果被執行的工作者在結束之前變得不可用,該函式將在另一個可用的工作者上重試。
- 如果任何先前安排的函式出現錯誤,schedule 將丟擲其中任何一個錯誤,並清除到目前為止收集的錯誤。使用者可以在返回的 tf.distribution.experimental.coordinator.RemoteValue 上呼叫 fetch 來檢查它們是否已經執行、失敗或取消,如果需要,可以重新安排相應的函式。當 schedule 引發異常時,它保證沒有任何函式仍在執行。
Schedule 的具體定義如下,資料迭代器作為引數之一會和 fn 一起被傳入。
def schedule(self, fn, args=None, kwargs=None):
"""Schedules fn to be dispatched to a worker for asynchronous execution.
This method is non-blocking in that it queues the fn which will be
executed later and returns a
tf.distribute.experimental.coordinator.RemoteValue object immediately.
fetch can be called on it to wait for the function execution to finish
and retrieve its output from a remote worker. On the other hand, call
tf.distribute.experimental.coordinator.ClusterCoordinator.join to wait for
all scheduled functions to finish.
schedule guarantees that fn will be executed on a worker at least once;
it could be more than once if its corresponding worker fails in the middle
of its execution. Note that since worker can fail at any point when
executing the function, it is possible that the function is partially
executed, but tf.distribute.experimental.coordinator.ClusterCoordinator
guarantees that in those events, the function will eventually be executed on
any worker that is available.
If any previously scheduled function raises an error, schedule will raise
any one of those errors, and clear the errors collected so far. What happens
here, some of the previously scheduled functions may have not been executed.
User can call fetch on the returned
tf.distribute.experimental.coordinator.RemoteValue to inspect if they have
executed, failed, or cancelled, and reschedule the corresponding function if
needed.
When schedule raises, it guarantees that there is no function that is
still being executed.
At this time, there is no support of worker assignment for function
execution, or priority of the workers.
args and kwargs are the arguments passed into fn , when fn is
executed on a worker. They can be
tf.distribute.experimental.coordinator.PerWorkerValues and in this case,
the argument will be substituted with the corresponding component on the
target worker. Arguments that are not
tf.distribute.experimental.coordinator.PerWorkerValues will be passed into
fn as-is. Currently, tf.distribute.experimental.coordinator.RemoteValue
is not supported to be input args or kwargs .
Args:
fn: A tf.function ; the function to be dispatched to a worker for
execution asynchronously. Regular python funtion is not supported to be
scheduled.
args: Positional arguments for fn .
kwargs: Keyword arguments for fn .
Returns:
A tf.distribute.experimental.coordinator.RemoteValue object that
represents the output of the function scheduled.
Raises:
Exception: one of the exceptions caught by the coordinator from any
previously scheduled function, since the last time an error was thrown
or since the beginning of the program.
"""
if not isinstance(fn,
(def_function.Function, tf_function.ConcreteFunction)):
raise TypeError(
" tf.distribute.experimental.coordinator.ClusterCoordinator.schedule "
" only accepts a tf.function or a concrete function.")
# Slot variables are usually created during function tracing time; thus
# schedule needs to be called within the strategy.scope() .
with self.strategy.scope():
self.strategy.extended._being_scheduled = True
remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs)
self.strategy.extended._being_scheduled = False
return remote_value
2.2 Join
Join 方法的作用是阻塞直到所有預定的函式都執行完畢,其具體特點如下:
- 如果任何先前安排的函式產生錯誤,join 將因為丟擲一個錯誤而失敗,並清除到目前為止收集的錯誤。如果發生這種情況,一些先前安排的函式可能沒有被執行。
- 使用者可以對返回的 tf.distribution.experimental.coordinator.RemoteValue 呼叫 fetch 來檢查它們是否已經執行、失敗或取消了。
- 如果一些已經取消的函式需要重新安排,使用者應該再次呼叫 schedule 。
- 當 join 返回或丟擲異常時,它保證沒有任何函式仍在執行。
def join(self):
"""Blocks until all the scheduled functions have finished execution.
If any previously scheduled function raises an error, join will fail by
raising any one of those errors, and clear the errors collected so far. If
this happens, some of the previously scheduled functions may have not been
executed. Users can call fetch on the returned
tf.distribute.experimental.coordinator.RemoteValue to inspect if they have
executed, failed, or cancelled. If some that have been cancelled need to be
rescheduled, users should call schedule with the function again.
When join returns or raises, it guarantees that there is no function that
is still being executed.
Raises:
Exception: one of the exceptions caught by the coordinator by any
previously scheduled function since the last time an error was thrown or
since the beginning of the program.
"""
self._cluster.join()
2.3 Done
Done 方法返回所有分發的函式是否已經執行完畢。如果任何先前分發的函式引發錯誤,done'將會失敗。
def done(self):
"""Returns whether all the scheduled functions have finished execution.
If any previously scheduled function raises an error, done will fail by
raising any one of those errors.
When done returns True or raises, it guarantees that there is no function
that is still being executed.
Returns:
Whether all the scheduled functions have finished execution.
Raises:
Exception: one of the exceptions caught by the coordinator by any
previously scheduled function since the last time an error was thrown or
since the beginning of the program.
"""
return self._cluster.done()
2.4 Fetch
Fetch 會獲取 remote values 的結果。
def fetch(self, val):
"""Blocking call to fetch results from the remote values.
This is a wrapper around
tf.distribute.experimental.coordinator.RemoteValue.fetch for a
RemoteValue structure; it returns the execution results of
RemoteValue s. If not ready, wait for them while blocking the caller.
Example:
```python
strategy = ...
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy)
def dataset_fn():
return tf.data.Dataset.from_tensor_slices([1, 1, 1])
with strategy.scope():
v = tf.Variable(initial_value=0)
@tf.function
def worker_fn(iterator):
def replica_fn(x):
v.assign_add(x)
return v.read_value()
return strategy.run(replica_fn, args=(next(iterator),))
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
distributed_iterator = iter(distributed_dataset)
result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
assert coordinator.fetch(result) == 1
```
Args:
val: The value to fetch the results from. If this is structure of
tf.distribute.experimental.coordinator.RemoteValue , fetch() will be
called on the individual
tf.distribute.experimental.coordinator.RemoteValue to get the result.
Returns:
If val is a tf.distribute.experimental.coordinator.RemoteValue or a
structure of tf.distribute.experimental.coordinator.RemoteValue s,
return the fetched tf.distribute.experimental.coordinator.RemoteValue
values immediately if they are available, or block the call until they are
available, and return the fetched
tf.distribute.experimental.coordinator.RemoteValue values with the same
structure. If val is other types, return it as-is.
"""
def _maybe_fetch(val):
if isinstance(val, RemoteValue):
return val.fetch()
else:
return val
return nest.map_structure(_maybe_fetch, val)
3. 資料
除了排程遠端函式,ClusterCoordinator 還幫助在所有工作者上建立資料集,並當一個工作者從失敗中恢復時重建這些資料集。使用者可以通過呼叫 dataset_fn 來在worker裝置上建立資料集。使用例子如下:
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy=strategy)
@tf.function
def worker_fn(iterator):
return next(iterator)
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(
lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
per_worker_iter = iter(per_worker_dataset)
remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
assert remote_value.fetch() == 3
3.1 建立資料集
上面程式碼使用了 create_per_worker_dataset 在worker上建立資料集,這些資料集由 dataset_fn 生成,並返回一個代表這些資料集的集合。在這樣的資料集集合上呼叫 iter 會返回一個 tf.distribution.experimental.coordinator.PerWorkerValues,它是一個迭代器的集合,其中的迭代器已經被放置在各個工作者上。
需要注意,不支援在迭代器的 "PerWorkerValues"上直接呼叫 "next"。該迭代器應該是作為一個引數傳遞給 tf.distribution.experimental.coordinator.ClusterCoordinator.schedule 。當計劃的函式即將被工作者執行時,該函式將收到與該工作者相對應的單個迭代器。該函式可以對該迭代器呼叫 next 方法。
目前,schedule 方法假定工作者都是相同的,因此假設不同工作者上的資料集是一樣的,除非它們包含 dataset.shuffle 操作,並且沒有設定隨機種子,在這種情況下,它們的洗牌方式會不同。正因為如此,建議將資料集無限地重複,並安排有限的步驟,而不是依賴於資料集的 OutOfRangeError 來結束。
def create_per_worker_dataset(self, dataset_fn):
"""Create dataset on workers by calling dataset_fn on worker devices.
This creates the given dataset generated by dataset_fn on workers
and returns an object that represents the collection of those individual
datasets. Calling iter on such collection of datasets returns a
tf.distribute.experimental.coordinator.PerWorkerValues , which is a
collection of iterators, where the iterators have been placed on respective
workers.
Calling next on a PerWorkerValues of iterator is unsupported. The
iterator is meant to be passed as an argument into
tf.distribute.experimental.coordinator.ClusterCoordinator.schedule . When
the scheduled function is about to be executed by a worker, the
function will receive the individual iterator that corresponds to the
worker. The next method can be called on an iterator inside a
scheduled function when the iterator is an input of the function.
Currently the schedule method assumes workers are all the same and thus
assumes the datasets on different workers are the same, except they may be
shuffled differently if they contain a dataset.shuffle operation and a
random seed is not set. Because of this, we also recommend the datasets to
be repeated indefinitely and schedule a finite number of steps instead of
relying on the OutOfRangeError from a dataset.
Args:
dataset_fn: The dataset function that returns a dataset. This is to be
executed on the workers.
Returns:
An object that represents the collection of those individual
datasets. iter is expected to be called on this object that returns
a tf.distribute.experimental.coordinator.PerWorkerValues of the
iterators (that are on the workers).
"""
return values_lib.get_per_worker_dataset(dataset_fn, self)
get_per_worker_dataset 則返回 PerWorkerDatasetFromDataset 或者 PerWorkerDatasetFromDatasetFunction。
def get_per_worker_dataset(dataset_or_dataset_fn, coordinator):
if callable(dataset_or_dataset_fn):
return PerWorkerDatasetFromDatasetFunction(dataset_or_dataset_fn,
coordinator)
else:
return PerWorkerDatasetFromDataset(dataset_or_dataset_fn, coordinator)
3.2 PerWorkerDistributedDataset
PerWorkerDistributedDataset 代表了從一個資料集建立的工作者使用的分散式資料集。
class PerWorkerDatasetFromDataset(PerWorkerDatasetFromDatasetFunction):
"""Represents worker-distributed datasets created from a dataset."""
def __init__(self, dataset, coordinator):
"""Makes an iterable from datasets created by the given dataset.
It creates a dataset_fn which deserializes a dataset from a graph under the
hood.
Args:
dataset: A tf.data.Dataset, a DistributedDataset or a
DistributedDatasetsFromFunction
coordinator: a ClusterCoordinator object, used to create dataset
resources.
"""
if isinstance(dataset, input_lib.DistributedDataset):
original_dataset = dataset._original_dataset
serialized = serialize_dataset_to_graph(original_dataset)
def dataset_fn():
deserialized = deserialize_dataset_from_graph(
serialized, original_dataset.element_spec)
dataset.build(dataset_to_replace=deserialized)
return dataset
elif isinstance(dataset, input_lib.DistributedDatasetsFromFunction):
def dataset_fn():
dataset.build()
return dataset
elif isinstance(dataset, dataset_ops.Dataset):
serialized = serialize_dataset_to_graph(dataset)
def dataset_fn():
return deserialize_dataset_from_graph(serialized, dataset.element_spec)
else:
raise ValueError("Unexpected dataset type!")
super(PerWorkerDatasetFromDataset, self).__init__(dataset_fn, coordinator)
3.3 PerWorkerDatasetFromDatasetFunction
PerWorkerDistributedDataset 代表了從一個資料集方法建立的工作者使用的分散式資料集。
在 iter 之中有:
-
呼叫 _create_per_worker_iterator 得到一個 iter(dataset)。
-
呼叫 self._coordinator._create_per_worker_resources 為每工作者生成一個 iterator。
-
最後返回一個 PerWorkerDistributedIterator。
class PerWorkerDatasetFromDatasetFunction(object):
"""Represents worker-distributed datasets created from dataset function."""
def __init__(self, dataset_fn, coordinator):
"""Makes an iterable from datasets created by the given function.
Args:
dataset_fn: A function that returns a Dataset .
coordinator: a ClusterCoordinator object, used to create dataset
resources.
"""
def disallow_variable_creation(next_creator, **kwargs):
raise ValueError("Creating variables in dataset_fn is not allowed.")
if isinstance(dataset_fn, def_function.Function):
with variable_scope.variable_creator_scope(disallow_variable_creation):
dataset_fn = dataset_fn.get_concrete_function()
elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
with variable_scope.variable_creator_scope(disallow_variable_creation):
dataset_fn = def_function.function(dataset_fn).get_concrete_function()
self._dataset_fn = dataset_fn
self._coordinator = coordinator
self._element_spec = None
def __iter__(self):
# We would like users to create iterators outside tf.function s so that we
# can track them.
if (not context.executing_eagerly() or
ops.get_default_graph().building_function):
raise RuntimeError(
"__iter__() is not supported inside of tf.function or in graph mode.")
def _create_per_worker_iterator():
dataset = self._dataset_fn()
return iter(dataset)
# If PerWorkerDatasetFromDatasetFunction.__iter__ is called multiple
# times, for the same object it should only create and register resource
# once. Using object id to distinguish different iterator resources.
per_worker_iterator = self._coordinator._create_per_worker_resources(
_create_per_worker_iterator)
# Setting type_spec of each RemoteValue so that functions taking these
# RemoteValues as inputs can be traced.
for iterator_remote_value in per_worker_iterator._values:
iterator_remote_value._type_spec = (
input_lib.get_iterator_spec_from_dataset(
self._coordinator.strategy, self._dataset_fn.structured_outputs))
return PerWorkerDistributedIterator(per_worker_iterator._values)
@property
def element_spec(self):
"""The type specification of an element of this dataset.
This property is subject to change without notice.
"""
return self._dataset_fn.structured_outputs.element_spec
3.4 _create_per_worker_resources
_create_per_worker_resources 會呼叫各個工作者的方法來讓每個工作者得到資料。
def _create_per_worker_resources(self, fn, args=None, kwargs=None):
"""Synchronously create resources on the workers.
The resources are represented by
tf.distribute.experimental.coordinator.RemoteValue s.
Args:
fn: The function to be dispatched to all workers for execution
asynchronously.
args: Positional arguments for fn .
kwargs: Keyword arguments for fn .
Returns:
A tf.distribute.experimental.coordinator.PerWorkerValues object, which
wraps a tuple of tf.distribute.experimental.coordinator.RemoteValue
objects.
"""
results = []
for w in self._cluster.workers:
results.append(w.create_resource(fn, args=args, kwargs=kwargs))
return PerWorkerValues(tuple(results))
3.5 PerWorkerValues
PerWorkerValues 是一個容納 value 列表的容器,每個工作者對應一個 value。Tf.distribution.experimental.coordinator.PerWorkerValues 包含一個值的集合,其中每個值都位於其相應的工作者上,當被用作 tf.distribution.experimental.coordinator.ClusterCoordinator.schedule() 的 args 或 kwargs 時,某一個工作者的特定值將被傳遞到該工作者上執行的函式中。
建立 tf.distribution.experimental.coordinator.PerWorkerValues 物件的唯一路徑是通過在 ClusterCoordinator.create_per_worker_dataset 返回的分散式資料集例項上呼叫 iter 。目前還不支援建立自定義 tf.distribution.experimental.coordinator.PerWorkerValues 的機制。
@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
class PerWorkerValues(composite_tensor.CompositeTensor):
"""A container that holds a list of values, one value per worker.
tf.distribute.experimental.coordinator.PerWorkerValues contains a collection
of values, where each of the values is located on its corresponding worker,
and upon being used as one of the args or kwargs of
tf.distribute.experimental.coordinator.ClusterCoordinator.schedule() , the
value specific to a worker will be passed into the function being executed at
that corresponding worker.
Currently, the only supported path to create an object of
tf.distribute.experimental.coordinator.PerWorkerValues is through calling
iter on a ClusterCoordinator.create_per_worker_dataset -returned
distributed dataset instance. The mechanism to create a custom
tf.distribute.experimental.coordinator.PerWorkerValues is not yet supported.
"""
def __init__(self, values):
for v in values:
if not isinstance(v, RemoteValue):
raise AssertionError(
" PerWorkerValues should only take RemoteValue s.")
self._values = tuple(values)
@property
def _type_spec(self):
return PerWorkerValuesTypeSpec(
self._values[0]._type_spec,
type(self))
獲取資料的邏輯如下:
4. Cluster
Cluster 才是業務執行者。
4.1 定義
Cluster 是一個工作者叢集。在初始化方法之中,會做如下處理:
- 設定如何忽略引數伺服器暫時錯誤。
- 設定工作者的裝置名字。
- 生成一系列工作者。
這裡要注意的是如何忽略因為工作者瞬時連線錯誤而報告的故障。
- 工作者和引數伺服器之間的瞬時連線問題會由工作者轉達給協調者,這將導致協調者認為存在引數伺服器故障。
- 瞬時與永久的引數伺服器故障之間的區別是工作者報告的數量。當這個環境變數設定為正整數 K 時,協調器忽略最多 K 個失敗報告,也就是說,只有超過 K 個執行錯誤,並且這些錯誤是因為同一個引數伺服器例項導致的,我們才認為引數伺服器例項遇到了失敗。
class Cluster(object):
"""A cluster with workers.
We assume all function errors are fatal and based on this assumption our
error reporting logic is:
1) Both schedule and join can raise a non-retryable error which is the
first error seen by the coordinator from any previously scheduled functions.
2) When an error is raised, there is no guarantee on how many previously
scheduled functions have been executed; functions that have not been executed
will be thrown away and marked as cancelled.
3) After an error is raised, the internal state of error will be cleared.
I.e. functions can continue to be scheduled and subsequent calls of schedule
or join will not raise the same error again.
Attributes:
failure_handler: The failure handler used to handler worker preemption
failure.
workers: a list of Worker objects in the cluster.
"""
def __init__(self, strategy):
"""Initializes the cluster instance."""
self._num_workers = strategy._num_workers
self._num_ps = strategy._num_ps
# 如何忽略引數伺服器暫時錯誤
self._transient_ps_failures_threshold = int(
os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
self._potential_ps_failures_lock = threading.Lock()
self._potential_ps_failures_count = [0] * self._num_ps
self._closure_queue = _CoordinatedClosureQueue()
self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
self)
# 設定 worker 的裝置名字
worker_device_strings = [
"/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
]
# 生成 Workers
self.workers = [
Worker(i, w, self) for i, w in enumerate(worker_device_strings)
]
4.2 Schedule
這個類提供的最重要的API是 "schedule"/"join" 這對函式。"schedule" API是非阻塞的,它把一個 "tf.function "插入佇列,並立即返回一個 "RemoteValue"。
def schedule(self, function, args, kwargs):
"""Schedules function to be dispatched to a worker for execution.
Args:
function: The function to be dispatched to a worker for execution
asynchronously.
args: Positional arguments for fn .
kwargs: Keyword arguments for fn .
Returns:
A RemoteValue object.
"""
closure = Closure(
function,
self._closure_queue._cancellation_mgr,
args=args,
kwargs=kwargs)
self._closure_queue.put(closure)
return closure.output_remote_value
def join(self):
"""Blocks until all scheduled functions are executed."""
self._closure_queue.wait()
具體邏輯如下,虛線表示資料集被傳入,這裡的 Queue 是 from six.moves import queue 引入的 queue.Queue,我們接下來在_CoordinatedClosureQueue之中會見到。
或者我們從官方文件圖來看,目前完成的是左邊圓圈部分。
4.3 停止
停止程式碼如下,具體是呼叫佇列的處理方法。
def stop(self):
"""Stop worker, worker preemption threads, and the closure queue."""
self.failure_handler.stop()
for worker in self.workers:
worker.stop()
self._closure_queue.stop()
def done(self):
"""Returns true if all scheduled functions are executed."""
return self._closure_queue.done()
5. 任務 Closure
Closure 的作用是把任務封裝起來,並且提供了其他功能。
class Closure(object):
"""Hold a function to be scheduled and its arguments."""
def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
if not callable(function):
raise ValueError("Function passed to ClusterCoordinator.schedule must "
"be a callable object.")
self._args = args or ()
self._kwargs = kwargs or {}
_disallow_remote_value_as_input(self._args)
_disallow_remote_value_as_input(self._kwargs)
if isinstance(function, def_function.Function):
replica_args = _select_worker_slice(0, self._args)
replica_kwargs = _select_worker_slice(0, self._kwargs)
# Note: no need to handle function registration failure since this kind of
# failure will not raise exceptions as designed in the runtime. The
# coordinator has to rely on subsequent operations that raise to catch
# function registration failure.
# Record the function tracing overhead. Note that we pass in the tracing
# count of the def_function.Function as a state tracker, so that metrics
# will only record the time for actual function tracing (i.e., excluding
# function cache lookups).
with metric_utils.monitored_timer(
"function_tracing", state_tracker=function._get_tracing_count):
self._concrete_function = function.get_concrete_function(
*nest.map_structure(_maybe_as_type_spec, replica_args),
**nest.map_structure(_maybe_as_type_spec, replica_kwargs))
elif isinstance(function, tf_function.ConcreteFunction):
self._concrete_function = function
if hasattr(self, "_concrete_function"):
# If we have a concrete function, we get to retrieve the output type spec
# via the structured_output.
output_type_spec = func_graph.convert_structure_to_signature(
self._concrete_function.structured_outputs)
self._function = cancellation_mgr.get_cancelable_function(
self._concrete_function)
else:
# Otherwise (i.e. what is passed in is a regular python function), we have
# no such information.
output_type_spec = None
self._function = function
self.output_remote_value = RemoteValueImpl(self, output_type_spec)
5.1 執行
Closure 的 execute_on 負責執行,具體是在指定的裝置上執行 self._function,就是使用者自定義的 function。需要注意的是,with context.executor_scope(worker.executor) 使用了 context。
def execute_on(self, worker):
"""Executes the closure on the given worker.
Args:
worker: a Worker object.
"""
replica_args = _select_worker_slice(worker.worker_index, self._args)
replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
e = (
_maybe_rebuild_remote_values(worker, replica_args) or
_maybe_rebuild_remote_values(worker, replica_kwargs))
if e:
if not isinstance(e, InputError):
e = InputError(e)
self.output_remote_value._set_error(e)
return
with ops.device(worker.device_name): # 在指定裝置上
with context.executor_scope(worker.executor): # 通過上下文
with metric_utils.monitored_timer("closure_execution"):
output_values = self._function( # 執行使用者的引數
*nest.map_structure(_maybe_get_remote_value, replica_args),
**nest.map_structure(_maybe_get_remote_value, replica_kwargs))
self.output_remote_value._set_values(output_values)
Self._function 是使用者自定義的 function,我們再給出一個方法示例,可以看出來可以使用 strategy.run 把訓練方法分發到遠端工作者進行訓練。
@tf.function
def worker_fn(iterator):
def replica_fn(inputs):
batch_data, labels = inputs
# calculate gradient, applying gradient, metrics update etc.
strategy.run(replica_fn, args=(next(iterator),))
5.2 取消
使用者可以設定取消 Closure,就是在返回值之中做下設定。
def mark_cancelled(self):
self.output_remote_value._set_error(
errors.CancelledError(
None, None, "The corresponding function is "
"cancelled. Please reschedule the function."))
5.3 ResourceClosure
ResourceClosure 是派生類,把 Closure 用 RemoteValue 包裝起來。實際上使用的都是 ResourceClosure。
class ResourceClosure(Closure):
def build_output_remote_value(self):
if self._output_remote_value_ref is None:
# We need to remember the Closure object in the RemoteValue here.
ret = RemoteValueImpl(self, self._output_type_spec)
self._output_remote_value_ref = weakref.ref(ret)
return ret
else:
return self._output_remote_value_ref()
6. 佇列
_CoordinatedClosureQueue 是任務所在的佇列。
6.1 定義
from six.moves import queue
class _CoordinatedClosureQueue(object):
"""Manage a queue of closures, inflight count and errors from execution.
This class is thread-safe.
"""
def __init__(self):
# self._inflight_closure_count only tracks the number of inflight closures
# that are "in generation". Once an error occurs, error generation is
# incremented and all subsequent arriving closures (from inflight) are
# considered "out of generation".
self._inflight_closure_count = 0
self._queue_lock = threading.Lock()
# Condition indicating that all pending closures (either queued or inflight)
# have been processed, failed, or cancelled.
self._stop_waiting_condition = threading.Condition(self._queue_lock)
# Condition indicating that an item becomes available in queue (not empty).
self._closures_queued_condition = threading.Condition(self._queue_lock)
self._should_process_closures = True
# Condition indicating that a queue slot becomes available (not full).
# Note that even with "infinite" queue size, there is still a "practical"
# size limit for the queue depending on host memory capacity, and thus the
# queue will eventually become full with a lot of enqueued closures.
self._queue_free_slot_condition = threading.Condition(self._queue_lock)
# Condition indicating there is no inflight closures.
self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
# Use to cancel in-flight closures.
self._cancellation_mgr = cancellation.CancellationManager()
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
self._error = None
# The following is a lock to make sure when wait is called and before it
# returns no put can be executed during this period. It is because wait
# won't know what to do with newly put closures. This lock adds an cutoff
# for wait so that closures put into the queue while waiting would not be
# taken responsible by this wait .
#
# We cannot reuse the self._queue_lock since when wait waits for a
# condition, the self._queue_lock will be released.
#
# We don't use a reader/writer's lock on purpose to reduce the complexity
# of the code.
self._put_wait_lock = threading.Lock()
6.2 插入取出
Put 和 get 方法分別負責插入和取出。
def put(self, closure):
"""Put a closure into the queue for later execution.
If mark_failed was called before put , the error from the first
invocation of mark_failed will be raised.
Args:
closure: The Closure to put into the queue.
"""
with self._put_wait_lock, self._queue_lock:
self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
self._queue.put(closure, block=False)
self._raise_if_error()
self._closures_queued_condition.notify()
def get(self, timeout=None):
"""Return a closure from the queue to be executed."""
with self._queue_lock:
while self._queue.empty() and self._should_process_closures:
if not self._closures_queued_condition.wait(timeout=timeout):
return None
if not self._should_process_closures:
return None
closure = self._queue.get(block=False)
self._queue_free_slot_condition.notify()
self._inflight_closure_count += 1
return closure
Put_back 則負責把 closure 重新放回queue。
def put_back(self, closure):
"""Put the closure back into the queue as it was not properly executed."""
with self._queue_lock:
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to put_back.")
if self._error:
closure.mark_cancelled()
else:
self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
self._queue.put(closure, block=False)
self._closures_queued_condition.notify()
self._inflight_closure_count -= 1
if self._inflight_closure_count == 0:
self._no_inflight_closure_condition.notifyAll()
6.3 等待
方法 wait 會等待所有 closures 結束。
def wait(self, timeout=None):
"""Wait for all closures to be finished before returning.
If mark_failed was called before or during wait , the error from the
first invocation of mark_failed will be raised.
Args:
timeout: A float specifying a timeout for the wait in seconds.
Returns:
True unless the given timeout expired, in which case it returns False.
"""
with self._put_wait_lock, self._queue_lock:
while (not self._error and
(not self._queue.empty() or self._inflight_closure_count > 0)):
if not self._stop_waiting_condition.wait(timeout=timeout):
return False
self._raise_if_error()
return True
6.4 異常&結束
Mark_failed 和 done 則是處理結束和異常的一套組合。
def mark_failed(self, e):
"""Sets error and unblocks any wait() call."""
with self._queue_lock:
# TODO(yuefengz): maybe record all failure and give users more
# information?
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to mark_failed.")
if self._error is None:
self._error = e
self._inflight_closure_count -= 1
if self._inflight_closure_count == 0:
self._no_inflight_closure_condition.notifyAll()
self._stop_waiting_condition.notifyAll()
def done(self):
"""Returns true if the queue is empty and there is no inflight closure.
If mark_failed was called before done , the error from the first
invocation of mark_failed will be raised.
"""
with self._queue_lock:
self._raise_if_error()
return self._queue.empty() and self._inflight_closure_count == 0
6.5 停止
Stop 和 _cancel_all_closures 負責暫停 closures。
def stop(self):
with self._queue_lock:
self._should_process_closures = False
self._closures_queued_condition.notifyAll()
def _cancel_all_closures(self):
"""Clears the queue and sets remaining closures cancelled error.
This method expects self._queue_lock to be held prior to entry.
"""
self._cancellation_mgr.start_cancel()
while self._inflight_closure_count > 0:
self._no_inflight_closure_condition.wait()
while True:
try:
closure = self._queue.get(block=False)
self._queue_free_slot_condition.notify()
closure.mark_cancelled()
except queue.Empty:
break
# The cancellation manager cannot be reused once cancelled. After all
# closures (queued or inflight) are cleaned up, recreate the cancellation
# manager with clean state.
# Note on thread-safety: this is triggered when one of theses
# ClusterCoordinator APIs are called: schedule , wait , and done . At the
# same time, no new closures can be constructed (which reads the
# _cancellation_mgr to get cancellable functions).
self._cancellation_mgr = cancellation.CancellationManager()
def _raise_if_error(self):
"""Raises the error if one exists.
If an error exists, cancel the closures in queue, raises it, and clear
the error.
This method expects self._queue_lock to be held prior to entry.
"""
if self._error:
logging.error("Start cancelling closures due to error %r: %s",
self._error, self._error)
self._cancel_all_closures()
try:
raise self._error
finally:
self._error = None
7.4 Worker
Worker 是函式的執行者。
7.1 定義
Worker 的定義如下,其啟動了一個執行緒來執行 _process_queue。
class Worker(object):
"""A worker in a cluster.
Attributes:
worker_index: The index of the worker in the cluster.
device_name: The device string of the worker, e.g. "/job:worker/task:1".
executor: The worker's executor for remote function execution.
failure_handler: The failure handler used to handler worker preemption
failure.
"""
def __init__(self, worker_index, device_name, cluster):
self.worker_index = worker_index
self.device_name = device_name
# 這裡會有一個executor
self.executor = executor.new_executor(enable_async=False)
self.failure_handler = cluster.failure_handler
self._cluster = cluster
self._resource_remote_value_refs = []
self._should_worker_thread_run = True
# Worker threads need to start after Worker 's initialization.
threading.Thread(target=self._process_queue,
name="WorkerClosureProcessingLoop-%d" % self.worker_index,
daemon=True).start()
New_executor 會呼叫 TFE_NewExecutor。
def new_executor(enable_async):
handle = pywrap_tfe.TFE_NewExecutor(enable_async)
return Executor(handle)
TFE_NewExecutor 定義在 tensorflow/c/eager/c_api_experimental.cc,其生成了 TFE_Executor。
TFE_Executor* TFE_NewExecutor(bool is_async) {
return new TFE_Executor(is_async);
}
TFE_Executor 定義如下,Executor類是會話執行器的抽象,在 TF2 之中,也有 EagerExecutor。
struct TFE_Executor {
explicit TFE_Executor(bool async)
: owned_executor(new tensorflow::EagerExecutor(async)) {}
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
: owned_executor(nullptr), unowned_executor(executor) {}
tensorflow::EagerExecutor* executor() {
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
}
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
tensorflow::EagerExecutor* unowned_executor;
};
7.2 處理
_process_queue 方法會從 queue 之中取出 Closure,然後執行任務。
def _process_queue(self):
"""Function running in a worker thread to process closure queues."""
self._maybe_delay()
while self._should_worker_thread_run:
closure = self._cluster._closure_queue.get()
if not self._should_worker_thread_run or closure is None:
return
self._process_closure(closure)
# To properly stop the worker and preemption threads, it is important that
# ClusterCoordinator object is not held onto so its __del__ can be
# called. By removing the reference to the closure that has already been
# processed, we ensure that the closure object is released, while
# getting the next closure at above self._cluster._closure_queue.get()
# call.
del closure
7.2.1 等待
_process_queue 之中首先會呼叫 _maybe_delay 等待環境變數配置。
def _maybe_delay(self):
"""Delay if corresponding env vars are set."""
# If the following two env vars variables are set. Scheduling for workers
# will start in a staggered manner. Worker i will wait for
# TF_COORDINATOR_SCHEDULE_START_DELAY * i seconds, not exceeding
# TF_COORDINATOR_SCHEDULE_START_DELAY_MAX .
delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0"))
delay_cap = int(
os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0"))
if delay_cap:
delay_secs = min(delay_secs * self.worker_index, delay_cap)
if delay_secs > 0:
logging.info("Worker %d sleeping for %d seconds before running function",
self.worker_index, delay_secs)
time.sleep(delay_secs)
7.2.2 處理任務
_process_queue 之中接著會呼叫 _process_closure 來執行 closure。
def _process_closure(self, closure):
"""Runs a closure with preemption handling."""
try:
with self._cluster.failure_handler.wait_on_failure(
on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),
on_recovery_fn=self._set_resources_aborted,
worker_device_name=self.device_name):
closure.execute_on(self)
with metric_utils.monitored_timer("remote_value_fetch"):
# Copy the remote tensor to local (the coordinator) in case worker
# becomes unavailable at a later time.
closure.output_remote_value.get()
self._cluster._closure_queue.mark_finished()
except Exception as e:
# Avoid logging the derived cancellation error
if not isinstance(e, errors.CancelledError):
logging.error(
"/job:worker/task:%d encountered the following error when "
"processing closure: %r:%s", self.worker_index, e, e)
closure.output_remote_value._set_error(e)
self._cluster._closure_queue.mark_failed(e)
7.3 資料
我們接下來看看如何把資料讀取放到工作者上執行。前面提到了,在 _create_per_worker_resources 會呼叫 create_resource,為每一個工作者建立其自己的資源。
def create_resource(self, function, args=None, kwargs=None):
"""Synchronously creates a per-worker resource represented by a RemoteValue .
Args:
function: the resource function to be run remotely. It should be a
tf.function , a concrete function or a Python function.
args: positional arguments to be passed to the function.
kwargs: keyword arguments to be passed to the function.
Returns:
one or several RemoteValue objects depending on the function return
values.
"""
# Some notes about the concurrency: currently all the activities related to
# the same worker such as creating resources, setting resources' aborted
# status, and executing closures happen on the same thread. This allows us
# to have simpler logic of concurrency.
closure = ResourceClosure(
function,
self._cluster.closure_queue._cancellation_mgr,
args=args,
kwargs=kwargs)
resource_remote_value = closure.build_output_remote_value()
self._register_resource(resource_remote_value)
# The following is a short-term solution to lazily create resources in
# parallel.
resource_remote_value._set_aborted()
return resource_remote_value
_register_resource 則會把每個 Worker 的資源註冊到 Worker 之上。
def _register_resource(self, resource_remote_value):
if not isinstance(resource_remote_value, RemoteValue):
raise ValueError("Resource being registered is not of type "
" tf.distribute.experimental.coordinator.RemoteValue .")
self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
邏輯如下,虛線表述資料流。使用者通過 put 方法向佇列之中放入 Closure,Worker 通過 put 方法從佇列獲取 Closure 執行。
7.4 停止
Stop 等一系列方法負責停止。
def stop(self):
"""Ensure the worker thread is closed."""
self._should_worker_thread_run = False
def _set_resources_aborted(self):
for weakref_resource in self._resource_remote_value_refs:
resource = weakref_resource()
if resource:
resource._set_aborted() # pylint: disable=protected-access
def _set_dead(self):
raise NotImplementedError("_set_dead is not implemented.")
7.5 與 Strategy 聯絡
至此,我們其實還沒有正式和 Strategy 聯絡起來,我們再用一個例子來看看,這裡會發現,傳遞給 coordinator 的方法之中,會呼叫 strategy.run(replica_fn, args=(next(iterator),)),這樣就和 strategy 聯絡起來了。
strategy = ...
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy)
def dataset_fn():
return tf.data.Dataset.from_tensor_slices([1, 1, 1])
with strategy.scope():
v = tf.Variable(initial_value=0)
@tf.function
def worker_fn(iterator):
def replica_fn(x):
v.assign_add(x)
return v.read_value()
return strategy.run(replica_fn, args=(next(iterator),)) # 這裡正式聯絡起來
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
distributed_iterator = iter(distributed_dataset)
result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
assert coordinator.fetch(result) == 1
8. Failover
8.1 策略
應對失敗的總體策略大致如下:
-
當發現一個工作者失敗了,Coordinator 把 function 再次放入佇列,然後發給另一個工作者執行,同時啟動一個後臺執行緒等待恢復,如果恢復了,則用資源來重建這個工作者,繼續分配工作。
-
因此,一些工作者的失敗並不妨礙叢集繼續工作,這使得叢集之中的例項可以偶爾不可用(例如,可搶佔或spot 例項)。但是協調者和引數伺服器必須始終可用,這樣叢集才能取得進展。
8.2 工作者失敗
當發生工作者失敗(failure)時候,具體邏輯如下:
- ClusterCoordinator 類與 tf.distribution.experimental.ParameterServerStrategy 一起使用時,具有內建的工作者故障容錯功能。也就是說,當一些工作者由於任何原因,協調器無法聯絡上它們,這些工作者的訓練進度將繼續由其餘工作者完成。
- 在工作者恢復時,之前提供的資料集函式(對於自定義訓練迴圈,可以是 ClusterCoordinator.create_per_worker_dataset,或者是 tf.keras.utils.experimental.DatasetCreator 用於 Model.fit )將被呼叫到工作者身上,以重新建立資料集。
- 當一個失敗的工作者恢復之後,在使用通過 create_per_worker_dataset 建立的資料被重新建立後,它將被新增到函式執行中。
8.3 引數伺服器或者協調器故障
當引數伺服器失敗時,schedule,join 或 done 會引發 tf.errors.UnavailableError。在這種情況下,除了重置失敗的引數伺服器外,使用者還應該重新啟動協調器,使其重新連線到工作者和引數伺服器,重新建立變數,並載入檢查點。如果協調器發生故障,在使用者把它重置回來之後,程式會自動連線到工作者和引數伺服器,並從檢查點繼續前進。因為協調器本身也可能變得不可用。因此建議使用某些工具以便不丟失訓練進度:
- 因此,在使用者的程式中,必須定期儲存檢查點檔案,並在程式開始時恢復。如果 "tf.keras.optimizers.Optimizer" 被應用 checkpoint,在從檢查點恢復後,其 "iterations" 屬性會大致顯示已經進行的步驟數。這可以用來決定在訓練完成前還需要多少個 epochs 和步驟(steps)。
- 對於 Model.fit,你應該使用 BackupAndRestore 回撥,它可以自動處理進度的儲存和恢復。
- 對於一個自定義的訓練迴圈,你應該定期檢查模型變數,並在訓練開始前從檢查點(如果有的話)載入模型變數。如果優化器有檢查點,訓練進度可以從 optimizer.iterations 中大致推斷出來。
checkpoint_manager = tf.train.CheckpointManager(
tf.train.Checkpoint(model=model, optimizer=optimizer),
checkpoint_dir,
max_to_keep=3)
if checkpoint_manager.latest_checkpoint:
checkpoint = checkpoint_manager.checkpoint
checkpoint.restore(
checkpoint_manager.latest_checkpoint).assert_existing_objects_matched()
global_steps = int(optimizer.iterations.numpy())
starting_epoch = global_steps // steps_per_epoch
for _ in range(starting_epoch, num_epoches):
for _ in range(steps_per_epoch):
coordinator.schedule(step_fn, args=(per_worker_iterator,))
coordinator.join()
checkpoint_manager.save()
8.4 返回 RemoteValue
如果一個函式被成功執行,就可以成功獲取到 RemoteValue。這是因為目前在執行完一個函式後,返回值會立即被複制到協調器。如果在複製過程中出現任何工作者故障,該函式將在另一個可用的工作者上重試。因此,如果你想優化效能,你可以安排(schedule)一個沒有返回值的函式。
8.5 錯誤報告
一旦協調器發現一個錯誤,如來自引數伺服器的 UnavailableError 或其他應用錯誤,如來自 tf.debugging.check_numerics 的 InvalidArgument,它將在引發錯誤之前取消所有 pending 和排隊(queued)的函式。獲取它們相應的 RemoteValue 將引發一個 CancelledError 。
在引發錯誤後,協調器將不會引發相同的錯誤或任何引發一個來自已取消函式的錯誤。
ClusterCoordinator 假設所有的函式錯誤都是致命的,基於這個假設,其的錯誤報告邏輯是:
- Schedule 和 join 都可以引發一個不可重試的錯誤,這是協調者從任何先前安排的函式中看到的第一個錯誤。
- 當一個錯誤被丟擲時,不保證有多少先前安排的功能被執行;沒有被執行的功能將被丟棄並被標記為取消。
- 在一個錯誤被丟擲後,錯誤的內部狀態將被清除。
8.6 WorkerPreemptionHandler
WorkerPreemptionHandler 是處理失敗的主要模組,其定義如下:
class WorkerPreemptionHandler(object):
"""Handles worker preemptions."""
def __init__(self, server_def, cluster):
self._server_def = server_def
self._cluster = cluster
self._cluster_update_lock = threading.Lock()
self._cluster_due_for_update_or_finish = threading.Event()
self._worker_up_cond = threading.Condition(self._cluster_update_lock)
self._error_from_recovery = None
self._should_preemption_thread_run = True
self._preemption_handler_thread = threading.Thread(
target=self._preemption_handler,
name="WorkerPreemptionHandler",
daemon=True)
self._preemption_handler_thread.start()
8.6.1 配置
在 Cluster 生成時,會把 WorkerPreemptionHandler 配置進來。
self.failure_handler = WorkerPreemptionHandler(context.get_server_def(), self)
8.6.2 等待
在處理 closure 時,會用 wait_on_failure 包裹一層用來處理錯誤。
def _process_closure(self, closure):
"""Runs a closure with preemption handling."""
assert closure is not None
try:
with self._cluster.failure_handler.wait_on_failure(
on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),
on_recovery_fn=self._set_resources_aborted,
worker_device_name=self.device_name):
closure.execute_on(self)
WorkerPreemptionHandler 的 wait_on_failure 方法如下:
@contextlib.contextmanager
def wait_on_failure(self,
on_failure_fn=None,
on_transient_failure_fn=None,
on_recovery_fn=None,
worker_device_name="(unknown)"):
"""Catches worker preemption error and wait until failed workers are back.
Args:
on_failure_fn: an optional function to run if preemption happens.
on_transient_failure_fn: an optional function to run if transient failure
happens.
on_recovery_fn: an optional function to run when a worker is recovered
from preemption.
worker_device_name: the device name of the worker instance that is passing
through the failure.
Yields:
None.
"""
try:
yield
except (errors.OpError, InputError) as e:
# If the error is due to temporary connectivity issues between worker and
# ps, put back closure, ignore error and do not mark worker as failure.
if self._cluster._record_and_ignore_transient_ps_failure(e):
if on_transient_failure_fn:
on_transient_failure_fn()
return
# Ignoring derived CancelledErrors to tolerate transient failures in
# PS-worker communication, which initially exposed as an UnavailableError
# and then lead to sub-function cancellation, subsequently getting
# reported from worker to chief as CancelledError.
# We do not mark either worker or PS as failed due to only CancelledError.
# If there are real (non-transient) failures, they must also be reported
# as other errors (UnavailableError most likely) in closure executions.
if isinstance(e, errors.CancelledError) and "/job:" in str(e):
if on_transient_failure_fn:
on_transient_failure_fn()
return
# This reraises the error, if it's not considered recoverable; otherwise,
# the following failure recovery logic run. At this time, only worker
# unavailability is recoverable. PS unavailability as well as other
# errors in the user function is not recoverable.
self._validate_preemption_failure(e)
if on_failure_fn:
on_failure_fn()
with self._cluster_update_lock:
self._cluster_due_for_update_or_finish.set()
self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
if self._error_from_recovery:
try:
raise self._error_from_recovery
finally:
self._error_from_recovery = None
if on_recovery_fn:
with self.wait_on_failure(
on_recovery_fn=on_recovery_fn,
on_transient_failure_fn=on_transient_failure_fn,
worker_device_name=worker_device_name):
on_recovery_fn()
_validate_preemption_failure 定義如下:
def _validate_preemption_failure(self, e):
"""Validates that the given exception represents worker preemption."""
# Only categorize the failure as a worker preemption if the cancellation
# manager did not attempt to cancel the blocking operations.
if _is_worker_failure(e) and (
not self._cluster._closure_queue._cancellation_mgr.is_cancelled):
return
raise e
8.6.3 handler
WorkerPreemptionHandler 有一個後臺執行緒 _preemption_handler_thread。
self._preemption_handler_thread = threading.Thread(
target=self._preemption_handler,
name="WorkerPreemptionHandler",
daemon=True)
self._preemption_handler_thread.start()
_preemption_handler 會進行必要的錯誤處理。
def _preemption_handler(self):
"""A loop that handles preemption.
This loop waits for signal of worker preemption and upon worker preemption,
it waits until all workers are back and updates the cluster about the
restarted workers.
"""
assert self._should_preemption_thread_run
while True:
self._cluster_due_for_update_or_finish.wait()
if not self._should_preemption_thread_run:
break
with self._cluster_update_lock:
try:
context.context().update_server_def(self._server_def)
# Cluster updated successfully, clear the update signal, and notify
# all workers that they are recovered from failure.
self._worker_up_cond.notify_all()
# The check for _should_preemption_thread_run is necessary since the
# stop may have already set _cluster_due_for_update_or_finish.
if self._should_preemption_thread_run:
self._cluster_due_for_update_or_finish.clear()
except Exception as e:
try:
self._validate_preemption_failure(e)
except Exception as ps_e:
# In this case, a parameter server fails. So we raise this error to
# the caller of wait_on_failure .
self._error_from_recovery = ps_e
self._worker_up_cond.notify_all()
if self._should_preemption_thread_run:
self._cluster_due_for_update_or_finish.clear()
# NOTE: Since the first RPC (GetStatus) of update_server_def is
# currently blocking by default, error should only happen if:
# (1) More workers failed while waiting for the previous workers to
# come back;
# (2) Worker failed when exchanging subsequent RPCs after the first
# RPC returns.
# Consider adding backoff retry logic if we see the error logged
# too frequently.
9. 總結
依據前面的程式碼,我們總結出來問題點如下:
-
Worker 如何知道使用哪些裝置?答案是:在叢集建立工作者時候,會給每一個工作者設定一個裝置。
-
如何具體執行使用者函式?答案是:在工作者執行 Closure 時候,會在指定執行在本工作者裝置上,然後執行指定的方法(Self._function)。Self._function 是使用者自定義的 function,其中可以使用 strategy.run 把訓練方法分發到遠端工作者進行訓練。
-
如何獲取資料?答案是:為每個工作者建立一個 PerWorkerValues,PerWorkerValues 是一個容納 value 列表的容器,每個工作者從對應 PerWorkerValues 之中獲取資料。