[原始碼解析] TensorFlow 分散式之 ClusterCoordinator

羅西的思考發表於2022-05-21

[原始碼解析] TensorFlow 分散式之 ClusterCoordinator

本文我們主要來看看ParameterServerStrategy如何分發計算,也就是ClusterCoordinator如何運作。這是TF分散式的最後一篇。

安利兩個github,都是非常好的學習資料,推薦。

https://github.com/yuhuiaws/ML-study

https://github.com/Jack47/hack-SysML

另外推薦西門宇少的最新大作讓Pipeline在Transformer LM上沿著Token level並行起來——TeraPipe

本系列其他文章是:

[翻譯] TensorFlow 分散式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"

[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構

[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache

[原始碼解析] TensorFlow 分散式環境(5) --- Session

[原始碼解析] TensorFlow 分散式環境(7) --- Worker 動態邏輯

[原始碼解析] TensorFlow 分散式環境(8) --- 通訊機制

[翻譯] 使用 TensorFlow 進行分散式訓練

[原始碼解析] TensorFlow 分散式 DistributedStrategy 之基礎篇

[原始碼解析] TensorFlow 之 分散式變數

[原始碼解析] TensorFlow 分散式之 MirroredStrategy

[原始碼解析] TensorFlow 分散式之 MirroredStrategy 分發計算

[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V1

[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V2

1. 思路

TensorFlow 2 推薦使用一種基於中央協調的架構來進行引數伺服器訓練。每個工作者和引數伺服器都執行一個 tf.distribution.Server,在此基礎上,一個協調者任務負責在工作者和引數伺服器上建立資源,排程功能,並協調訓練。協調器使用 tf.distribution.experimental.coordinator.ClusterCoordinator 來協調叢集,使用 tf.distribution.experimental.ParameterServerStrategy 來定義引數伺服器上的變數和工作者的計算。

ClusterCoordinator 是一個用於安排和協調遠端函式執行的物件。該類用於建立容錯(fault-tolerant)資源和排程函式到遠端 TensorFlow 伺服器。目前該類不支援獨立使用,它應該與旨在與之合作的 tf.distribution 策略一起使用。ClusterCoordinator 類目前只適用於和 tf.distribution.experimental.ParameterServerStrategy 一起工作。

1.1 使用

在使用 ParameterServerStrategy 定義所有的計算後,使用者可以使用 tf.distribution.experimental.coordinator.ClusterCoordinator 類來建立資源並將訓練步驟分配給遠端工作者。

首先,我們來建立一個 ClusterCoordinator 物件並傳入策略物件。

strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)

其次,創一個屬於每個工作者(per-worker)的資料集和一個迭代器。在下面程式碼的 per_worker_dataset_fn 中,建議將 dataset_fn 包裹到 strategy.distribution_datasets_from_function 中,以允許無縫高效的把資料預取(prefetching )到 GPU。

@tf.function
def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(dataset_fn)

per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)

最後一步是使用 ClusterCoordinator.schedule 將計算分配給遠端工作者。

  • schedule 方法把一個 tf.function 插入佇列,並立即返回一個 future-like 的 RemoteValue 。佇列之中的函式將被派發給後臺執行緒中的遠端工作者,RemoteValue 將被非同步填充結果。
  • 使用者可以使用 join 方法( ClusterCoordinator.join )來等待所有被規劃(scheduled)的函式執行。
@tf.function
def step_fn(iterator):
	return next(iterator)

num_epoches = 4
steps_per_epoch = 5
for i in range(num_epoches):
  accuracy.reset_states()
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))
  # Wait at epoch boundaries.
  coordinator.join()
  print ("Finished epoch %d, accuracy is %f." % (i, accuracy.result().numpy()))

下面是如何得到 RemoteValue 的結果。

loss = coordinator.schedule(step_fn, args=(per_worker_iterator,))
print ("Final loss is %f" % loss.fetch())

使用者也可以啟動所有的步驟(steps),並在等待完成時做一些事情。

for _ in range(total_steps):
  coordinator.schedule(step_fn, args=(per_worker_iterator,))
while not coordinator.done():
  time.sleep(10)
  # Do something like logging metrics or writing checkpoints.

1.2 問題點

依據前面的程式碼,我們總結出來問題點如下:

  • Worker 如何知道使用哪些裝置?
  • 如何具體執行使用者函式?
  • 如何獲取資料?

接下來我們就嘗試通過分析程式碼來回答這些問題。

2. 定義

ClusterCoordinator 的主要思路如下。

  • 協調者不是訓練工作者之一,相反,它負責建立資源,如變數和資料集,排程 "tf.function",儲存檢查點等等。
  • 為了使訓練工作順利進行,協調者派遣 "tf.function" 在遠端工作者上執行。
  • 在收到協調者的請求後,工作者通過從引數伺服器讀取變數、執行操作和更新引數伺服器上的變數來執行 "tf.function"。
  • 每個工作者只處理來自協調者的請求,並與引數伺服器進行通訊。而不與叢集中的其他工作者直接互動。

ClusterCoordinator 定義具體如下,我們可以看到,其主要是配置了 _strategy 成員變數,生成了 _cluster 成員變數。

@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
class ClusterCoordinator(object):
    
  def __new__(cls, strategy):
    #  ClusterCoordinator  is kept as a single instance to a given  Strategy .
    if strategy._cluster_coordinator is None:
      strategy._cluster_coordinator = super(
          ClusterCoordinator, cls).__new__(cls)
    return strategy._cluster_coordinator

  def __init__(self, strategy):
    """Initialization of a  ClusterCoordinator  instance.

    Args:
      strategy: a supported  tf.distribute.Strategy  object. Currently, only
         tf.distribute.experimental.ParameterServerStrategy  is supported.

    Raises:
      ValueError: if the strategy being used is not supported.
    """
    if not getattr(self, "_has_initialized", False):
      if not isinstance(strategy,
                        parameter_server_strategy_v2.ParameterServerStrategyV2):
        raise ValueError(
            "Only  tf.distribute.experimental.ParameterServerStrategy  "
            "is supported to work with "
            " tf.distribute.experimental.coordinator.ClusterCoordinator  "
            "currently.")
      self._strategy = strategy
      self.strategy.extended._used_with_coordinator = True
      self._cluster = Cluster(strategy)
      self._has_initialized = True

  def __del__(self):
    self._cluster.stop()

  @property
  def strategy(self):
    """Returns the  Strategy  associated with the  ClusterCoordinator ."""
    return self._strategy

2.1 Schedule

由 ClusterCoordinator 物件提供的最重要的 API 是 schedule,其會分派 tf.function 到一個工作者,以便非同步執行,具體如下:

  • 該方法是非阻塞的,因為它把 fn 插入佇列,並立即返回 tf.distribution.experimental.coordinator.RemoteValue 物件。fn 排隊等待稍後執行。
  • 在佇列之中排隊的函式將被派發給後臺執行緒中的遠端工作者來非同步執行,他們的 RemoteValue 將被非同步賦值。
  • 由於 schedule 不需要分配一個工作者,傳遞進來的 tf.function 可以在任何可用的工作者上執行。
  • 可以呼叫 fetch 來等待函式執行完成,並從遠端工作者那裡獲取其輸出。另一方面,也可以呼叫 tf.distribution.experimental.coordinator.ClusterCoordinator.join 來等待所有預定的函式完成。

失敗和容錯的策略如下:

  • 由於工作者在執行函式的任何時候都可能失敗,所以函式有可能被部分執行,但是 tf.distribution.experimental.coordinator.ClusterCoordinator 保證在這些事件中,函式最終將在任何可用的工作者上執行。
  • schedule 保證 fn 至少在工作者上執行一次;如果其對應的工作者在執行過程中失敗,由於函式的執行不是原子性的,所以一個函式可能被執行多次。
  • 如果被執行的工作者在結束之前變得不可用,該函式將在另一個可用的工作者上重試。
  • 如果任何先前安排的函式出現錯誤,schedule 將丟擲其中任何一個錯誤,並清除到目前為止收集的錯誤。使用者可以在返回的 tf.distribution.experimental.coordinator.RemoteValue 上呼叫 fetch 來檢查它們是否已經執行、失敗或取消,如果需要,可以重新安排相應的函式。當 schedule 引發異常時,它保證沒有任何函式仍在執行。

Schedule 的具體定義如下,資料迭代器作為引數之一會和 fn 一起被傳入。

  def schedule(self, fn, args=None, kwargs=None):
    """Schedules  fn  to be dispatched to a worker for asynchronous execution.

    This method is non-blocking in that it queues the  fn  which will be
    executed later and returns a 
     tf.distribute.experimental.coordinator.RemoteValue  object immediately.
     fetch  can be called on it to wait for the function execution to finish
    and retrieve its output from a remote worker. On the other hand, call
     tf.distribute.experimental.coordinator.ClusterCoordinator.join  to wait for
    all scheduled functions to finish.

     schedule  guarantees that  fn  will be executed on a worker at least once;
    it could be more than once if its corresponding worker fails in the middle
    of its execution. Note that since worker can fail at any point when
    executing the function, it is possible that the function is partially
    executed, but  tf.distribute.experimental.coordinator.ClusterCoordinator 
    guarantees that in those events, the function will eventually be executed on
    any worker that is available.

    If any previously scheduled function raises an error,  schedule  will raise
    any one of those errors, and clear the errors collected so far. What happens
    here, some of the previously scheduled functions may have not been executed.
    User can call  fetch  on the returned
     tf.distribute.experimental.coordinator.RemoteValue  to inspect if they have
    executed, failed, or cancelled, and reschedule the corresponding function if
    needed.

    When  schedule  raises, it guarantees that there is no function that is
    still being executed.

    At this time, there is no support of worker assignment for function
    execution, or priority of the workers.

     args  and  kwargs  are the arguments passed into  fn , when  fn  is
    executed on a worker. They can be
     tf.distribute.experimental.coordinator.PerWorkerValues  and in this case,
    the argument will be substituted with the corresponding component on the
    target worker. Arguments that are not
     tf.distribute.experimental.coordinator.PerWorkerValues  will be passed into
     fn  as-is. Currently,  tf.distribute.experimental.coordinator.RemoteValue 
    is not supported to be input  args  or  kwargs .

    Args:
      fn: A  tf.function ; the function to be dispatched to a worker for
        execution asynchronously. Regular python funtion is not supported to be
        scheduled.
      args: Positional arguments for  fn .
      kwargs: Keyword arguments for  fn .

    Returns:
      A  tf.distribute.experimental.coordinator.RemoteValue  object that
      represents the output of the function scheduled.

    Raises:
      Exception: one of the exceptions caught by the coordinator from any
        previously scheduled function, since the last time an error was thrown
        or since the beginning of the program.
    """
    if not isinstance(fn,
                      (def_function.Function, tf_function.ConcreteFunction)):
      raise TypeError(
          " tf.distribute.experimental.coordinator.ClusterCoordinator.schedule "
          " only accepts a  tf.function  or a concrete function.")
    # Slot variables are usually created during function tracing time; thus
    #  schedule  needs to be called within the  strategy.scope() .
    with self.strategy.scope():
      self.strategy.extended._being_scheduled = True  
      remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs)
      self.strategy.extended._being_scheduled = False  
      return remote_value

2.2 Join

Join 方法的作用是阻塞直到所有預定的函式都執行完畢,其具體特點如下:

  • 如果任何先前安排的函式產生錯誤,join 將因為丟擲一個錯誤而失敗,並清除到目前為止收集的錯誤。如果發生這種情況,一些先前安排的函式可能沒有被執行。
  • 使用者可以對返回的 tf.distribution.experimental.coordinator.RemoteValue 呼叫 fetch 來檢查它們是否已經執行、失敗或取消了。
  • 如果一些已經取消的函式需要重新安排,使用者應該再次呼叫 schedule 。
  • 當 join 返回或丟擲異常時,它保證沒有任何函式仍在執行。
  def join(self):
    """Blocks until all the scheduled functions have finished execution.

    If any previously scheduled function raises an error,  join  will fail by
    raising any one of those errors, and clear the errors collected so far. If
    this happens, some of the previously scheduled functions may have not been
    executed. Users can call  fetch  on the returned
     tf.distribute.experimental.coordinator.RemoteValue  to inspect if they have
    executed, failed, or cancelled. If some that have been cancelled need to be
    rescheduled, users should call  schedule  with the function again.

    When  join  returns or raises, it guarantees that there is no function that
    is still being executed.

    Raises:
      Exception: one of the exceptions caught by the coordinator by any
        previously scheduled function since the last time an error was thrown or
        since the beginning of the program.
    """
    self._cluster.join()

2.3 Done

Done 方法返回所有分發的函式是否已經執行完畢。如果任何先前分發的函式引發錯誤,done'將會失敗。

  def done(self):
    """Returns whether all the scheduled functions have finished execution.

    If any previously scheduled function raises an error,  done  will fail by
    raising any one of those errors.

    When  done  returns True or raises, it guarantees that there is no function
    that is still being executed.

    Returns:
      Whether all the scheduled functions have finished execution.
    Raises:
      Exception: one of the exceptions caught by the coordinator by any
        previously scheduled function since the last time an error was thrown or
        since the beginning of the program.
    """
    return self._cluster.done()

2.4 Fetch

Fetch 會獲取 remote values 的結果。

  def fetch(self, val):
    """Blocking call to fetch results from the remote values.

    This is a wrapper around
     tf.distribute.experimental.coordinator.RemoteValue.fetch  for a
     RemoteValue  structure; it returns the execution results of
     RemoteValue s. If not ready, wait for them while blocking the caller.

    Example:
    ```python
    strategy = ...
    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
        strategy)

    def dataset_fn():
      return tf.data.Dataset.from_tensor_slices([1, 1, 1])

    with strategy.scope():
      v = tf.Variable(initial_value=0)

    @tf.function
    def worker_fn(iterator):
      def replica_fn(x):
        v.assign_add(x)
        return v.read_value()
      return strategy.run(replica_fn, args=(next(iterator),))

    distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
    distributed_iterator = iter(distributed_dataset)
    result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
    assert coordinator.fetch(result) == 1
    ```

    Args:
      val: The value to fetch the results from. If this is structure of
         tf.distribute.experimental.coordinator.RemoteValue ,  fetch()  will be
        called on the individual
         tf.distribute.experimental.coordinator.RemoteValue  to get the result.

    Returns:
      If  val  is a  tf.distribute.experimental.coordinator.RemoteValue  or a
      structure of  tf.distribute.experimental.coordinator.RemoteValue s,
      return the fetched  tf.distribute.experimental.coordinator.RemoteValue 
      values immediately if they are available, or block the call until they are
      available, and return the fetched
       tf.distribute.experimental.coordinator.RemoteValue  values with the same
      structure. If  val  is other types, return it as-is.
    """

    def _maybe_fetch(val):
      if isinstance(val, RemoteValue):
        return val.fetch()
      else:
        return val

    return nest.map_structure(_maybe_fetch, val)

3. 資料

除了排程遠端函式,ClusterCoordinator 還幫助在所有工作者上建立資料集,並當一個工作者從失敗中恢復時重建這些資料集。使用者可以通過呼叫 dataset_fn 來在worker裝置上建立資料集。使用例子如下:

strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
    strategy=strategy)

@tf.function
def worker_fn(iterator):
  return next(iterator)

def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(
      lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))

per_worker_dataset = coordinator.create_per_worker_dataset(
    per_worker_dataset_fn)
per_worker_iter = iter(per_worker_dataset)
remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
assert remote_value.fetch() == 3

3.1 建立資料集

上面程式碼使用了 create_per_worker_dataset 在worker上建立資料集,這些資料集由 dataset_fn 生成,並返回一個代表這些資料集的集合。在這樣的資料集集合上呼叫 iter 會返回一個 tf.distribution.experimental.coordinator.PerWorkerValues,它是一個迭代器的集合,其中的迭代器已經被放置在各個工作者上。

需要注意,不支援在迭代器的 "PerWorkerValues"上直接呼叫 "next"。該迭代器應該是作為一個引數傳遞給 tf.distribution.experimental.coordinator.ClusterCoordinator.schedule 。當計劃的函式即將被工作者執行時,該函式將收到與該工作者相對應的單個迭代器。該函式可以對該迭代器呼叫 next 方法。

目前,schedule 方法假定工作者都是相同的,因此假設不同工作者上的資料集是一樣的,除非它們包含 dataset.shuffle 操作,並且沒有設定隨機種子,在這種情況下,它們的洗牌方式會不同。正因為如此,建議將資料集無限地重複,並安排有限的步驟,而不是依賴於資料集的 OutOfRangeError 來結束。

  def create_per_worker_dataset(self, dataset_fn):
    """Create dataset on workers by calling  dataset_fn  on worker devices.

    This creates the given dataset generated by dataset_fn on workers
    and returns an object that represents the collection of those individual
    datasets. Calling  iter  on such collection of datasets returns a
     tf.distribute.experimental.coordinator.PerWorkerValues , which is a
    collection of iterators, where the iterators have been placed on respective
    workers.

    Calling  next  on a  PerWorkerValues  of iterator is unsupported. The
    iterator is meant to be passed as an argument into
     tf.distribute.experimental.coordinator.ClusterCoordinator.schedule . When
    the scheduled function is about to be executed by a worker, the
    function will receive the individual iterator that corresponds to the
    worker. The  next  method can be called on an iterator inside a
    scheduled function when the iterator is an input of the function.

    Currently the  schedule  method assumes workers are all the same and thus
    assumes the datasets on different workers are the same, except they may be
    shuffled differently if they contain a  dataset.shuffle  operation and a
    random seed is not set. Because of this, we also recommend the datasets to
    be repeated indefinitely and schedule a finite number of steps instead of
    relying on the  OutOfRangeError  from a dataset.

    Args:
      dataset_fn: The dataset function that returns a dataset. This is to be
        executed on the workers.

    Returns:
      An object that represents the collection of those individual
      datasets.  iter  is expected to be called on this object that returns
      a  tf.distribute.experimental.coordinator.PerWorkerValues  of the
      iterators (that are on the workers).
    """
    return values_lib.get_per_worker_dataset(dataset_fn, self)

get_per_worker_dataset 則返回 PerWorkerDatasetFromDataset 或者 PerWorkerDatasetFromDatasetFunction。

def get_per_worker_dataset(dataset_or_dataset_fn, coordinator):
  if callable(dataset_or_dataset_fn):
    return PerWorkerDatasetFromDatasetFunction(dataset_or_dataset_fn,
                                               coordinator)
  else:
    return PerWorkerDatasetFromDataset(dataset_or_dataset_fn, coordinator)

3.2 PerWorkerDistributedDataset

PerWorkerDistributedDataset 代表了從一個資料集建立的工作者使用的分散式資料集。

class PerWorkerDatasetFromDataset(PerWorkerDatasetFromDatasetFunction):
  """Represents worker-distributed datasets created from a dataset."""

  def __init__(self, dataset, coordinator):
    """Makes an iterable from datasets created by the given dataset.

    It creates a dataset_fn which deserializes a dataset from a graph under the
    hood.

    Args:
      dataset: A tf.data.Dataset, a DistributedDataset or a
        DistributedDatasetsFromFunction
      coordinator: a  ClusterCoordinator  object, used to create dataset
        resources.
    """
    if isinstance(dataset, input_lib.DistributedDataset):
      original_dataset = dataset._original_dataset
      serialized = serialize_dataset_to_graph(original_dataset)

      def dataset_fn():
        deserialized = deserialize_dataset_from_graph(
            serialized, original_dataset.element_spec)
        dataset.build(dataset_to_replace=deserialized)
        return dataset
      
    elif isinstance(dataset, input_lib.DistributedDatasetsFromFunction):
      def dataset_fn():
        dataset.build()
        return dataset
      
    elif isinstance(dataset, dataset_ops.Dataset):
      serialized = serialize_dataset_to_graph(dataset)

      def dataset_fn():
        return deserialize_dataset_from_graph(serialized, dataset.element_spec)
      
    else:
      raise ValueError("Unexpected dataset type!")

    super(PerWorkerDatasetFromDataset, self).__init__(dataset_fn, coordinator)

3.3 PerWorkerDatasetFromDatasetFunction

PerWorkerDistributedDataset 代表了從一個資料集方法建立的工作者使用的分散式資料集。

iter 之中有:

  • 呼叫 _create_per_worker_iterator 得到一個 iter(dataset)。

  • 呼叫 self._coordinator._create_per_worker_resources 為每工作者生成一個 iterator。

  • 最後返回一個 PerWorkerDistributedIterator。

class PerWorkerDatasetFromDatasetFunction(object):
  """Represents worker-distributed datasets created from dataset function."""

  def __init__(self, dataset_fn, coordinator):
    """Makes an iterable from datasets created by the given function.

    Args:
      dataset_fn: A function that returns a  Dataset .
      coordinator: a  ClusterCoordinator  object, used to create dataset
        resources.
    """

    def disallow_variable_creation(next_creator, **kwargs):
      raise ValueError("Creating variables in  dataset_fn  is not allowed.")

    if isinstance(dataset_fn, def_function.Function):
      with variable_scope.variable_creator_scope(disallow_variable_creation):
        dataset_fn = dataset_fn.get_concrete_function()
    elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
      with variable_scope.variable_creator_scope(disallow_variable_creation):
        dataset_fn = def_function.function(dataset_fn).get_concrete_function()
    self._dataset_fn = dataset_fn
    self._coordinator = coordinator
    self._element_spec = None

  def __iter__(self):
    # We would like users to create iterators outside  tf.function s so that we
    # can track them.
    if (not context.executing_eagerly() or
        ops.get_default_graph().building_function):
      raise RuntimeError(
          "__iter__() is not supported inside of tf.function or in graph mode.")

    def _create_per_worker_iterator():
      dataset = self._dataset_fn()
      return iter(dataset)

    # If PerWorkerDatasetFromDatasetFunction.__iter__ is called multiple
    # times, for the same object it should only create and register resource
    # once. Using object id to distinguish different iterator resources.
    per_worker_iterator = self._coordinator._create_per_worker_resources(
        _create_per_worker_iterator)

    # Setting type_spec of each RemoteValue so that functions taking these
    # RemoteValues as inputs can be traced.
    for iterator_remote_value in per_worker_iterator._values:
      iterator_remote_value._type_spec = (
          input_lib.get_iterator_spec_from_dataset(
              self._coordinator.strategy, self._dataset_fn.structured_outputs))

    return PerWorkerDistributedIterator(per_worker_iterator._values)

  @property
  def element_spec(self):
    """The type specification of an element of this dataset.

    This property is subject to change without notice.
    """
    return self._dataset_fn.structured_outputs.element_spec

3.4 _create_per_worker_resources

_create_per_worker_resources 會呼叫各個工作者的方法來讓每個工作者得到資料。

def _create_per_worker_resources(self, fn, args=None, kwargs=None):
  """Synchronously create resources on the workers.

  The resources are represented by
   tf.distribute.experimental.coordinator.RemoteValue s.

  Args:
    fn: The function to be dispatched to all workers for execution
      asynchronously.
    args: Positional arguments for  fn .
    kwargs: Keyword arguments for  fn .

  Returns:
    A  tf.distribute.experimental.coordinator.PerWorkerValues  object, which
    wraps a tuple of  tf.distribute.experimental.coordinator.RemoteValue 
    objects.
  """
  results = []
  for w in self._cluster.workers:
    results.append(w.create_resource(fn, args=args, kwargs=kwargs))  
  return PerWorkerValues(tuple(results))

3.5 PerWorkerValues

PerWorkerValues 是一個容納 value 列表的容器,每個工作者對應一個 value。Tf.distribution.experimental.coordinator.PerWorkerValues 包含一個值的集合,其中每個值都位於其相應的工作者上,當被用作 tf.distribution.experimental.coordinator.ClusterCoordinator.schedule() 的 args 或 kwargs 時,某一個工作者的特定值將被傳遞到該工作者上執行的函式中。

建立 tf.distribution.experimental.coordinator.PerWorkerValues 物件的唯一路徑是通過在 ClusterCoordinator.create_per_worker_dataset 返回的分散式資料集例項上呼叫 iter 。目前還不支援建立自定義 tf.distribution.experimental.coordinator.PerWorkerValues 的機制。

@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
class PerWorkerValues(composite_tensor.CompositeTensor):
  """A container that holds a list of values, one value per worker.

   tf.distribute.experimental.coordinator.PerWorkerValues  contains a collection
  of values, where each of the values is located on its corresponding worker,
  and upon being used as one of the  args  or  kwargs  of
   tf.distribute.experimental.coordinator.ClusterCoordinator.schedule() , the
  value specific to a worker will be passed into the function being executed at
  that corresponding worker.

  Currently, the only supported path to create an object of
   tf.distribute.experimental.coordinator.PerWorkerValues  is through calling
   iter  on a  ClusterCoordinator.create_per_worker_dataset -returned
  distributed dataset instance. The mechanism to create a custom
   tf.distribute.experimental.coordinator.PerWorkerValues  is not yet supported.
  """

  def __init__(self, values):
    for v in values:
      if not isinstance(v, RemoteValue):
        raise AssertionError(
            " PerWorkerValues  should only take  RemoteValue s.")
    self._values = tuple(values)

  @property
  def _type_spec(self):
    return PerWorkerValuesTypeSpec(
        self._values[0]._type_spec,  
        type(self))

獲取資料的邏輯如下:

4. Cluster

Cluster 才是業務執行者。

4.1 定義

Cluster 是一個工作者叢集。在初始化方法之中,會做如下處理:

  • 設定如何忽略引數伺服器暫時錯誤。
  • 設定工作者的裝置名字。
  • 生成一系列工作者。

這裡要注意的是如何忽略因為工作者瞬時連線錯誤而報告的故障。

  • 工作者和引數伺服器之間的瞬時連線問題會由工作者轉達給協調者,這將導致協調者認為存在引數伺服器故障。
  • 瞬時與永久的引數伺服器故障之間的區別是工作者報告的數量。當這個環境變數設定為正整數 K 時,協調器忽略最多 K 個失敗報告,也就是說,只有超過 K 個執行錯誤,並且這些錯誤是因為同一個引數伺服器例項導致的,我們才認為引數伺服器例項遇到了失敗。
class Cluster(object):
  """A cluster with workers.

  We assume all function errors are fatal and based on this assumption our
  error reporting logic is:
  1) Both  schedule  and  join  can raise a non-retryable error which is the
  first error seen by the coordinator from any previously scheduled functions.
  2) When an error is raised, there is no guarantee on how many previously
  scheduled functions have been executed; functions that have not been executed
  will be thrown away and marked as cancelled.
  3) After an error is raised, the internal state of error will be cleared.
  I.e. functions can continue to be scheduled and subsequent calls of  schedule 
  or  join  will not raise the same error again.

  Attributes:
    failure_handler: The failure handler used to handler worker preemption
      failure.
    workers: a list of  Worker  objects in the cluster.
  """

  def __init__(self, strategy):
    """Initializes the cluster instance."""

    self._num_workers = strategy._num_workers
    self._num_ps = strategy._num_ps

    # 如何忽略引數伺服器暫時錯誤
    self._transient_ps_failures_threshold = int(
        os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
    self._potential_ps_failures_lock = threading.Lock()
    self._potential_ps_failures_count = [0] * self._num_ps

    self._closure_queue = _CoordinatedClosureQueue()
    self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
                                                   self)
    
    # 設定 worker 的裝置名字
    worker_device_strings = [
        "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
    ]
    # 生成 Workers
    self.workers = [
        Worker(i, w, self) for i, w in enumerate(worker_device_strings)
    ]

4.2 Schedule

這個類提供的最重要的API是 "schedule"/"join" 這對函式。"schedule" API是非阻塞的,它把一個 "tf.function "插入佇列,並立即返回一個 "RemoteValue"。

  def schedule(self, function, args, kwargs):
    """Schedules  function  to be dispatched to a worker for execution.

    Args:
      function: The function to be dispatched to a worker for execution
        asynchronously.
      args: Positional arguments for  fn .
      kwargs: Keyword arguments for  fn .

    Returns:
      A  RemoteValue  object.
    """
    closure = Closure(
        function,
        self._closure_queue._cancellation_mgr, 
        args=args,
        kwargs=kwargs)
    self._closure_queue.put(closure)
    return closure.output_remote_value

  def join(self):
    """Blocks until all scheduled functions are executed."""
    self._closure_queue.wait()

具體邏輯如下,虛線表示資料集被傳入,這裡的 Queue 是 from six.moves import queue 引入的 queue.Queue,我們接下來在_CoordinatedClosureQueue之中會見到。

或者我們從官方文件圖來看,目前完成的是左邊圓圈部分。

4.3 停止

停止程式碼如下,具體是呼叫佇列的處理方法。

  def stop(self):
    """Stop worker, worker preemption threads, and the closure queue."""
    self.failure_handler.stop()

    for worker in self.workers:
      worker.stop()
    self._closure_queue.stop()

  def done(self):
    """Returns true if all scheduled functions are executed."""
    return self._closure_queue.done()


5. 任務 Closure

Closure 的作用是把任務封裝起來,並且提供了其他功能。

class Closure(object):
  """Hold a function to be scheduled and its arguments."""

  def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
    if not callable(function):
      raise ValueError("Function passed to  ClusterCoordinator.schedule  must "
                       "be a callable object.")
    self._args = args or ()
    self._kwargs = kwargs or {}

    _disallow_remote_value_as_input(self._args)
    _disallow_remote_value_as_input(self._kwargs)

    if isinstance(function, def_function.Function):
      replica_args = _select_worker_slice(0, self._args)
      replica_kwargs = _select_worker_slice(0, self._kwargs)

      # Note: no need to handle function registration failure since this kind of
      # failure will not raise exceptions as designed in the runtime. The
      # coordinator has to rely on subsequent operations that raise to catch
      # function registration failure.

      # Record the function tracing overhead. Note that we pass in the tracing
      # count of the def_function.Function as a state tracker, so that metrics
      # will only record the time for actual function tracing (i.e., excluding
      # function cache lookups).
      with metric_utils.monitored_timer(
          "function_tracing", state_tracker=function._get_tracing_count):  
        self._concrete_function = function.get_concrete_function(
            *nest.map_structure(_maybe_as_type_spec, replica_args),
            **nest.map_structure(_maybe_as_type_spec, replica_kwargs))
    elif isinstance(function, tf_function.ConcreteFunction):
      self._concrete_function = function

    if hasattr(self, "_concrete_function"):
      # If we have a concrete function, we get to retrieve the output type spec
      # via the structured_output.
      output_type_spec = func_graph.convert_structure_to_signature(
          self._concrete_function.structured_outputs)
      self._function = cancellation_mgr.get_cancelable_function(
          self._concrete_function)
    else:
      # Otherwise (i.e. what is passed in is a regular python function), we have
      # no such information.
      output_type_spec = None
      self._function = function

    self.output_remote_value = RemoteValueImpl(self, output_type_spec)

5.1 執行

Closure 的 execute_on 負責執行,具體是在指定的裝置上執行 self._function,就是使用者自定義的 function。需要注意的是,with context.executor_scope(worker.executor) 使用了 context。

  def execute_on(self, worker):
    """Executes the closure on the given worker.

    Args:
      worker: a Worker object.
    """
    replica_args = _select_worker_slice(worker.worker_index, self._args)
    replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)

    e = (
        _maybe_rebuild_remote_values(worker, replica_args) or
        _maybe_rebuild_remote_values(worker, replica_kwargs))
    if e:
      if not isinstance(e, InputError):
        e = InputError(e)
      self.output_remote_value._set_error(e) 
      return

    with ops.device(worker.device_name): # 在指定裝置上
      with context.executor_scope(worker.executor): # 通過上下文
        with metric_utils.monitored_timer("closure_execution"):
          output_values = self._function( # 執行使用者的引數
              *nest.map_structure(_maybe_get_remote_value, replica_args),
              **nest.map_structure(_maybe_get_remote_value, replica_kwargs))
    self.output_remote_value._set_values(output_values) 

Self._function 是使用者自定義的 function,我們再給出一個方法示例,可以看出來可以使用 strategy.run 把訓練方法分發到遠端工作者進行訓練。

@tf.function
def worker_fn(iterator):

	def replica_fn(inputs):
      batch_data, labels = inputs
      # calculate gradient, applying gradient, metrics update etc.

	strategy.run(replica_fn, args=(next(iterator),))

5.2 取消

使用者可以設定取消 Closure,就是在返回值之中做下設定。

  def mark_cancelled(self):
    self.output_remote_value._set_error(  
        errors.CancelledError(
            None, None, "The corresponding function is "
            "cancelled. Please reschedule the function."))

5.3 ResourceClosure

ResourceClosure 是派生類,把 Closure 用 RemoteValue 包裝起來。實際上使用的都是 ResourceClosure。

class ResourceClosure(Closure):

  def build_output_remote_value(self):
    if self._output_remote_value_ref is None:
      # We need to remember the Closure object in the  RemoteValue  here.
      ret = RemoteValueImpl(self, self._output_type_spec)
      self._output_remote_value_ref = weakref.ref(ret)
      return ret
    else:
      return self._output_remote_value_ref()

6. 佇列

_CoordinatedClosureQueue 是任務所在的佇列。

6.1 定義

from six.moves import queue

class _CoordinatedClosureQueue(object):
  """Manage a queue of closures, inflight count and errors from execution.

  This class is thread-safe.
  """

  def __init__(self):
    #  self._inflight_closure_count  only tracks the number of inflight closures
    # that are "in generation". Once an error occurs, error generation is
    # incremented and all subsequent arriving closures (from inflight) are
    # considered "out of generation".
    self._inflight_closure_count = 0

    self._queue_lock = threading.Lock()

    # Condition indicating that all pending closures (either queued or inflight)
    # have been processed, failed, or cancelled.
    self._stop_waiting_condition = threading.Condition(self._queue_lock)

    # Condition indicating that an item becomes available in queue (not empty).
    self._closures_queued_condition = threading.Condition(self._queue_lock)
    self._should_process_closures = True

    # Condition indicating that a queue slot becomes available (not full).
    # Note that even with "infinite" queue size, there is still a "practical"
    # size limit for the queue depending on host memory capacity, and thus the
    # queue will eventually become full with a lot of enqueued closures.
    self._queue_free_slot_condition = threading.Condition(self._queue_lock)

    # Condition indicating there is no inflight closures.
    self._no_inflight_closure_condition = threading.Condition(self._queue_lock)

    # Use to cancel in-flight closures.
    self._cancellation_mgr = cancellation.CancellationManager()

    self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
    self._error = None

    # The following is a lock to make sure when  wait  is called and before it
    # returns no  put  can be executed during this period. It is because  wait 
    # won't know what to do with newly put closures. This lock adds an cutoff
    # for  wait  so that closures put into the queue while waiting would not be
    # taken responsible by this  wait .
    #
    # We cannot reuse the  self._queue_lock  since when  wait  waits for a
    # condition, the  self._queue_lock  will be released.
    #
    # We don't use a reader/writer's lock on purpose to reduce the complexity
    # of the code.
    self._put_wait_lock = threading.Lock()


6.2 插入取出

Put 和 get 方法分別負責插入和取出。

  def put(self, closure):
    """Put a closure into the queue for later execution.

    If  mark_failed  was called before  put , the error from the first
    invocation of  mark_failed  will be raised.

    Args:
      closure: The  Closure  to put into the queue.
    """
    with self._put_wait_lock, self._queue_lock:
      self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
      self._queue.put(closure, block=False)
      self._raise_if_error()
      self._closures_queued_condition.notify()

  def get(self, timeout=None):
    """Return a closure from the queue to be executed."""
    with self._queue_lock:
      while self._queue.empty() and self._should_process_closures:
        if not self._closures_queued_condition.wait(timeout=timeout):
          return None
      if not self._should_process_closures:
        return None
      closure = self._queue.get(block=False)
      self._queue_free_slot_condition.notify()
      self._inflight_closure_count += 1
      return closure

Put_back 則負責把 closure 重新放回queue。

  def put_back(self, closure):
    """Put the closure back into the queue as it was not properly executed."""
    with self._queue_lock:
      if self._inflight_closure_count < 1:
        raise AssertionError("There is no inflight closures to put_back.")
      if self._error:
        closure.mark_cancelled()
      else:
        self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
        self._queue.put(closure, block=False)
        self._closures_queued_condition.notify()
      self._inflight_closure_count -= 1
      if self._inflight_closure_count == 0:
        self._no_inflight_closure_condition.notifyAll()

6.3 等待

方法 wait 會等待所有 closures 結束。

  def wait(self, timeout=None):
    """Wait for all closures to be finished before returning.

    If  mark_failed  was called before or during  wait , the error from the
    first invocation of  mark_failed  will be raised.

    Args:
      timeout: A float specifying a timeout for the wait in seconds.

    Returns:
      True unless the given timeout expired, in which case it returns False.
    """
    with self._put_wait_lock, self._queue_lock:
      while (not self._error and
             (not self._queue.empty() or self._inflight_closure_count > 0)):
        if not self._stop_waiting_condition.wait(timeout=timeout):
          return False
      self._raise_if_error()
      return True

6.4 異常&結束

Mark_failed 和 done 則是處理結束和異常的一套組合。

  def mark_failed(self, e):
    """Sets error and unblocks any wait() call."""
    with self._queue_lock:
      # TODO(yuefengz): maybe record all failure and give users more
      # information?
      if self._inflight_closure_count < 1:
        raise AssertionError("There is no inflight closures to mark_failed.")
      if self._error is None:
        self._error = e
      self._inflight_closure_count -= 1
      if self._inflight_closure_count == 0:
        self._no_inflight_closure_condition.notifyAll()
      self._stop_waiting_condition.notifyAll()

  def done(self):
    """Returns true if the queue is empty and there is no inflight closure.

    If  mark_failed  was called before  done , the error from the first
    invocation of  mark_failed  will be raised.
    """
    with self._queue_lock:
      self._raise_if_error()
      return self._queue.empty() and self._inflight_closure_count == 0


6.5 停止

Stop 和 _cancel_all_closures 負責暫停 closures。

  def stop(self):
    with self._queue_lock:
      self._should_process_closures = False
      self._closures_queued_condition.notifyAll()

  def _cancel_all_closures(self):
    """Clears the queue and sets remaining closures cancelled error.

    This method expects self._queue_lock to be held prior to entry.
    """
    self._cancellation_mgr.start_cancel()
    while self._inflight_closure_count > 0:
      self._no_inflight_closure_condition.wait()
    while True:
      try:
        closure = self._queue.get(block=False)
        self._queue_free_slot_condition.notify()
        closure.mark_cancelled()
      except queue.Empty:
        break
    # The cancellation manager cannot be reused once cancelled. After all
    # closures (queued or inflight) are cleaned up, recreate the cancellation
    # manager with clean state.
    # Note on thread-safety: this is triggered when one of theses
    # ClusterCoordinator APIs are called:  schedule ,  wait , and  done . At the
    # same time, no new closures can be constructed (which reads the
    # _cancellation_mgr to get cancellable functions).
    self._cancellation_mgr = cancellation.CancellationManager()

  def _raise_if_error(self):
    """Raises the error if one exists.

    If an error exists, cancel the closures in queue, raises it, and clear
    the error.

    This method expects self._queue_lock to be held prior to entry.
    """
    if self._error:
      logging.error("Start cancelling closures due to error %r: %s",
                    self._error, self._error)
      self._cancel_all_closures()
      try:
        raise self._error  
      finally:
        self._error = None

7.4 Worker

Worker 是函式的執行者。

7.1 定義

Worker 的定義如下,其啟動了一個執行緒來執行 _process_queue。

class Worker(object):
  """A worker in a cluster.

  Attributes:
    worker_index: The index of the worker in the cluster.
    device_name: The device string of the worker, e.g. "/job:worker/task:1".
    executor: The worker's executor for remote function execution.
    failure_handler: The failure handler used to handler worker preemption
      failure.
  """

  def __init__(self, worker_index, device_name, cluster):
    self.worker_index = worker_index
    self.device_name = device_name
    # 這裡會有一個executor
    self.executor = executor.new_executor(enable_async=False)
    self.failure_handler = cluster.failure_handler
    self._cluster = cluster
    self._resource_remote_value_refs = []
    self._should_worker_thread_run = True

    # Worker threads need to start after  Worker 's initialization.
    threading.Thread(target=self._process_queue,
                     name="WorkerClosureProcessingLoop-%d" % self.worker_index,
                     daemon=True).start()

New_executor 會呼叫 TFE_NewExecutor。

def new_executor(enable_async):
  handle = pywrap_tfe.TFE_NewExecutor(enable_async)
  return Executor(handle)

TFE_NewExecutor 定義在 tensorflow/c/eager/c_api_experimental.cc,其生成了 TFE_Executor。

TFE_Executor* TFE_NewExecutor(bool is_async) {
  return new TFE_Executor(is_async);
}

TFE_Executor 定義如下,Executor類是會話執行器的抽象,在 TF2 之中,也有 EagerExecutor。

struct TFE_Executor {
  explicit TFE_Executor(bool async)
      : owned_executor(new tensorflow::EagerExecutor(async)) {}

  explicit TFE_Executor(tensorflow::EagerExecutor* executor)
      : owned_executor(nullptr), unowned_executor(executor) {}

  tensorflow::EagerExecutor* executor() {
    return owned_executor == nullptr ? unowned_executor : owned_executor.get();
  }

  std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
  tensorflow::EagerExecutor* unowned_executor;
};

7.2 處理

_process_queue 方法會從 queue 之中取出 Closure,然後執行任務。

  def _process_queue(self):
    """Function running in a worker thread to process closure queues."""
    self._maybe_delay()
    while self._should_worker_thread_run:
      closure = self._cluster._closure_queue.get()  
      if not self._should_worker_thread_run or closure is None:
        return
      self._process_closure(closure)
      # To properly stop the worker and preemption threads, it is important that
      #  ClusterCoordinator  object is not held onto so its  __del__  can be
      # called. By removing the reference to the  closure  that has already been
      # processed, we ensure that the  closure  object is released, while
      # getting the next  closure  at above  self._cluster._closure_queue.get() 
      # call.
      del closure

7.2.1 等待

_process_queue 之中首先會呼叫 _maybe_delay 等待環境變數配置。

  def _maybe_delay(self):
    """Delay if corresponding env vars are set."""
    # If the following two env vars variables are set. Scheduling for workers
    # will start in a staggered manner. Worker i will wait for
    #  TF_COORDINATOR_SCHEDULE_START_DELAY  * i seconds, not exceeding
    #  TF_COORDINATOR_SCHEDULE_START_DELAY_MAX .
    delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0"))
    delay_cap = int(
        os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0"))
    if delay_cap:
      delay_secs = min(delay_secs * self.worker_index, delay_cap)
    if delay_secs > 0:
      logging.info("Worker %d sleeping for %d seconds before running function",
                   self.worker_index, delay_secs)
    time.sleep(delay_secs)

7.2.2 處理任務

_process_queue 之中接著會呼叫 _process_closure 來執行 closure。

  def _process_closure(self, closure):
    """Runs a closure with preemption handling."""
    try:
      with self._cluster.failure_handler.wait_on_failure(
          on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),  
          on_recovery_fn=self._set_resources_aborted,
          worker_device_name=self.device_name):
        closure.execute_on(self)
        with metric_utils.monitored_timer("remote_value_fetch"):
          # Copy the remote tensor to local (the coordinator) in case worker
          # becomes unavailable at a later time.
          closure.output_remote_value.get()
        self._cluster._closure_queue.mark_finished()  
    except Exception as e:  
      # Avoid logging the derived cancellation error
      if not isinstance(e, errors.CancelledError):
        logging.error(
            "/job:worker/task:%d encountered the following error when "
            "processing closure: %r:%s", self.worker_index, e, e)
      closure.output_remote_value._set_error(e)  
      self._cluster._closure_queue.mark_failed(e)  


7.3 資料

我們接下來看看如何把資料讀取放到工作者上執行。前面提到了,在 _create_per_worker_resources 會呼叫 create_resource,為每一個工作者建立其自己的資源。

  def create_resource(self, function, args=None, kwargs=None):
    """Synchronously creates a per-worker resource represented by a  RemoteValue .

    Args:
      function: the resource function to be run remotely. It should be a
         tf.function , a concrete function or a Python function.
      args: positional arguments to be passed to the function.
      kwargs: keyword arguments to be passed to the function.

    Returns:
      one or several RemoteValue objects depending on the function return
      values.
    """
    # Some notes about the concurrency: currently all the activities related to
    # the same worker such as creating resources, setting resources' aborted
    # status, and executing closures happen on the same thread. This allows us
    # to have simpler logic of concurrency.
    closure = ResourceClosure(
        function,
        self._cluster.closure_queue._cancellation_mgr,  
        args=args,
        kwargs=kwargs)
    resource_remote_value = closure.build_output_remote_value()
    self._register_resource(resource_remote_value)

    # The following is a short-term solution to lazily create resources in
    # parallel.
    resource_remote_value._set_aborted() 
    return resource_remote_value

_register_resource 則會把每個 Worker 的資源註冊到 Worker 之上。

def _register_resource(self, resource_remote_value):
  if not isinstance(resource_remote_value, RemoteValue):
    raise ValueError("Resource being registered is not of type "
                     " tf.distribute.experimental.coordinator.RemoteValue .")
  self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))

邏輯如下,虛線表述資料流。使用者通過 put 方法向佇列之中放入 Closure,Worker 通過 put 方法從佇列獲取 Closure 執行。

7.4 停止

Stop 等一系列方法負責停止。

  def stop(self):
    """Ensure the worker thread is closed."""
    self._should_worker_thread_run = False

  def _set_resources_aborted(self):
    for weakref_resource in self._resource_remote_value_refs:
      resource = weakref_resource()
      if resource:
        resource._set_aborted()  # pylint: disable=protected-access

  def _set_dead(self):
    raise NotImplementedError("_set_dead is not implemented.")

7.5 與 Strategy 聯絡

至此,我們其實還沒有正式和 Strategy 聯絡起來,我們再用一個例子來看看,這裡會發現,傳遞給 coordinator 的方法之中,會呼叫 strategy.run(replica_fn, args=(next(iterator),)),這樣就和 strategy 聯絡起來了。

    strategy = ...
    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
        strategy)

    def dataset_fn():
      return tf.data.Dataset.from_tensor_slices([1, 1, 1])

    with strategy.scope():
      v = tf.Variable(initial_value=0)

    @tf.function
    def worker_fn(iterator):
      def replica_fn(x):
        v.assign_add(x)
        return v.read_value()
      return strategy.run(replica_fn, args=(next(iterator),)) # 這裡正式聯絡起來

    distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
    distributed_iterator = iter(distributed_dataset)
    result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
    assert coordinator.fetch(result) == 1

8. Failover

8.1 策略

應對失敗的總體策略大致如下:

  • 當發現一個工作者失敗了,Coordinator 把 function 再次放入佇列,然後發給另一個工作者執行,同時啟動一個後臺執行緒等待恢復,如果恢復了,則用資源來重建這個工作者,繼續分配工作。

  • 因此,一些工作者的失敗並不妨礙叢集繼續工作,這使得叢集之中的例項可以偶爾不可用(例如,可搶佔或spot 例項)。但是協調者和引數伺服器必須始終可用,這樣叢集才能取得進展。

8.2 工作者失敗

當發生工作者失敗(failure)時候,具體邏輯如下:

  • ClusterCoordinator 類與 tf.distribution.experimental.ParameterServerStrategy 一起使用時,具有內建的工作者故障容錯功能。也就是說,當一些工作者由於任何原因,協調器無法聯絡上它們,這些工作者的訓練進度將繼續由其餘工作者完成。
  • 在工作者恢復時,之前提供的資料集函式(對於自定義訓練迴圈,可以是 ClusterCoordinator.create_per_worker_dataset,或者是 tf.keras.utils.experimental.DatasetCreator 用於 Model.fit )將被呼叫到工作者身上,以重新建立資料集。
  • 當一個失敗的工作者恢復之後,在使用通過 create_per_worker_dataset 建立的資料被重新建立後,它將被新增到函式執行中。

8.3 引數伺服器或者協調器故障

當引數伺服器失敗時,schedule,join 或 done 會引發 tf.errors.UnavailableError。在這種情況下,除了重置失敗的引數伺服器外,使用者還應該重新啟動協調器,使其重新連線到工作者和引數伺服器,重新建立變數,並載入檢查點。如果協調器發生故障,在使用者把它重置回來之後,程式會自動連線到工作者和引數伺服器,並從檢查點繼續前進。因為協調器本身也可能變得不可用。因此建議使用某些工具以便不丟失訓練進度:

  • 因此,在使用者的程式中,必須定期儲存檢查點檔案,並在程式開始時恢復。如果 "tf.keras.optimizers.Optimizer" 被應用 checkpoint,在從檢查點恢復後,其 "iterations" 屬性會大致顯示已經進行的步驟數。這可以用來決定在訓練完成前還需要多少個 epochs 和步驟(steps)。
  • 對於 Model.fit,你應該使用 BackupAndRestore 回撥,它可以自動處理進度的儲存和恢復。
  • 對於一個自定義的訓練迴圈,你應該定期檢查模型變數,並在訓練開始前從檢查點(如果有的話)載入模型變數。如果優化器有檢查點,訓練進度可以從 optimizer.iterations 中大致推斷出來。
checkpoint_manager = tf.train.CheckpointManager(
    tf.train.Checkpoint(model=model, optimizer=optimizer),
    checkpoint_dir,
    max_to_keep=3)
if checkpoint_manager.latest_checkpoint:
  checkpoint = checkpoint_manager.checkpoint
  checkpoint.restore(
      checkpoint_manager.latest_checkpoint).assert_existing_objects_matched()

global_steps = int(optimizer.iterations.numpy())
starting_epoch = global_steps // steps_per_epoch

for _ in range(starting_epoch, num_epoches):
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))
  coordinator.join()
  checkpoint_manager.save()

8.4 返回 RemoteValue

如果一個函式被成功執行,就可以成功獲取到 RemoteValue。這是因為目前在執行完一個函式後,返回值會立即被複制到協調器。如果在複製過程中出現任何工作者故障,該函式將在另一個可用的工作者上重試。因此,如果你想優化效能,你可以安排(schedule)一個沒有返回值的函式。

8.5 錯誤報告

一旦協調器發現一個錯誤,如來自引數伺服器的 UnavailableError 或其他應用錯誤,如來自 tf.debugging.check_numerics 的 InvalidArgument,它將在引發錯誤之前取消所有 pending 和排隊(queued)的函式。獲取它們相應的 RemoteValue 將引發一個 CancelledError 。

在引發錯誤後,協調器將不會引發相同的錯誤或任何引發一個來自已取消函式的錯誤。

ClusterCoordinator 假設所有的函式錯誤都是致命的,基於這個假設,其的錯誤報告邏輯是:

  • Schedule 和 join 都可以引發一個不可重試的錯誤,這是協調者從任何先前安排的函式中看到的第一個錯誤。
  • 當一個錯誤被丟擲時,不保證有多少先前安排的功能被執行;沒有被執行的功能將被丟棄並被標記為取消。
  • 在一個錯誤被丟擲後,錯誤的內部狀態將被清除。

8.6 WorkerPreemptionHandler

WorkerPreemptionHandler 是處理失敗的主要模組,其定義如下:

class WorkerPreemptionHandler(object):
  """Handles worker preemptions."""

  def __init__(self, server_def, cluster):
    self._server_def = server_def
    self._cluster = cluster
    self._cluster_update_lock = threading.Lock()
    self._cluster_due_for_update_or_finish = threading.Event()
    self._worker_up_cond = threading.Condition(self._cluster_update_lock)
    self._error_from_recovery = None
    self._should_preemption_thread_run = True
    self._preemption_handler_thread = threading.Thread(
        target=self._preemption_handler,
        name="WorkerPreemptionHandler",
        daemon=True)
    self._preemption_handler_thread.start()

8.6.1 配置

在 Cluster 生成時,會把 WorkerPreemptionHandler 配置進來。

self.failure_handler = WorkerPreemptionHandler(context.get_server_def(), self)

8.6.2 等待

在處理 closure 時,會用 wait_on_failure 包裹一層用來處理錯誤。

  def _process_closure(self, closure):
    """Runs a closure with preemption handling."""
    assert closure is not None
    try:
      with self._cluster.failure_handler.wait_on_failure(
          on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),  
          on_recovery_fn=self._set_resources_aborted,
          worker_device_name=self.device_name):
        closure.execute_on(self)

WorkerPreemptionHandler 的 wait_on_failure 方法如下:

  @contextlib.contextmanager
  def wait_on_failure(self,
                      on_failure_fn=None,
                      on_transient_failure_fn=None,
                      on_recovery_fn=None,
                      worker_device_name="(unknown)"):
    """Catches worker preemption error and wait until failed workers are back.

    Args:
      on_failure_fn: an optional function to run if preemption happens.
      on_transient_failure_fn: an optional function to run if transient failure
        happens.
      on_recovery_fn: an optional function to run when a worker is recovered
        from preemption.
      worker_device_name: the device name of the worker instance that is passing
        through the failure.

    Yields:
      None.
    """
    try:
      yield
    except (errors.OpError, InputError) as e:
      # If the error is due to temporary connectivity issues between worker and
      # ps, put back closure, ignore error and do not mark worker as failure.
      if self._cluster._record_and_ignore_transient_ps_failure(e):  
        if on_transient_failure_fn:
          on_transient_failure_fn()
        return

      # Ignoring derived CancelledErrors to tolerate transient failures in
      # PS-worker communication, which initially exposed as an UnavailableError
      # and then lead to sub-function cancellation, subsequently getting
      # reported from worker to chief as CancelledError.
      # We do not mark either worker or PS as failed due to only CancelledError.
      # If there are real (non-transient) failures, they must also be reported
      # as other errors (UnavailableError most likely) in closure executions.
      if isinstance(e, errors.CancelledError) and "/job:" in str(e):
        if on_transient_failure_fn:
          on_transient_failure_fn()
        return

      # This reraises the error, if it's not considered recoverable; otherwise,
      # the following failure recovery logic run. At this time, only worker
      # unavailability is recoverable. PS unavailability as well as other
      # errors in the user function is not recoverable.
      self._validate_preemption_failure(e)

      if on_failure_fn:
        on_failure_fn()

      with self._cluster_update_lock:
        self._cluster_due_for_update_or_finish.set()
        self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
        if self._error_from_recovery:
          try:
            raise self._error_from_recovery
          finally:
            self._error_from_recovery = None

      if on_recovery_fn:
        with self.wait_on_failure(
            on_recovery_fn=on_recovery_fn,
            on_transient_failure_fn=on_transient_failure_fn,
            worker_device_name=worker_device_name):
          on_recovery_fn()

_validate_preemption_failure 定義如下:

  def _validate_preemption_failure(self, e):
    """Validates that the given exception represents worker preemption."""

    # Only categorize the failure as a worker preemption if the cancellation
    # manager did not attempt to cancel the blocking operations.
    if _is_worker_failure(e) and (
        not self._cluster._closure_queue._cancellation_mgr.is_cancelled):  
      return
    raise e

8.6.3 handler

WorkerPreemptionHandler 有一個後臺執行緒 _preemption_handler_thread。

    self._preemption_handler_thread = threading.Thread(
        target=self._preemption_handler,
        name="WorkerPreemptionHandler",
        daemon=True)
    self._preemption_handler_thread.start()


_preemption_handler 會進行必要的錯誤處理。

  def _preemption_handler(self):
    """A loop that handles preemption.

    This loop waits for signal of worker preemption and upon worker preemption,
    it waits until all workers are back and updates the cluster about the
    restarted workers.
    """
    assert self._should_preemption_thread_run
    while True:
      self._cluster_due_for_update_or_finish.wait()
      if not self._should_preemption_thread_run:
        break

      with self._cluster_update_lock:
        try:
          context.context().update_server_def(self._server_def)

          # Cluster updated successfully, clear the update signal, and notify
          # all workers that they are recovered from failure.
          self._worker_up_cond.notify_all()
          # The check for _should_preemption_thread_run is necessary since the
          #  stop  may have already set _cluster_due_for_update_or_finish.
          if self._should_preemption_thread_run:
            self._cluster_due_for_update_or_finish.clear()
        except Exception as e:  
          try:
            self._validate_preemption_failure(e)
          except Exception as ps_e: 
            # In this case, a parameter server fails. So we raise this error to
            # the caller of  wait_on_failure .
            self._error_from_recovery = ps_e
            self._worker_up_cond.notify_all()
            if self._should_preemption_thread_run:
              self._cluster_due_for_update_or_finish.clear()
          # NOTE: Since the first RPC (GetStatus) of update_server_def is
          # currently blocking by default, error should only happen if:
          # (1) More workers failed while waiting for the previous workers to
          #     come back;
          # (2) Worker failed when exchanging subsequent RPCs after the first
          #     RPC returns.
          # Consider adding backoff retry logic if we see the error logged
          # too frequently.

9. 總結

依據前面的程式碼,我們總結出來問題點如下:

  • Worker 如何知道使用哪些裝置?答案是:在叢集建立工作者時候,會給每一個工作者設定一個裝置。

  • 如何具體執行使用者函式?答案是:在工作者執行 Closure 時候,會在指定執行在本工作者裝置上,然後執行指定的方法(Self._function)。Self._function 是使用者自定義的 function,其中可以使用 strategy.run 把訓練方法分發到遠端工作者進行訓練。

  • 如何獲取資料?答案是:為每個工作者建立一個 PerWorkerValues,PerWorkerValues 是一個容納 value 列表的容器,每個工作者從對應 PerWorkerValues 之中獲取資料。

0xFF 參考

tensorflow原始碼解析之distributed_runtime

TensorFlow分散式訓練

Tensorflow分散式原理理解

TensorFlow架構與設計:概述

Tensorflow 跨裝置通訊

TensorFlow 篇 | TensorFlow 2.x 分散式訓練概覽

相關文章