[原始碼解析] TensorFlow 分散式之 MirroredStrategy

羅西的思考發表於2022-04-19

[原始碼解析] TensorFlow 分散式之 MirroredStrategy

MirroredStrategy 策略通常用於在一臺機器上用多個GPU進行訓練。其主要難點就是:如何更新 Mirrored 變數?如何分發計算?本文我們看看其總體思路和如何更新變數。

這裡安利兩個github,都是非常好的學習資料,推薦。

https://github.com/yuhuiaws/ML-study

https://github.com/Jack47/hack-SysML

另外推薦西門宇少的最新大作讓Pipeline在Transformer LM上沿著Token level並行起來——TeraPipe

本系列其他文章是:

[翻譯] TensorFlow 分散式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"

[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構

[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache

[原始碼解析] TensorFlow 分散式環境(5) --- Session

[原始碼解析] TensorFlow 分散式環境(7) --- Worker 動態邏輯

[原始碼解析] TensorFlow 分散式環境(8) --- 通訊機制

[翻譯] 使用 TensorFlow 進行分散式訓練

[原始碼解析] TensorFlow 分散式 DistributedStrategy 之基礎篇

[原始碼解析] TensorFlow 之 分散式變數

1. 設計&思路

1.1 主要邏輯

MirroredStrategy 是TF的單機多卡同步的資料並行分散式訓練策略。

其主要邏輯如下:

  • MirroredStrategy 策略自動使用所有能被 TensorFlow 發現的 GPU 來做分散式訓練,如果使用者只想使用部分 GPU,則需要通過 devices 引數來指定使用哪些裝置。
  • 在訓練開始前,MirroredStrategy 策略把一份完整的模型副本複製到所有 N 個計算裝置(GPU)上。模型中的每個變數 (Variables) 都會進行映象複製,然後被放置到相應的 GPU 上,這些變數就是映象變數 (MirroredVariable)。
  • 資料並行的意義是:工作者會收到 tf.data.Dataset 傳來的資料,在訓練開始之後,每次傳入一個批次資料時,會把資料分成 N 份,這 N 份資料被分別傳入 N 個計算裝置。
  • 同步的意義是:在訓練中,每個工作者會在自己獲取的輸入資料上進行前向計算和反向計算,並且在每個步驟結束時彙總梯度。只有當所有裝置均更新本地變數後,才會進行下一輪訓練。
  • MirroredStrategy 策略通過 AllReduce 演算法在每個 GPU 之間對對所有映象變數保持同步更新, 同步方式是在計算裝置間進行高效交換梯度資料,並進行求和,這樣最終每個裝置都有了所有裝置的梯度之和,然後使用梯度求和的結果來更新各個 GPU 的本地變數。AllReduce 演算法預設使用 NcclAllReduce ,使用者可以通過配置 cross_device_ops 引數來修改為其它 AllReduce 演算法(如 HierarchicalCopyAllReduce )。

圖 1 MirroredStrategy 策略機制

1.2 使用

具體使用程式碼如下,程式碼之中夾雜這列印出來執行時候的具體變數。

    >>> @tf.function
    ... def step_fn(var):
    ...
    ...   def merge_fn(strategy, value, var):
    ...     # All-reduce the value. Note that value here is a
    ...     # tf.distribute.DistributedValues.
    ...     reduced = strategy.extended.batch_reduce_to(
    ...         tf.distribute.ReduceOp.SUM, [(value, var)])[0]
    ...     strategy.extended.update(var, lambda var, value: var.assign(value),
    ...         args=(reduced,))
    ...
    ...   value = tf.identity(1.)
    ...   tf.distribute.get_replica_context().merge_call(merge_fn,
    ...     args=(value, var))
    >>>
    >>> def run(strategy):
    ...   with strategy.scope():
    ...     v = tf.Variable(0.)
    ...     strategy.run(step_fn, args=(v,))
    ...     return v
    >>>
    >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
    MirroredVariable:{
      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
    }
    >>> run(tf.distribute.experimental.CentralStorageStrategy(
    ...     compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
    >>> run(tf.distribute.OneDeviceStrategy("GPU:0"))

1.3 分析思路

因為我們之前對PyTorch的資料並行實現DDP有了較為深入的分析,所以我們此處分析重點就是尋找TF和PyTorch的異同。能夠想到的問題是:

  • 在單機上是多程式還是多執行緒訓練?
  • 如何分發模型?(這部分通過MirroredVariable來實現,我們已經在前面章節分析)。
  • 如何保持映象變數 (MirroredVariable) 對外提供一個統一檢視?(這部分通過MirroredVariable來實現,我們已經在前面章節分析)。
  • 如何使用集合通訊操作(庫)?

從前面 MirroredVariable 章節我們可以知道,這些變數最終都是使用 Strategy 或者 Extended 進行操作。

2. 定義

2.1 MirroredStrategy

MirroredStrategy 的定義如下啊,主要就是生成了 MirroredExtended。

@tf_export("distribute.MirroredStrategy", v1=[])  # pylint: disable=g-classes-have-attributes
class MirroredStrategy(distribute_lib.Strategy):
  """Synchronous training across multiple replicas on one machine.
  """
  def __init__(self, devices=None, cross_device_ops=None):
    extended = MirroredExtended(
        self, devices=devices, cross_device_ops=cross_device_ops)
    super(MirroredStrategy, self).__init__(extended)
    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
        "MirroredStrategy")  

2.2 MirroredExtended

可以看到其核心變數如下:

  • devices:本訓練所擁有的裝置;
  • _collective_ops_in_use :底層的集合通訊操作;
class MirroredExtended(distribute_lib.StrategyExtendedV1):
  """Implementation of MirroredStrategy."""

  # If this is set to True, use NCCL collective ops instead of NCCL cross device
  # ops.
  _prefer_collective_ops = False

  def __init__(self, container_strategy, devices=None, cross_device_ops=None):
    super(MirroredExtended, self).__init__(container_strategy)
    if context.executing_eagerly():
      if devices and not _is_device_list_single_worker(devices):
        raise RuntimeError("In-graph multi-worker training with "
                           "MirroredStrategy is not supported in eager mode.")
      else:
        if TFConfigClusterResolver().cluster_spec().as_dict():
          # if you are executing in eager mode, only the single machine code
          # path is supported.
        devices = devices or all_local_devices()
    else:
      devices = devices or all_devices()

    self._cross_device_ops = cross_device_ops
    self._collective_ops_in_use = False
    self._collective_key_base = container_strategy._collective_key_base
    self._initialize_strategy(devices)
    self._communication_options = collective_util.Options(
        implementation=collective_util.CommunicationImplementation.NCCL)

    if ops.executing_eagerly_outside_functions():
      self.experimental_enable_get_next_as_optional = True

    # Flag to turn on VariablePolicy.
    self._use_var_policy = False

3. 初始化

初始化這裡分為兩種:

  • 單個節點。這裡會初始化單個節點上的單個 worker,初始化集合通訊操作。
  • 多個節點。呼叫 _initialize_multi_worker 來初始化多個節點上的多個 worker。
def _initialize_strategy(self, devices):
  # The _initialize_strategy method is intended to be used by distribute
  # coordinator as well.
  devices = tuple(device_util.resolve(d) for d in devices)
  if _is_device_list_single_worker(devices):
    self._initialize_single_worker(devices)
    self._collective_ops = self._make_collective_ops(devices)
    if self._prefer_collective_ops and (
        isinstance(self._cross_device_ops, cross_device_ops_lib.NcclAllReduce)
        or isinstance(self._inferred_cross_device_ops,
                      cross_device_ops_lib.NcclAllReduce)):
      self._collective_ops_in_use = True
      self._inferred_cross_device_ops = None
  else:
    self._initialize_multi_worker(devices)

具體如何區分節點數目則是由 _is_device_list_single_worker 方法完成。

def _is_device_list_single_worker(devices):
  """Checks whether the devices list is for single or multi-worker.

  Args:
    devices: a list of device strings or tf.config.LogicalDevice objects, for
      either local or for remote devices.

  Returns:
    a boolean indicating whether these device strings are for local or for
    remote.

  Raises:
    ValueError: if device strings are not consistent.
  """
  specs = []
  for d in devices:
    name = d.name if isinstance(d, context.LogicalDevice) else d
    specs.append(tf_device.DeviceSpec.from_string(name))
  num_workers = len({(d.job, d.task, d.replica) for d in specs})
  all_local = all(d.job in (None, "localhost") for d in specs)
  any_local = any(d.job in (None, "localhost") for d in specs)

  if any_local and not all_local:
    raise ValueError("Local device string cannot have job specified other "
                     "than 'localhost'")

  if num_workers == 1 and not all_local:
    if any(d.task is None for d in specs):
      raise ValueError("Remote device string must have task specified.")

  return num_workers == 1

3.1 初始化多worker

這部分其實是被 MultiWorkerMirroredStrategy 使用,我們這裡只是大概介紹一下。

3.1.1 MultiWorkerMirroredStrategy

可以看到,其使用了 CollectiveAllReduceExtended 進行操作。

@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"])  disable=missing-docstring
class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):

  # The starting number for collective keys. This should only be set in tests.
  _collective_key_base = 0

  def __init__(self,
               communication=collective_util.CommunicationImplementation.AUTO,
               cluster_resolver=None):
    """Initializes the object."""
    communication_options = collective_util.Options(
        implementation=communication)
    super(CollectiveAllReduceStrategyV1, self).__init__(
        CollectiveAllReduceExtended(
            self,
            cluster_resolver=cluster_resolver,
            communication_options=communication_options))
    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
        "MultiWorkerMirroredStrategy")
    distribute_lib.distribution_strategy_replica_gauge.get_cell(
        "num_workers").set(self.extended._num_workers)
    distribute_lib.distribution_strategy_replica_gauge.get_cell(
        "num_gpu_per_worker").set(self.extended._num_gpus_per_worker)

3.1.2 CollectiveAllReduceExtended

CollectiveAllReduceExtended 擴充套件了 MirroredExtended。

class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):

因此,在多節點環境下,就是走的 _initialize_multi_worker 路徑。

  • 初始化 worker,其實就是字串;
  • 初始化 worker_devices,是一個 tuple;
  • _inferred_cross_device_ops 或者由使用者指定,或者是 NcclAllReduce。
def _initialize_multi_worker(self, devices):
  """Initializes the object for multi-worker training."""
  device_dict = _group_device_list(devices)
  workers = []
  worker_devices = []
  for job in ("chief", "worker"):
    for task in range(len(device_dict.get(job, []))):
      worker = "/job:%s/task:%d" % (job, task)
      workers.append(worker)
      worker_devices.append((worker, device_dict[job][task]))

  # Setting _default_device will add a device scope in the
  # distribution.scope. We set the default device to the first worker. When
  # users specify device under distribution.scope by
  #   with tf.device("/cpu:0"):
  #     ...
  # their ops will end up on the cpu device of its first worker, e.g.
  # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode.
  self._default_device = workers[0]
  self._host_input_device = numpy_dataset.SingleDevice(workers[0])

  self._devices = tuple(devices)
  self._input_workers_devices = worker_devices
  self._is_multi_worker_training = True

  # 如何選擇集合操作
  if len(workers) > 1:
    if (not isinstance(self._cross_device_ops,
                       cross_device_ops_lib.ReductionToOneDevice) or
        self._cross_device_ops._num_between_graph_workers > 1):  
      raise ValueError(
          "In-graph multi-worker training with MirroredStrategy is not "
          "supported.")
    self._inferred_cross_device_ops = self._cross_device_ops
  else:
    self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()

3.2 跨裝置操作

上面提到了跨裝置如何選擇集合操作,我們接下來就先看看這部分,然後再研究單 worker 初始化。

基本上所有的分散式策略都通過某些 collective ops 和 cross device ops 來進行資料通訊,比如 MirroredStrategy 使用 CollectiveOps 來對變數保持同步,而 CollectiveOps 會在 TensorFlow 執行時候自動根據硬體配置,當前網路拓撲以及張量大小來選擇合適的 AllReduce 演算法。比如,在 tensorflow/core/kernels/collective_ops.cc 之中有如下使用,c 是當前 op 的計算狀態, col_exec 是根據系統的具體情況來選擇的 collective executor,所有的集合操作,比如 all reduce,boardcast 等操作都由 collective executor 去執行。

void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
                      DoneCallback done) override {
  auto output_shape = c->input(0).shape();
  output_shape.set_dim(
      0, output_shape.dim_size(0) * col_params_->group.group_size);
  col_params_->instance.shape = output_shape;

  // Allocate output on the first pass through this function.  This must be
  // done immediately, while we're still in the executor thread.  Otherwise
  // the memory is not guaranteed to be unused by any concurrently executing
  // GPU kernel.
  if (c->mutable_output(0) == nullptr) {
    // Allocate the output tensor.
    Tensor* output = nullptr;
    OP_REQUIRES_OK_ASYNC(
        c, c->allocate_output(0, col_params_->instance.shape, &output), done);
  }
  if (!CanProceedWithCompute(c, col_exec, done)) return;

  auto actual_done = [c, col_params = col_params_, done](const Status& s) {
    col_params->Unref();
    OP_REQUIRES_OK_ASYNC(c, s, done);
    done();
  };

  col_params_->Ref();
  col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}

3.2.1 CrossDeviceOps

CrossDeviceOps 是跨裝置操作的基類,目前其派生類如下:

  • tf.distribute.ReductionToOneDevice。
  • tf.distribute.NcclAllReduce。
  • tf.distribute.HierarchicalCopyAllReduce。
@tf_export("distribute.CrossDeviceOps")
class CrossDeviceOps(object):
  """Base class for cross-device reduction and broadcasting algorithms.

  The main purpose of this class is to be passed to
  tf.distribute.MirroredStrategy in order to choose among different cross
  device communication implementations. Prefer using the methods of
  tf.distribute.Strategy instead of the ones of this class.

  Implementations:
  * tf.distribute.ReductionToOneDevice
  * tf.distribute.NcclAllReduce
  * tf.distribute.HierarchicalCopyAllReduce
  """

3.2.2 ReductionToOneDevice

ReductionToOneDevice 將值複製到一個裝置上規約它們,然後將規約後的值廣播出來,它不支援批處理。

@tf_export("distribute.ReductionToOneDevice")
class ReductionToOneDevice(CrossDeviceOps):
  """A CrossDeviceOps implementation that copies values to one device to reduce.

  This implementation always copies values to one device to reduce them, then
  broadcast reduced values to the destinations. It doesn't support efficient
  batching.
  """

  def __init__(self, reduce_to_device=None, accumulation_fn=None):
    """Initializes with a device to reduce to and a way to accumulate.

    Args:
      reduce_to_device: the intermediate device to reduce to. If None, reduce
        to the first device in destinations of the reduce method.
      accumulation_fn: a function that does accumulation.  If None,
        tf.math.add_n is used.
    """
    self.reduce_to_device = reduce_to_device
    self.accumulation_fn = accumulation_fn or math_ops.add_n
    super(ReductionToOneDevice, self).__init__()

  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
                            options):
    del options  # Unused.
    if check_destinations(destinations):
      devices = get_devices_from(destinations, self._canonicalize_devices)
    else:
      devices = get_devices_from(per_replica_value, self._canonicalize_devices)
    reduce_to_device = self.reduce_to_device or devices[0]
    reduced = _simple_reduce(per_replica_value, reduce_to_device,
                             self.accumulation_fn, reduce_op)
    return self.broadcast(reduced, destinations)

  def _gather_implementation(self, per_replica_value, destinations, axis,
                             options):
    del options  # Unused.
    if check_destinations(destinations):
      devices = get_devices_from(destinations, self._canonicalize_devices)
    else:
      devices = get_devices_from(per_replica_value, self._canonicalize_devices)
    reduce_to_device = self.reduce_to_device or devices[0]
    gathered = _simple_gather(per_replica_value, reduce_to_device, axis)
    return self.broadcast(gathered, destinations)

  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
                                  options):
    return [
        self.reduce_implementation(
            reduce_op, t, destinations=v, options=options)
        for t, v in value_destination_pairs
    ]

這裡 _simple_reduce 如下:

def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
                   reduce_op):
  """Reduces the value by accumulation_fn and reduce_op."""
  all_values = per_replica_value.values
  count = len(all_values)

  with ops.device(reduce_to_device):
    with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
      reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
          all_values, accumulation_fn)
      if reduce_op == reduce_util.ReduceOp.MEAN:
        reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices(
            reduced, count)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise ValueError("reduce_op must be Reduce.SUM or Reduce.MEAN.")
  return reduced

3.2.3 AllReduceCrossDeviceOps

這是 NcclAllReduce 和 HierarchicalCopyAllReduce 的基類。

class AllReduceCrossDeviceOps(CrossDeviceOps):
  """All-reduce implementation of CrossDeviceOps.

  It performs all-reduce when applicable using NCCL or hierarchical copy. For
  the batch API, tensors will be repacked or aggregated for more efficient
  cross-device transportation.

  For reduces that are not all-reduce, it falls back to
  tf.distribute.ReductionToOneDevice.
  """

  def __init__(self, all_reduce_alg="nccl", num_packs=1):
    """Initializes the object.

    Args:
      all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or
        "hierarchical_copy" are supported.
      num_packs: a non-negative integer. The number of packs to split values
        into. If zero, no packing will be done.
    """
    self._all_reduce_alg = all_reduce_alg
    self._num_packs = num_packs
    self._simple_cross_replica_ops = ReductionToOneDevice()
    super(AllReduceCrossDeviceOps, self).__init__()

  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
                            options):
    del options  # Unused.
    # To use NCCL or all-reduce, source and destination devices should match,
    # and none of the devices should be CPU.
    if (_devices_match(per_replica_value, destinations) and
        not any("cpu" in d.lower() for d in get_devices_from(destinations))):
      return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
    else:
      return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
                                                   destinations)

  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
                                  options):
    if _all_devices_match(value_destination_pairs):
      return self._batch_all_reduce(reduce_op,
                                    [v[0] for v in value_destination_pairs])
    else:
      return [
          self.reduce_implementation(reduce_op, value, dest, options)
          for value, dest in value_destination_pairs
      ]

  def _batch_all_reduce(self, reduce_op, per_replica_values):
    """All-reduce algorithm in a batch."""
    dense_values, dense_indices, sparse_values, sparse_indices = (
        cross_device_utils.split_by_sparsity(per_replica_values))
    if dense_values:
      dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
    else:
      dense_results = []
    if sparse_values:
      sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
                                                        sparse_values)
    else:
      sparse_results = []
    return cross_device_utils.stitch_values(((dense_results, dense_indices),
                                             (sparse_results, sparse_indices)))

  def _do_batch_all_reduce(self, reduce_op, dense_values):
    """Run batch all-reduces."""

    destinations = dense_values[0]._devices  # pylint: disable=protected-access
    grouped = _group_value_by_device(dense_values)

    # device_grad_packs:
    # [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]]
    device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs)

    # The actual aggregation of the repacked gradients. Note that they are
    # sharded among different aggregation trees. So it is important to strike
    # the balance on num_splits.
    if self._all_reduce_alg == "nccl":
      reduced = cross_device_utils.aggregate_gradients_using_nccl(
          device_grad_packs)
    else:
      reduced = (
          cross_device_utils.aggregate_gradients_using_hierarchical_copy(
              destinations, device_grad_packs))

    reduced = _unpack_tensors(reduced, tensor_packer)
    return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op)

  def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values):
    """Run batch all-reduce for sparse values."""
    # Use sparse_values as destinations to do all-reduces. It is effectively
    # an allgather under the hood but not an efficient one.
    return self._simple_cross_replica_ops.batch_reduce(
        reduce_op, zip(sparse_values, sparse_values))

  def _gather_implementation(self, per_replica_value, destinations, axis,
                             options):
    return ReductionToOneDevice()._gather(per_replica_value, destinations, axis,  # pylint: disable=protected-access
                                          options)

3.2.4 NcclAllReduce

NcclAllReduce 方法會使用 Nccl 進行AllReduce。

@tf_export("distribute.NcclAllReduce")
class NcclAllReduce(AllReduceCrossDeviceOps):
  """NCCL all-reduce implementation of CrossDeviceOps.

  It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be
  repacked or aggregated for more efficient cross-device transportation.

  For reduces that are not all-reduce, it falls back to
  tf.distribute.ReductionToOneDevice.
  """

  def __init__(self, num_packs=1):
    """Initializes the object.

    Args:
      num_packs: a non-negative integer. The number of packs to split values
        into. If zero, no packing will be done.

    Raises:
      ValueError: if num_packs is negative.
    """
    if num_packs < 0:
      raise ValueError(
          "NCCL all-reduce requires num_packs >= 0, but {} is specified".format(
              num_packs))
    super(NcclAllReduce, self).__init__(
        all_reduce_alg="nccl", num_packs=num_packs)

3.2.5 HierarchicalCopyAllReduce

HierarchicalCopyAllReduce 使用 Hierarchical 演算法進行AllReduce。它把資料沿著一些 hierarchy 的邊規約到某一個GPU,並沿著同一路徑廣播回每個GPU。對於批處理API,張量將被重新打包或聚合,以便更有效地跨裝置運輸。這是為 Nvidia DGX-1 建立的規約操作,它假設 GPU 在 DGX-1 機器上是 Hierarchical 連線的。如果你有不同的 GPU 相互連線,它可能會比 tf.distribution.ReductionToOneDevice更慢。

@tf_export("distribute.HierarchicalCopyAllReduce")
class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
  """Hierarchical copy all-reduce implementation of CrossDeviceOps.

  It reduces to one GPU along edges in some hierarchy and broadcasts back to
  each GPU along the same path. For the batch API, tensors will be repacked or
  aggregated for more efficient cross-device transportation.

  This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like
  that on DGX-1 machine. If you have different GPU inter-connections, it is
  likely that it would be slower than tf.distribute.ReductionToOneDevice.

  For reduces that are not all-reduce, it falls back to
  tf.distribute.ReductionToOneDevice.
  """

  def __init__(self, num_packs=1):
    """Initializes the object.

    Args:
      num_packs: a non-negative integer. The number of packs to split values
        into. If zero, no packing will be done.

    Raises:
      ValueError if num_packs is negative.
    """
    super(HierarchicalCopyAllReduce, self).__init__(
        all_reduce_alg="hierarchical_copy",
        num_packs=num_packs)

3.2.6 CollectiveAllReduce

CollectiveAllReduce 使用集合通訊進行 AllReduce,是 PyTorch 自己實現的演算法。

class CollectiveAllReduce(CrossDeviceOps):
  """All-reduce cross device ops using collective ops.

  In the between-graph replicated training, it will still do all-reduces across
  all workers and then put results on the right destinations.
  """

  def __init__(self,
               devices,
               group_size,
               collective_keys=None,
               canonicalize_devices=True):
    """Initializes the object.

    Args:
      devices: a list of device strings to run collectives on.
      group_size: the global group size. For between-graph replicated training
        it's the total number of devices across all workers.
      collective_keys: an optional CollectiveKey object.
      canonicalize_devices: Whether to canonicalize devices for workers or not.
    """
    if group_size % len(devices) > 0:
      raise ValueError("group_size must be divisible by the number of devices.")

    self._group_size = group_size
    self._collective_keys = (collective_keys or
                             cross_device_utils.CollectiveKeys())
    # This lock guards all collective launches, i.e. calls to
    # cross_device_utils.build_collectve_*.
    #
    # In a multi threaded eager program we need to ensure different groups of
    # collectives don't interleave each other, otherwise there could be
    # deadlocks. E.g. if two user threads both are launching collectives:
    #   user-thread-0  device0                 device1
    #   user-thread-1          device0 device1
    # In eager mode, we use one thread per device to launch collective ops, so
    # the above launch sequences end up with the following queues:
    #   device-0  collective-0  collective-1
    #   device-1  collective-1  collective-0
    # This deadlocks since neither collective is able to finish.
    self._lock = threading.Lock()

    if canonicalize_devices:
      self._devices = tuple(device_util.canonicalize(d) for d in devices)
    else:
      self._devices = tuple(
          device_util.canonicalize_without_job_and_task(d) for d in devices)
    group_key = self._collective_keys.get_group_key(self._devices)
    self._launchers = []
    # Whether to only use NCCL for batched all-reduce when NCCL is requested.
    # This is because of the lack of mechanism to order NCCL operations
    # deterministically.
    self._limited_nccl = False
    for device in self._devices:
      launcher = cross_device_utils.CollectiveReplicaLauncher(
          group_key, group_size, self._collective_keys, device)
      self._launchers.append(launcher)
      if not launcher.can_order_nccl():
        self._limited_nccl = True

    self._pool = multiprocessing.pool.ThreadPool(len(self._devices))

    super(CollectiveAllReduce, self).__init__()
    self._canonicalize_devices = canonicalize_devices

  @property
  def _num_between_graph_workers(self):
    # Currently we only support equal number of devices on each worker.
    return self._group_size / len(self._devices)

  def _all_reduce(self, reduce_op, value, replica_id, options):
    """Implements CrossDeviceOps.all_reduce."""
    flat_values = nest.flatten(value)

    implementation = options.implementation.value
    # If NCCL launches can't be ordered (self._limited_nccl == True), we only
    # use NCCL when batch_size > 1, hoping that there's only one batched
    # all-reduce, which is the gradient aggregation in optimizer. For TF 2.x,
    # NCCL launches are always ordered.
    if (self._limited_nccl and
        options.implementation == CommunicationImplementation.NCCL and
        len(flat_values) == 1):
      implementation = CommunicationImplementation.AUTO.value

    launcher = self._launchers[replica_id]
    dense_values, dense_indices, sparse_values, sparse_indices = (
        cross_device_utils.split_by_sparsity(flat_values))
    dense_results = []
    sparse_results = []

    if dense_values:
      # Reverse the lists so that there's better chance that values follows
      # the order in which they are calculated (e.g. when they're gradients), so
      # as to overlap calculation with communication. However, this may not be
      # optimal for cases like gradients of complicated non-sequential models.
      #
      # Note that we reverse the list before packing so that the first pack
      # won't be too small, since it's more likely for first few packs to have
      # long queuing time due to concurrent intense computation.
      #
      # TODO(b/147393503): explore solutions for optimal ordering.
      dense_values.reverse()
      packs = cross_device_utils.group_by_size(dense_values,
                                               options.bytes_per_pack)

      dense_results = launcher.batch_all_reduce(packs, implementation,
                                                options.timeout_seconds)
      if reduce_op == reduce_util.ReduceOp.MEAN:
        for i, v in enumerate(dense_results):
          with ops.device(self._devices[replica_id]):
            dense_results[i] = v / self._group_size
      dense_results.reverse()

    if sparse_values:

      for indexed_slice in sparse_values:
        sparse_results.append(
            launcher.all_reduce_indexed_slices(indexed_slice, implementation,
                                               options.timeout_seconds))

      if reduce_op == reduce_util.ReduceOp.MEAN:
        for i, v in enumerate(sparse_results):
          with ops.device(self._devices[replica_id]):
            sparse_results[i] = ops.IndexedSlices(
                values=sparse_results[i].values / self._group_size,
                indices=sparse_results[i].indices,
                dense_shape=sparse_results[i].dense_shape)

    flat_results = cross_device_utils.stitch_values(
        ((dense_results, dense_indices), (sparse_results, sparse_indices)))
    return nest.pack_sequence_as(value, flat_results)

  def _all_reduce_per_replica_values(self, reduce_op, per_replica_values,
                                     options):
    """All reduce a list of per_replica_value."""
    values_by_device = [[] for _ in self._devices]
    num_devices = len(self._devices)
    for per_replica in per_replica_values:
      for i in range(num_devices):
        values_by_device[i].append(per_replica.values[i])

    if context.executing_eagerly():

      def thread_fn(device_id):
        with context.eager_mode():
          return self._all_reduce(reduce_op, values_by_device[device_id],
                                  device_id, options)

      with self._lock:
        outputs_by_device = self._pool.map(thread_fn, list(range(num_devices)))
    else:
      outputs_by_device = []
      with self._lock:
        for i in range(num_devices):
          outputs_by_device.append(
              self._all_reduce(reduce_op, values_by_device[i], i, options))

    result = []
    for values in zip(*outputs_by_device):
      result.append(
          distribute_utils.regroup(values, wrap_class=value_lib.Mirrored))
    return result

  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
                            options):
    values_util.mark_as_unsaveable()
    all_reduced = self._all_reduce_per_replica_values(reduce_op,
                                                      [per_replica_value],
                                                      options)[0]
    devices = get_devices_from(destinations, self._canonicalize_devices)

    if _devices_match(per_replica_value, destinations,
                      self._canonicalize_devices):
      return all_reduced

    # Convert all_reduced to a Mirrored object, as a simple and uniform
    # utility to access component for a particular device.
    if not isinstance(all_reduced, value_lib.Mirrored):
      all_reduced = value_lib.Mirrored([all_reduced])

    # If we got this far, the destination devices do not match the all-reduce
    # devices, so we must map from one to the other.
    index = []
    # We must add these control dependencies, otherwise we can get deadlock.
    with ops.control_dependencies(all_reduced.values):
      for d in devices:
        with ops.device(d):
          for v in all_reduced.values:
            if v.device == d:
              index.append(array_ops.identity(v))
              break
          else:
            index.append(array_ops.identity(all_reduced._primary))  # pylint: disable=protected-access
    return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)

  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
                                  options):
    values_util.mark_as_unsaveable()
    all_devices_match = _all_devices_match(value_destination_pairs,
                                           self._canonicalize_devices)
    if all_devices_match:
      return self._all_reduce_per_replica_values(
          reduce_op, [v[0] for v in value_destination_pairs], options)
    else:
      return [
          self.reduce_implementation(reduce_op, value, dest, options)
          for value, dest in value_destination_pairs
      ]

  def _gather_implementation(self, per_replica_value, destinations, axis,
                             options):
    all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0]
    values_util.mark_as_unsaveable()
    devices = get_devices_from(destinations, self._canonicalize_devices)

    if _devices_match(per_replica_value, destinations,
                      self._canonicalize_devices):
      return all_gathered

    # Convert all_gathered to a Mirrored object, as a simple and uniform
    # utility to access component for a particular device.
    if not isinstance(all_gathered, value_lib.Mirrored):
      all_gathered = value_lib.Mirrored([all_gathered])

    # If we got this far, the destination devices do not match the all-gather
    # devices, so we must map from one to the other.
    index = []
    # We must add these control dependencies, otherwise we can get deadlock.
    with ops.control_dependencies(all_gathered.values):
      for d in devices:
        with ops.device(d):
          for v in all_gathered.values:
            if v.device == d:
              index.append(array_ops.identity(v))
              break
            else:
              index.append(array_ops.identity(all_gathered._primary))  # pylint: disable=protected-access
    return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)

  def _batch_all_gather(self, per_replica_values, axis, options):
    """all gather multiple per-replica-values."""
    batch_size = len(per_replica_values)
    # Pass options.implementation to the runtime as a communication
    # implementation hint.
    implementation = options.implementation.value
    # For now, we use NCCL only when batch_size > 1.
    # TODO(b/132575814): switch to NCCL for all collectives when implementation
    # is NCCL.
    if (options.implementation == CommunicationImplementation.NCCL and
        batch_size == 1):
      implementation = CommunicationImplementation.AUTO.value

    def compute_gathered_values():
      gathered_values = []
      with self._lock, ops.name_scope("allgather"):
        for per_replica in per_replica_values:
          outputs = []
          for i in range(len(self._devices)):
            outputs.append(self._launchers[i].all_gather(
                per_replica.values[i], axis, implementation,
                options.timeout_seconds))
          gathered_values.append(outputs)
      return gathered_values

    if context.executing_eagerly():
      gathered_values = def_function.function(compute_gathered_values)()
    else:
      gathered_values = compute_gathered_values()

    mirrored = []
    for value in gathered_values:
      mirrored.append(
          distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
    return mirrored

目前具體邏輯如下,可以看到有眾多實現方式,如何選擇就需要具體情況具體分析:

圖 2 跨裝置操作

3.3 單節點初始化

我們這裡重點還是研究單節點初始化,具體程式碼如下,可以看到主要就是:

  • 初始化單個worker;
  • 通過 _make_collective_ops 來建立集合操作;
if _is_device_list_single_worker(devices):
  self._initialize_single_worker(devices)
  self._collective_ops = self._make_collective_ops(devices) # 建立集合操作
  if self._prefer_collective_ops and (
      isinstance(self._cross_device_ops, cross_device_ops_lib.NcclAllReduce)
      or isinstance(self._inferred_cross_device_ops,
                    cross_device_ops_lib.NcclAllReduce)):
    self._collective_ops_in_use = True
    self._inferred_cross_device_ops = None

3.3.1 初始化單worker

此部分重點邏輯如下:

  • 首先拿到本次訓練使用的裝置 _devices ,大致為:('/replica:0/task:0/device:GPU:0', '/replica:0/task:0/device:GPU:1')。
  • 然後得到輸入對應的裝置 _input_workers_devices,大致為:('/replica:0/task:0/device:CPU:0': 0, '/replica:0/task:0/device:GPU:0', '/replica:0/task:0/device:GPU:1'),這個後續會被用來建立 InputWorkers。
  • 得到 _inferred_cross_device_ops,就是跨裝置使用的操作(這是依據已有條件推理出來的)。
  • 得到 _default_device,就是預設裝置,這裡會設定其 spec。DeviceSpec 被用來描述狀態儲存和計算髮生的位置。使用 "DeviceSpec" 可以解析裝置規格字串以驗證其有效性,然後合併它們或以程式設計方式組合它們。
def _initialize_single_worker(self, devices):
  """Initializes the object for single-worker training."""
  self._devices = tuple(device_util.canonicalize(d) for d in devices)
  self._input_workers_devices = (
      (device_util.canonicalize("/device:CPU:0", devices[0]), devices),)

  self._inferred_cross_device_ops = None if self._cross_device_ops else (
      cross_device_ops_lib.select_cross_device_ops(devices)) # 推理出跨裝置操作
  
  self._host_input_device = numpy_dataset.SingleDevice(
      self._input_workers_devices[0][0])
  
  self._is_multi_worker_training = False
  
  device_spec = tf_device.DeviceSpec.from_string(
      self._input_workers_devices[0][0])
  # Ensures when we enter strategy.scope() we use the correct default device
  if device_spec.job is not None and device_spec.job != "localhost":
    self._default_device = "/job:%s/replica:%d/task:%d" % (
        device_spec.job, device_spec.replica, device_spec.task)

具體如何推理出跨裝置操作?是通過 select_cross_device_ops 完成的。

def select_cross_device_ops(devices, session_config=None):
  """Find the best CrossDeviceOps locally given a tf.compat.v1.ConfigProto.

  Args:
    devices: a list of devices passed to tf.distribute.Strategy.
    session_config: a tf.compat.v1.ConfigProto or None. If None, it will
      make decision based on all logical devices.

  Returns:
    A subclass of CrossDeviceOps.
  """
  requested_devices = set(device_util.canonicalize(d) for d in devices)
  if ops.executing_eagerly_outside_functions():
    logical_gpus = context.context().list_logical_devices(device_type="GPU")
    physical_gpus = context.context().list_physical_devices(device_type="GPU")
    if len(logical_gpus) != len(physical_gpus):
      return ReductionToOneDevice()
    machine_devices = context.context().list_logical_devices()
  else:
    machine_devices = device_lib.list_local_devices(
        session_config=session_config)
    
  using_devices = set()
  for d in machine_devices:
    if device_util.canonicalize(d.name) in requested_devices:
      using_devices.add(d.name)

  if any("gpu" not in d.lower() for d in requested_devices):
    return ReductionToOneDevice()

  if kernels.get_registered_kernels_for_op("NcclAllReduce"):
    return NcclAllReduce(num_packs=1)
  else:
    return ReductionToOneDevice()

3.3.2 建立集合操作

_make_collective_ops 方法用來獲取集合操作。

def _make_collective_ops(self, devices):
  self._collective_keys = cross_device_utils.CollectiveKeys(
      group_key_start=1 + self._collective_key_base)  
  return cross_device_ops_lib.CollectiveAllReduce(
      devices=self._devices,
      group_size=len(self._devices),
      collective_keys=self._collective_keys)

CollectiveAllReduce 是使用集合通訊來完成跨裝置的All-reduce。

3.3.3 如何區分

目前有三個集合通訊相關的成員變數,我們需要梳理一下。

  • self._collective_ops :這是集合操作,實際上配置的是 CollectiveAllReduce。
  • self._inferred_cross_device_ops :根據裝置情況推理出來的跨裝置操作,實際上是 ReductionToOneDevice 或者 NcclAllReduce。
  • self._cross_device_ops :傳入的配置引數。如果使用者想重寫跨裝置通訊,可以通過使用 cross_device_ops 引數來提供tf.distribute.CrossDeviceOps的例項。比如:mirrored_strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())。目前,除了預設選項 tf.distribute.NcclAllReduce外,還有 tf.distribute.HierarchicalCopyAllReduce 和 tf.distribute.ReductionToOneDevice 兩個選項。

我們看看這三個成員變數的具體用法。比如,_batch_reduce_to 會用到 _get_cross_device_ops。

def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
  cross_device_ops = None
  for value, _ in value_destination_pairs:
    if cross_device_ops is None:
      cross_device_ops = self._get_cross_device_ops(value) # 這裡用到
    elif cross_device_ops is not self._get_cross_device_ops(value):
      raise ValueError("inputs to batch_reduce_to must be either all on the "
                       "the host or all on the compute devices")
  return cross_device_ops.batch_reduce(
      reduce_op,
      value_destination_pairs,
      options=self._communication_options.merge(options))

_get_cross_device_ops 會依據不同配置和情況來選擇具體採用哪一種集合操作。

def _use_merge_call(self):
  # We currently only disable merge_call when XLA is used to compile the fn
  # passed to strategy.run and all devices are GPU.
	return not control_flow_util.GraphOrParentsInXlaContext(
        ops.get_default_graph()) or not all(
            [_is_gpu_device(d) for d in self._devices])

def _get_cross_device_ops(self, value):
  if not self._use_merge_call():
    return self._collective_ops

  # 如果設定了 _prefer_collective_ops,並且其他兩個成員變數有一個是NcclAllReduce,則設定 _collective_ops_in_use。
  if self._collective_ops_in_use:
    if isinstance(value, values.DistributedValues):
      value_int32 = True in {
          dtypes.as_dtype(v.dtype) == dtypes.int32 for v in value.values
      }
    else:
      value_int32 = dtypes.as_dtype(value.dtype) == dtypes.int32
    if value_int32:
      return cross_device_ops_lib.ReductionToOneDevice()
    else:
      return self._collective_ops

  return self._cross_device_ops or self._inferred_cross_device_ops

具體邏輯如下:

4. 更新分散式變數

我們接下來看看如何更新分散式變數,因為篇幅原因,這裡分析的目的只是大致把流程走通,如果有興趣的讀者可以繼續深入研究。

4.1 樣例

分散式變數是在多個裝置上建立的變數,Mirrored variable和 SyncOnRead variable 是兩個例子。一個操作分散式變數的程式碼如下,首先呼叫 reduce_to,然後呼叫 update。

    >>> @tf.function
    ... def step_fn(var):
    ...
    ...   def merge_fn(strategy, value, var):
    ...     # All-reduce the value. Note that value here is a
    ...     # tf.distribute.DistributedValues.
    ...     reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM,
    ...         value, destinations=var)
    ...     strategy.extended.update(var, lambda var, value: var.assign(value),
    ...         args=(reduced,))
    ...
    ...   value = tf.identity(1.)
    ...   tf.distribute.get_replica_context().merge_call(merge_fn,
    ...     args=(value, var))
    >>>
    >>> def run(strategy):
    ...   with strategy.scope():
    ...     v = tf.Variable(0.)
    ...     strategy.run(step_fn, args=(v,))
    ...     return v
    >>>
    >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
    MirroredVariable:{
      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
    }
    >>> run(tf.distribute.experimental.CentralStorageStrategy(
    ...     compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
    >>> run(tf.distribute.OneDeviceStrategy("GPU:0"))
    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>

4.2 規約

我們首先看看 reduce_to 操作。

StrategyExtendedV2

程式碼首先來到 StrategyExtendedV2。reduce_to 聚合了 tf.distribution.DistributedValues 和分散式變數。它同時支援 dense values 和tf.IndexedSlices。這個 API 目前只能在跨副本背景下呼叫(cross-replica)。其他用於跨副本規約的變體是。

  • tf.distribution.StrategyExtended.batch_reduce_to:批量版本 API。
  • tf.distribution.ReplicaContext.all_reduce:在副本上下文中的對應 API 版本。它同時支援批處理和非批處理的 all-reduce。
  • tf.distribution.Strategy.reduce:在跨副本上下文中的規約到主機的 API,使用起來更加便捷。

Destinations 指定將數值規約到哪裡,例如 "GPU:0"。你也可以傳入一個張量,這樣規約目的地將是該張量的裝置。對於 all-reduce 需要傳遞相同型別的 value 和 destinations。

def reduce_to(self, reduce_op, value, destinations, options=None):
  """Combine (via e.g. sum or mean) values across replicas.

  reduce_to aggregates tf.distribute.DistributedValues and distributed
  variables. It supports both dense values and tf.IndexedSlices.

  This API currently can only be called in cross-replica context. Other
  variants to reduce values across replicas are:
  * tf.distribute.StrategyExtended.batch_reduce_to: the batch version of
    this API.
  * tf.distribute.ReplicaContext.all_reduce: the counterpart of this API
    in replica context. It supports both batched and non-batched all-reduce.
  * tf.distribute.Strategy.reduce: a more convenient method to reduce
    to the host in cross-replica context.

  destinations specifies where to reduce the value to, e.g. "GPU:0". You can
  also pass in a Tensor, and the destinations will be the device of that
  tensor. For all-reduce, pass the same to value and destinations.

  It can be used in tf.distribute.ReplicaContext.merge_call to write code
  that works for all tf.distribute.Strategy.

  Args:
    reduce_op: a tf.distribute.ReduceOp value specifying how values should
      be combined. Allows using string representation of the enum such as
      "SUM", "MEAN".
    value: a tf.distribute.DistributedValues, or a tf.Tensor like object.
    destinations: a tf.distribute.DistributedValues, a tf.Variable, a
      tf.Tensor alike object, or a device string. It specifies the devices
      to reduce to. To perform an all-reduce, pass the same to value and
      destinations. Note that if it's a tf.Variable, the value is reduced
      to the devices of that variable, and this method doesn't update the
      variable.
    options: a tf.distribute.experimental.CommunicationOptions. Options to
      perform collective operations. This overrides the default options if the
      tf.distribute.Strategy takes one in the constructor. See
      tf.distribute.experimental.CommunicationOptions for details of the
      options.

  Returns:
    A tensor or value reduced to destinations.
  """
  if options is None:
    options = collective_util.Options()
  _require_cross_replica_or_default_context_extended(self)
  if isinstance(reduce_op, six.string_types):
    reduce_op = reduce_util.ReduceOp(reduce_op.upper())

  return self._reduce_to(reduce_op, value, destinations, options)

MirroredExtended

這裡有幾種執行流程,比如使用 _get_cross_device_ops 來得到集合通訊函式。

def _reduce_to(self, reduce_op, value, destinations, options):
  if (distribute_utils.is_mirrored(value) and
      reduce_op == reduce_util.ReduceOp.MEAN):
    return value

  def get_values(value):
    if not isinstance(value, values.DistributedValues):
      # This function handles reducing values that are not PerReplica or
      # Mirrored values. For example, the same value could be present on all
      # replicas in which case value would be a single value or value could
      # be 0.
      return cross_device_ops_lib.reduce_non_distributed_value(
          reduce_op, value, destinations, self._num_replicas_in_sync)
    if self._use_merge_call() and self._collective_ops_in_use and ((
        not cross_device_ops_lib._devices_match(value, destinations) or  
        any("cpu" in d.lower()
            for d in cross_device_ops_lib.get_devices_from(destinations)))):
      return cross_device_ops_lib.ReductionToOneDevice().reduce(
          reduce_op, value, destinations)
    
    return self._get_cross_device_ops(value).reduce(
        reduce_op,
        value,
        destinations=destinations,
        options=self._communication_options.merge(options))

  return nest.map_structure(get_values, value)

4.3 更新

我們其次看看 update 操作。

StrategyExtendedV2

Update 執行fn 使用 input 來更新 var,映象到相同裝置的輸入。tf.distribution.StrategyExtended.update 接收一個要更新的分散式變數 var,一個更新函式 fn,以及用於 fn 的 args 和 kwargs。然後從 args 和 kwargs 傳遞相應的值,將 fn 應用於 var 的每個元件變數。

Args 和 kwargs 都不能包含 per-replica values。如果它們包含 mirrored values,則在呼叫fn之前,它們將被解包(unwrapped)。例如,fn 可以是 assign_add,args 可以是一個映象(mirrored)的 DistributedValues,DistributedValues 中每個元件都包含要被新增到這個映象變數 var 的值。呼叫 update 將使用在裝置上相關張量來在 var 的每個元件變數上呼叫 assign_add。

def update(self, var, fn, args=(), kwargs=None, group=True):
  """Run fn to update var using inputs mirrored to the same devices.

  tf.distribute.StrategyExtended.update takes a distributed variable var
  to be updated, an update function fn, and args and kwargs for fn. It
  applies fn to each component variable of var and passes corresponding
  values from args and kwargs. Neither args nor kwargs may contain
  per-replica values. If they contain mirrored values, they will be unwrapped
  before calling fn. For example, fn can be assign_add and args can be
  a mirrored DistributedValues where each component contains the value to be
  added to this mirrored variable var. Calling update will call
  assign_add on each component variable of var with the corresponding
  tensor value on that device.

  Example usage:

  ```python
  strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2
  devices
  with strategy.scope():
    v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
  def update_fn(v):
    return v.assign(1.0)
  result = strategy.extended.update(v, update_fn)
  # result is
  # Mirrored:{
  #  0: tf.Tensor(1.0, shape=(), dtype=float32),
  #  1: tf.Tensor(1.0, shape=(), dtype=float32)
  # }
  ```

  If var is mirrored across multiple devices, then this method implements
  logic as following:

  ```python
  results = {}
  for device, v in var:
    with tf.device(device):
      # args and kwargs will be unwrapped if they are mirrored.
      results[device] = fn(v, *args, **kwargs)
  return merged(results)
  ```

  Otherwise, this method returns fn(var, *args, **kwargs) colocated with
  var.

  Args:
    var: Variable, possibly mirrored to multiple devices, to operate on.
    fn: Function to call. Should take the variable as the first argument.
    args: Tuple or list. Additional positional arguments to pass to fn().
    kwargs: Dict with keyword arguments to pass to fn().
    group: Boolean. Defaults to True. If False, the return value will be
      unwrapped.

  Returns:
    By default, the merged return value of fn across all replicas.  The
    merged result has dependencies to make sure that if it is evaluated at
    all, the side effects (updates) will happen on every replica. If instead
    "group=False" is specified, this function will return a nest of lists
    where each list has an element per replica, and the caller is responsible
    for ensuring all elements are executed.
  """
  if kwargs is None:
    kwargs = {}
  replica_context = distribution_strategy_context.get_replica_context()

  if (replica_context is None or replica_context is
      distribution_strategy_context._get_default_replica_context()):
    fn = autograph.tf_convert(
        fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
    with self._container_strategy().scope():
      return self._update(var, fn, args, kwargs, group) # 呼叫到派生類
  else:
    return self._replica_ctx_update(
        var, fn, args=args, kwargs=kwargs, group=group)

MirroredExtended

MirroredExtended 會把更新組合成列表,呼叫 fn 完成取值,然後呼叫 distribute_utils.update_regroup 完成 regroup 操作。

def _update(self, var, fn, args, kwargs, group):

  assert isinstance(var, values.DistributedVariable)
  updates = []
  for i, v in enumerate(var.values): # 遍歷 var 的 component
    name = "update_%d" % i
    with ops.device(v.device), \
         distribute_lib.UpdateContext(i), \
         ops.name_scope(name):
      # If args and kwargs are not mirrored, the value is returned as is.
      updates.append(
          fn(v, *distribute_utils.select_replica(i, args),
             **distribute_utils.select_replica(i, kwargs)))
  return distribute_utils.update_regroup(self, updates, group)

distribute_utils

distribute_utils.update_regroup 會完成 regroup 操作,限於篇幅這裡不做深入,有興趣讀者可以自行研究。

def update_regroup(extended, updates, group):
  """Regroup for an update, with dependencies to ensure all updates execute."""
  if not group:
    regrouped = regroup(updates, values_lib.Mirrored)
    return nest.map_structure(extended._local_results, regrouped)  

  def _make_grouped_mirrored(values):
    """Convert per-replica list values into Mirrored type with grouping."""
    if len(values) == 1:
      return values_lib.Mirrored(values)

    # Make sure we run all updates. Without this, something like
    # session.run(extended.update(...)) may only update one replica.
    g = control_flow_ops.group(values)

    # If values is just ops, the grouping is enough. Everything in values
    # should have the same type, since we expect every replica to be performing
    # the same computation.
    if not all(tensor_util.is_tf_type(v) for v in values):
      return g

    # Otherwise we need tensors with the same values as values, but
    # that have a dependency on g.
    with_dep = []
    for v in values:
      with ops.device(v.device), ops.control_dependencies([g]):
        with_dep.append(array_ops.identity(v))

    return values_lib.Mirrored(with_dep)

  return regroup(updates, _make_grouped_mirrored)

邏輯如下:

圖 更新操作

0xFF 參考

tensorflow原始碼解析之distributed_runtime

TensorFlow分散式訓練

Tensorflow分散式原理理解

TensorFlow架構與設計:概述

Tensorflow 跨裝置通訊

TensorFlow 篇 | TensorFlow 2.x 分散式訓練概覽

相關文章