[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V1

羅西的思考 發表於 2022-05-08

[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V1

本章我們看看 ParameterServerStrategy,就是第一版程式碼。研究這個是因為目前工業界還有很多公司在使用,而且其內部機制也比較清晰易懂,值得我們分析。

安利兩個github,都是非常好的學習資料,推薦。

https://github.com/yuhuiaws/ML-study

https://github.com/Jack47/hack-SysML

另外推薦西門宇少的最新大作讓Pipeline在Transformer LM上沿著Token level並行起來——TeraPipe

本系列其他文章是:

[翻譯] TensorFlow 分散式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"

[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構

[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache

[原始碼解析] TensorFlow 分散式環境(5) --- Session

[原始碼解析] TensorFlow 分散式環境(7) --- Worker 動態邏輯

[原始碼解析] TensorFlow 分散式環境(8) --- 通訊機制

[翻譯] 使用 TensorFlow 進行分散式訓練

[原始碼解析] TensorFlow 分散式 DistributedStrategy 之基礎篇

[原始碼解析] TensorFlow 之 分散式變數

[原始碼解析] TensorFlow 分散式之 MirroredStrategy

[原始碼解析] TensorFlow 分散式之 MirroredStrategy 分發計算

1. 思路

引數伺服器訓練是一種常見的資料並行方法,用於在多臺機器上擴充套件機器學習模型。一個引數伺服器訓練叢集由工作者和引數伺服器組成。變數是在引數伺服器上建立的,它們在每個步驟中被工作者讀取和更新。預設情況下,工作者獨立地讀取和更新這些變數,而不互相同步。在這種配置下,它被稱為非同步訓練。

Tensorflow 支援兩種方式實現 parameter server:低階 API 建立 parameter server 叢集方式和 tf.distribute.Strategy 中的 ParameterServerStrategy。ParameterServerStrategyV1 的主要作用就是把變數分佈在 ps 之上,計算分佈在 worker 之上。我們將從幾個方面來研究:

  • 如何與叢集建立連線。
  • 如何獲取資料。
  • 如何生成變數。
  • 如何執行。

1.1 總體邏輯

ParameterServerStrategyV1 是一個非同步的多工作者引數伺服器 tf.distribution 策略。這個策略需要兩個角色:工作者(worker)和引數伺服器。變數和對這些變數的更新將被分配給引數伺服器,其他操作則被分配給 工作者。

當每個工作者有一個以上的 GPU 時,操作將被複制到所有 GPU 上,但變數不會被複制,每個工作者共享一個共同的檢視,以確定一個變數被分配到哪個引數伺服器。預設狀態下,ParameterServerStrategyV1 使用 TFConfigClusterResolver 來查詢多工作者的配置,這需要一個 'TF_CONFIG' 環境變數,並且 'TF_CONFIG' 必須有一個叢集規格。

該類假設每個工作者獨立執行相同的程式碼,而但引數伺服器則執行一個標準伺服器。這意味著,雖然每個工作者將在所有 GPU 上同步計算一個梯度更新,但工作器之間的更新是非同步進行的。即使只有 CPU 或一個 GPU,也應該呼叫"call_for_each_replica(fn, ...)" 來進行任何可能跨副本複製的操作(即多個 GPU)。當定義"fn" 時,需要注意以下幾點:

  1. 一般不建議在策略的作用域(scope)內再開啟一個裝置作用域。裝置作用域(即呼叫 tf.device)將合併或者覆蓋操作的裝置,但不會改變變數的裝置。
  2. 也不建議在策略的作用域(scope)內再開啟一個 colocation 作用域(strategy.extended.colocate_vars_with),對於 colocating variables,則使用strategy.extended.colocate_vars_with 。協同操作可能會產生裝置分配衝突。

注意:該策略僅適用於 Estimator API。當你建立"RunConfig"時,把這個策略的一個例項傳遞給"experimental_distribute"引數。而這個"RunConfig"的例項應該被傳遞給"Estimator"例項,然後在這個"Estimator" 例項上呼叫"train_and_evaluate"。

1.2 使用

ParameterServerStrategy 的使用樣例如下:

  strategy = tf.distribute.experimental.ParameterServerStrategy()
  run_config = tf.estimator.RunConfig(
      experimental_distribute.train_distribute=strategy)
  estimator = tf.estimator.Estimator(config=run_config)
  tf.estimator.train_and_evaluate(estimator,...)

1.3 定義

ParameterServerStrategyV1 的定義和初始化沒有什麼可以研究的,主要是使用 ParameterServerStrategyExtended 完成初始化,摘錄如下:

@tf_export(v1=["distribute.experimental.ParameterServerStrategy"])  
class ParameterServerStrategyV1(distribute_lib.StrategyV1):
  def __init__(self, cluster_resolver=None):
  """Initializes this strategy with an optional cluster_resolver.

    Args:
      cluster_resolver: Optional
        tf.distribute.cluster_resolver.ClusterResolver object. Defaults to a
        tf.distribute.cluster_resolver.TFConfigClusterResolver.
  """
    if cluster_resolver is None:
      cluster_resolver = TFConfigClusterResolver()
    super(ParameterServerStrategyV1, self).__init__(
        ParameterServerStrategyExtended(
            self, cluster_resolver=cluster_resolver))
    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
      "ParameterServerStrategy")    

2. ParameterServerStrategyExtended

ParameterServerStrategyExtended 派生自 distribute_lib.StrategyExtendedV1,提供了可以分散式感知的演算法附加 API。

class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
"""Implementation of ParameterServerStrategy and CentralStorageStrategy."""

  def __init__(self,
               container_strategy,
               cluster_resolver=None,
               compute_devices=None,
               parameter_device=None):
    super(ParameterServerStrategyExtended, self).__init__(container_strategy)
    self._initialize_strategy(
        cluster_resolver=cluster_resolver,
        compute_devices=compute_devices,
        parameter_device=parameter_device)

    # We typically don't need to do all-reduce in this strategy.
    self._cross_device_ops = (
        cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU))

2.1 初始化

這部分完成了獲取叢集資訊的工作。_initialize_strategy 依據 spec 不同選擇啟動本地還是多工作者,我們只研究多工作者的情況。

  def _initialize_strategy(self,
                           cluster_resolver=None,
                           compute_devices=None,
                           parameter_device=None):
    if cluster_resolver and cluster_resolver.cluster_spec():
      self._initialize_multi_worker(cluster_resolver)
    else:
      self._initialize_local(
          compute_devices, parameter_device, cluster_resolver=cluster_resolver)

_initialize_multi_worker 這裡會做一系列配置,比如:

  • 獲取 gpu 數量。

  • 從叢集配置之中獲取資訊。

  • 設定工作裝置和輸入裝置名稱。

  • 設定計算裝置列表。

  • 分配裝置策略。

  • 得到引數伺服器裝置列表。

  def _initialize_multi_worker(self, cluster_resolver):
  """Initialize devices for multiple workers.

    It creates variable devices and compute devices. Variables and operations
    will be assigned to them respectively. We have one compute device per
    replica. The variable device is a device function or device string. The
    default variable device assigns variables to parameter servers in a
    round-robin fashion.

    Args:
      cluster_resolver: a descendant of ClusterResolver object.

    Raises:
      ValueError: if the cluster doesn't have ps jobs.
  """
    # 獲取gpu數量
    if isinstance(cluster_resolver, TFConfigClusterResolver):
      num_gpus = context.num_gpus()
    else:
      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)

    # Save the num_gpus_per_worker for configure method.
    self._num_gpus_per_worker = num_gpus

    # 從叢集配置之中獲取資訊
    cluster_spec = cluster_resolver.cluster_spec()
    task_type = cluster_resolver.task_type
    task_id = cluster_resolver.task_id
    cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
    assert cluster_spec.as_dict()

    # 設定工作裝置和輸入裝置名稱
    self._worker_device ="/job:%s/task:%d" % (task_type, task_id)
    self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)

    # Define compute devices which is a list of device strings and one for each
    # replica. When there are GPUs, replicate operations on these GPUs.
    # Otherwise, place operations on CPU.
    
    # 設定計算裝置列表
    if num_gpus > 0:
      compute_devices = tuple(
        "%s/device:GPU:%d" % (self._worker_device, i)
          for i in range(num_gpus))
    else:
      compute_devices = (self._worker_device,)

    self._compute_devices = [
        device_util.canonicalize(d) for d in compute_devices]

    # In distributed mode, place variables on ps jobs in a round-robin fashion.
    # Note that devices returned from replica_device_setter are not
    # canonical and therefore we don't canonicalize all variable devices to
    # make them consistent.
    # TODO(yuefengz): support passing a strategy object to control variable
    # assignment.

    # 分配裝置策略,變數放到哪個裝置上
    num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
    self._variable_device = device_setter.replica_device_setter(
        ps_tasks=num_ps_replicas, # 引數伺服器
        worker_device=self._worker_device, # 工作裝置
        merge_devices=True,
        cluster=cluster_spec)

    # The _parameter_devices is needed for the parameter_devices property
    # and is a list of all variable devices. Here parameter devices are all
    # tasks of the"ps" job.
    
    # 得到引數伺服器裝置列表
    self._parameter_devices = tuple(map("/job:ps/task:{}".format,
                                        range(num_ps_replicas)))

    # Add a default device so that ops without specified devices will not end up
    # on other workers.
    self._default_device = self._worker_device
    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
                                                task_id)
    self._cluster_spec = cluster_spec
    self._task_type = task_type
    self._task_id = task_id

2.2 分配裝置

我們接下來看看如何分配裝置。在目前狀態下,分配裝置就是給每個計算圖指定一個裝置名字,後續真正執行時候,系統會根據這個裝置名字再具體進行分配。

2.2.1 replica_device_setter

replica_device_setter 返回一個裝置函式 device function(或者說是策略),當為副本建立計算圖時候,此策略將提供資訊,該資訊用來指導計算圖應該分配到哪個裝置上。裝置函式與 with tf.device(device_function) 一起使用。當構建時候,Operation 會自動被對映到裝置函式提供的裝置之上。裝置約束首先從最內部的上下文新增,然後向外工作。如果 'cluster' 為 'None' 且 'ps_tasks' 為 0,則返回的函式為 no-op。否則,'ps_tasks' 的值派生自 'cluster'。如果'ps_tasks' 數值不為0,則後續變數就放到ps_device之上,否則放到 worker_device 之上。

@tf_export(v1=["train.replica_device_setter"])
def replica_device_setter(ps_tasks=0,
                          ps_device="/job:ps",
                          worker_device="/job:worker",
                          merge_devices=True,
                          cluster=None,
                          ps_ops=None,
                          ps_strategy=None):
"""Return a device function to use when building a Graph for replicas.

  Device Functions are used in with tf.device(device_function): statement to
  automatically assign devices to Operation objects as they are constructed,
  Device constraints are added from the inner-most context first, working
  outwards. The merging behavior adds constraints to fields that are yet unset
  by a more inner context. Currently the fields are (job, task, cpu/gpu).

  If cluster is None, and ps_tasks is 0, the returned function is a no-op.
  Otherwise, the value of ps_tasks is derived from cluster.

  Args:
    ps_tasks: Number of tasks in the ps job.  Ignored if cluster is
      provided.
    ps_device: String.  Device of the ps job.  If empty no ps job is used.
      Defaults to ps.
    worker_device: String.  Device of the worker job.  If empty no worker
      job is used.
    merge_devices: Boolean. If True, merges or only sets a device if the
      device constraint is completely unset. merges device specification rather
      than overriding them.
    cluster: ClusterDef proto or ClusterSpec.
    ps_ops: List of strings representing Operation types that need to be
      placed on ps devices.  If None, defaults to STANDARD_PS_OPS.
    ps_strategy: A callable invoked for every ps Operation (i.e. matched by
      ps_ops), that takes the Operation and returns the ps task index to
      use.  If None, defaults to a round-robin strategy across all ps
      devices.

  Returns:
    A function to pass to tf.device().

  Raises:
    TypeError if cluster is not a dictionary or ClusterDef protocol buffer,
    or if ps_strategy is provided but not a callable.
"""
  if cluster is not None:
    if isinstance(cluster, server_lib.ClusterSpec):
      cluster_spec = cluster.as_dict()
    else:
      cluster_spec = server_lib.ClusterSpec(cluster).as_dict()
    # Get ps_job_name from ps_device by stripping"/job:".
    ps_job_name = pydev.DeviceSpec.from_string(ps_device).job
    if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None:
      return None
    ps_tasks = len(cluster_spec[ps_job_name])

  if ps_tasks == 0:
    return None

  if ps_ops is None:
    ps_ops = list(STANDARD_PS_OPS)

  if ps_strategy is None:
    ps_strategy = _RoundRobinStrategy(ps_tasks)

  chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device,
                                  merge_devices, ps_ops, ps_strategy)
  return chooser.device_function

2.2.2 _RoundRobinStrategy

預設情況下,ps 任務上只放置變數 op,並且 placement strategy 是以 round-robin 機制在 ps tasks 之間進行分配。也可以採用比如 tf.contrib.training.GreedyLoadBalancingStrategy。

# To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker
# jobs on hosts worker0, worker1 and worker2.
cluster_spec = {
  "ps": ["ps0:2222","ps1:2222"],
  "worker": ["worker0:2222","worker1:2222","worker2:2222"]}
with
tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
  # Build your graph
  v1 = tf.Variable(...)  # assigned to /job:ps/task:0
  v2 = tf.Variable(...)  # assigned to /job:ps/task:1
  v3 = tf.Variable(...)  # assigned to /job:ps/task:0
# Run compute

_RoundRobinStrategy 具體如下:

class _RoundRobinStrategy(object):
"""Returns the next ps task index for placement in round-robin order.

  This class is not to be used directly by users.  See instead
  replica_device_setter() below.
"""

  def __init__(self, num_tasks):
  """Create a new _RoundRobinStrategy.

    Args:
      num_tasks: Number of ps tasks to cycle among.
  """
    self._num_tasks = num_tasks
    self._next_task = 0

  def __call__(self, unused_op):
  """Choose a ps task index for the given Operation.

    Args:
      unused_op: An Operation to be placed on ps.

    Returns:
      The next ps task index to use for the Operation. Returns the next
      index, in the range [offset, offset + num_tasks).
  """
    task = self._next_task
    self._next_task = (self._next_task + 1) % self._num_tasks
    return task

2.2.3 _ReplicaDeviceChooser

replica_device_setter 返回的是 _ReplicaDeviceChooser.device_function。就是使用 _ps_strategy 來返回裝置名字。這裡會依據_ps_tasks的資訊來決定變數放在 ps_device 之上還是worker_device之上。

class _ReplicaDeviceChooser(object):
"""Class to choose devices for Ops in a replicated training setup.

  This class is not to be used directly by users.  See instead
  replica_device_setter() below.
"""

  def __init__(self, ps_tasks, ps_device, worker_device, merge_devices, ps_ops,
               ps_strategy):
  """Create a new _ReplicaDeviceChooser.

    Args:
      ps_tasks: Number of tasks in the ps job.
      ps_device: String.  Name of the ps job.
      worker_device: String.  Name of the worker job.
      merge_devices: Boolean. Set to True to allow merging of device specs.
      ps_ops: List of strings representing Operation types that need to be
        placed on ps devices.
      ps_strategy: A callable invoked for every ps Operation (i.e. matched by
        ps_ops), that takes the Operation and returns the ps task index to
        use.
  """
    self._ps_tasks = ps_tasks
    self._ps_device = ps_device
    self._worker_device = worker_device
    self._merge_devices = merge_devices
    self._ps_ops = ps_ops
    self._ps_strategy = ps_strategy

  def device_function(self, op):
  """Choose a device for op.

    Args:
      op: an Operation.

    Returns:
      The device to use for the Operation.
  """
    # If we don't return early here, either merge_devices is True, or op.device
    # is empty (in which case merging is a no-op). So we can always merge below.
    if not self._merge_devices and op.device:
      return op.device

    current_device = pydev.DeviceSpec.from_string(op.device or"")

    # The ps_device will be used for specified ops (ps_ops) whenever it is
    # present and ps_tasks is non-zero. However, its task number will only be
    # set (using ps_strategy) if there is a job field in ps_device that won't be
    # changed by the job field (if present) in current_device.
    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
    if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops:
      ps_device = pydev.DeviceSpec.from_string(self._ps_device)

      current_job, ps_job = current_device.job, ps_device.job
      if ps_job and (not current_job or current_job == ps_job):
        # 這裡使用了策略
        ps_device = ps_device.replace(task=self._ps_strategy(op))

      ps_device = ps_device.make_merged_spec(current_device)
      return ps_device.to_string()

    worker_device = pydev.DeviceSpec.from_string(self._worker_device or"")
    worker_device = worker_device.make_merged_spec(current_device)
    return worker_device.to_string()

裝置相關的邏輯總結如下:

[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V1

圖 1 分配裝置

初始化之後,ParameterServerStrategyExtended如下:

[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V1

3. 資料

我們接下來看看如何獲取訓練資料。distribute_datasets_from_function 是呼叫基類 的 distribute_datasets_from_function,所以我們要看看 StrategyBase。

  def distribute_datasets_from_function(self, dataset_fn, options=None):
    if (options and options.experimental_replication_mode ==
        distribute_lib.InputReplicationMode.PER_REPLICA):
      raise NotImplementedError(
        "InputReplicationMode.PER_REPLICA"
        "is only supported in"
        "experimental_distribute_datasets_from_function"
        "of tf.distribute.MirroredStrategy")
    self._raise_pss_error_if_eager()
    super(ParameterServerStrategyV1, self).distribute_datasets_from_function(
        dataset_fn=dataset_fn, options=options)

3.1 StrategyBase

distribute_datasets_from_function 作用是依靠呼叫 'dataset_fn' 來分發 tf.data.Dataset。使用者傳入的引數 dataset_fn 是一個輸入函式。這個輸入引數帶有 InputContext 引數,並返回一個 tf.data.Dataset 例項。dataset_fn 得到的資料集應該是已按每個副本的批大小(即全域性批大小除以同步副本的數量)進行分批次和分片的。Tf.distribute.Strategy.distribute_datasets_from_function 本身不會做分批次和分片操作。

dataset_fn 將在每個工作者的 CPU device 上呼叫並且會生成一個資料集,其中該工作者上的每個 replica 都會將一個輸入 batch 移出佇列(即,如果一個工作者有兩個副本,則每個 step 之中,兩個 batches 將會被從 Dataset 之中移出佇列)。這種方法有多種用途。首先,它允許您指定自己的分批切分邏輯。(相比之下,tf.distribute.experimental_distribute_dataset 為您進行分批和分片。)例如,experimental_distribute_dataset 無法切分輸入檔案,則可以使用此方法來自定義手動切分資料集(避免experimental_distribute_dataset 中的慢回撥行為)。在資料集無限大的情況下,分片可以通過依據隨機種子的不同來建立資料集副本。另外,dataset_fn 應該使用 tf.distribute.InputContext 的例項來得到分批和輸入分片的資訊。

具體呼叫方式如下:

def per_worker_dataset_fn():
	return strategy.distribute_datasets_from_function(
		lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))

這裡我們發現,distribute_datasets_from_function 則又回到了派生類 _distribute_datasets_from_function 方法。

def distribute_datasets_from_function(self, dataset_fn, options=None):
    return self._extended._distribute_datasets_from_function(dataset_fn, options)    

3.2 _distribute_datasets_from_function

_distribute_datasets_from_function 則呼叫了 InputContext 來獲取資料。

  def _distribute_datasets_from_function(self, dataset_fn, options):
    if self._cluster_spec:
      input_pipeline_id = multi_worker_util.id_in_cluster(
          self._cluster_spec, self._task_type, self._task_id)
      num_input_pipelines = multi_worker_util.worker_count(
          self._cluster_spec, self._task_type)
    else:
      input_pipeline_id = 0
      num_input_pipelines = 1

    input_context = distribute_lib.InputContext(
        num_input_pipelines=num_input_pipelines,
        input_pipeline_id=input_pipeline_id,
        num_replicas_in_sync=self._num_replicas_in_sync)

    return input_lib.get_distributed_datasets_from_function(
        dataset_fn,
        self._input_workers_with_options(options), [input_context],
        self._container_strategy(),
        options=options)

3.3 InputLib

這部分程式碼在 tensorflow/python/distribute/input_lib.py,主要就是獲取 iterator。

def get_distributed_datasets_from_function(dataset_fn,
                                           input_workers,
                                           input_contexts,
                                           strategy,
                                           options=None):
"""Returns a distributed dataset from the given input function.

  This is a common function that is used by all strategies to return a
  distributed dataset. The distributed dataset instance returned is different
  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
  instances returned differ from each other in the APIs supported by each of
  them.

  Args:
    dataset_fn: a function that returns a tf.data.Dataset instance.
    input_workers: an InputWorkers object which specifies devices on which
        iterators should be created.
    input_contexts: A list of InputContext instances to be passed to call(s)
        to dataset_fn. Length and order should match worker order in
        worker_device_pairs.
    strategy: a tf.distribute.Strategy object, used to run all-reduce to
        handle last partial batch.
    options: Default is None. tf.distribute.InputOptions used to control
        options on how this dataset is distributed.

  Returns:
    A distributed dataset instance.

  Raises:
    ValueError: if options.experimental_replication_mode and
    options.experimental_place_dataset_on_device are not consistent
"""
  if tf2.enabled():
    return DistributedDatasetsFromFunction(input_workers, strategy,
                                           input_contexts, dataset_fn, options)
  else:
    return DistributedDatasetsFromFunctionV1(input_workers, strategy,
                                             input_contexts, dataset_fn,
                                             options)

DistributedDatasetsFromFunctionV1 則會返回 DistributedIteratorV1,既然得到了 iterator,就可以從資料集之中獲得資料了。

class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
"""Inputs created from dataset function."""

  def _make_initializable_iterator(self, shared_name=None):
  """Get an initializable iterator for DistributedDatasetsFromFunctionV1."""
    del shared_name  # Unused
    # Eager mode generates already initialized iterators. Hence we cannot create
    # an initializable iterator.
    if context.executing_eagerly():
      raise ValueError("Cannot create initializable iterator in Eager mode."
                     "Please use iter() instead.")
    return self._get_iterator()

  def _make_one_shot_iterator(self):
  """Get an iterator for iterating over DistributedDatasetsFromFunctionV1."""
    # Graph mode with one shot iterator is disabled because we have to call
    # initialize on the iterator which is only required if we are using a
    # tf.distribute strategy.
    if not context.executing_eagerly():
      raise ValueError("Cannot create a one shot iterator. Please use"
                     "make_initializable_iterator() instead.")
    return self._get_iterator()

  def _get_iterator(self):
    iterators = _create_iterators_per_worker(self._datasets,
                                             self._input_workers, True,
                                             self._options)
    iterator = DistributedIteratorV1(self._input_workers, iterators,
                                     self._strategy,
                                     self._enable_get_next_as_optional)
    iterator._element_spec = self._element_spec  # pylint: disable=protected-access

    # When async eager is enabled, sometimes the iterator may not finish
    # initialization before passing to a multi device function, add a sync point
    # here to make sure all underlying iterators are initialized.
    if context.executing_eagerly():
      context.async_wait()

    return iterator

  def __iter__(self):
    if (ops.executing_eagerly_outside_functions() or
        ops.get_default_graph().building_function):
      return self._get_iterator()

    raise RuntimeError("__iter__() is only supported inside of tf.function"
                     "or when eager execution is enabled.")

4. 作用域和變數

4.1 StrategyBase

scope 就是呼叫基類的方法。

  def scope(self):
    self._raise_pss_error_if_eager()
    return super(ParameterServerStrategyV1, self).scope()

StrategyBase 的 scope 方法返回一個 Context manager,其使用當前策略來建立分散式變數,當進入 Strategy.scope 時會發生:

  • "strategy" 成為全域性上下文內的 "當前" strategy 。在這個作用域內,tf.distribute.get_strategy() 將返回此策略。在此範圍之外,它返回預設的無操作策略。
  • 進入此作用域也會進入"cross-replica context"。
  • "scope"內的變數建立被策略攔截。每個策略都定義了它想要如何影響變數的建立。像 'MirroredStrategy'、'TPUStrategy' 和 'MultiWorkerMirroredStrategy' 這樣的同步策略會在每個副本上建立複製的變數,而 'ParameterServerStrategy' 在引數伺服器上建立變數。這是使用自定義的 tf.variable_creator_scope 完成的。
  • 在某些策略中,還可以輸入預設的裝置作用域:比如在"MultiWorkerMirroredStrategy"中,為每個工作者輸入預設的裝置作用域 "/CPU:0"。
  def scope(self):
  """Context manager to make the strategy current and distribute variables.

    This method returns a context manager, and is used as follows:

    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0","GPU:1"])
    >>> # Variable created inside scope:
    >>> with strategy.scope():
    ...   mirrored_variable = tf.Variable(1.)
    >>> mirrored_variable
    MirroredVariable:{
      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
    }
    >>> # Variable created outside scope:
    >>> regular_variable = tf.Variable(1.)
    >>> regular_variable
    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>

    Returns:
      A context manager.
  """
    return self._extended._scope(self)  

既然是呼叫了 extended,我們就接著分析。

4.2 StrategyExtendedV2

_scope 則配置瞭如何建立變數,如何獲取變數,如何獲取變數作用域。具體返回給使用者一個 _CurrentDistributionContext,使用者使用比如 creator_with_resource_vars 會呼叫到 派生策略的 _create_variable 來建立變數。

  def _scope(self, strategy):
  """Implementation of tf.distribute.Strategy.scope()."""

    def creator_with_resource_vars(next_creator, **kwargs):
    """Variable creator to use in _CurrentDistributionContext."""
      _require_strategy_scope_extended(self)
      kwargs["use_resource"] = True
      kwargs["distribute_strategy"] = strategy

      # Unwrap initial_value if it is a CheckpointInitialValue to avoid
      # dereferencing a Tensor that is without a name. We still need to
      # propagate the metadata it's holding.
      if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
        checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
        kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
      elif isinstance(kwargs["initial_value"],
                      trackable.CheckpointInitialValueCallable):
        checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
      elif (isinstance(kwargs["initial_value"], functools.partial) and
            isinstance(kwargs["initial_value"].func,
                       trackable.CheckpointInitialValueCallable)):
        # Some libraries (e.g, Keras) create partial function out of initializer
        # to bind shape/dtype, for example:
        #  initial_val = functools.partial(initializer, shape, dtype=dtype)
        # Therefore to get the restore_uid we need to examine the"func" of
        # the partial function.
        checkpoint_restore_uid = kwargs[
          "initial_value"].func.checkpoint_position.restore_uid
      else:
        checkpoint_restore_uid = None

      # 這裡呼叫派生策略的 _create_variable
      created = self._create_variable(next_creator, **kwargs)

      if checkpoint_restore_uid is not None:
        # pylint: disable=protected-access
        # Let the checkpointing infrastructure know that the variable was
        # already restored so it doesn't waste memory loading the value again.
        # In this case of CheckpointInitialValueCallable this may already be
        # done by the final variable creator, but it doesn't hurt to do it
        # again.
        created._maybe_initialize_trackable()
        created._update_uid = checkpoint_restore_uid
       return created

    def distributed_getter(getter, *args, **kwargs):
      return getter(*args, **kwargs)

    return _CurrentDistributionContext(
        strategy,
        variable_scope.variable_creator_scope(creator_with_resource_vars),
        variable_scope.variable_scope(
            variable_scope.get_variable_scope(),
            custom_getter=distributed_getter), self._default_device)

4.2 建立變數

上面講到了 creator_with_resource_vars 會呼叫到派生策略的 _create_variable 來建立變數這裡我們就看看 PS 如何處理。初始化時候配置了 self._variable_device,這樣就知道了應該如何分配變數到設定之上。在後續程式碼之中有 with ops.device(self._variable_device),這就是把後續作用域之中的變數放到self._variable_device之上。

self._variable_device = device_setter.replica_device_setter(
        ps_tasks=num_ps_replicas, # 引數伺服器
        worker_device=self._worker_device, # 工作裝置
        merge_devices=True,
        cluster=cluster_spec)

建立變數如下:

  def _create_variable(self, next_creator, **kwargs):
    
    # 建立變數
    var_creator = self._create_var_creator(next_creator, **kwargs)

    if"colocate_with" in kwargs:
      colocate_with = kwargs["colocate_with"]
      if isinstance(colocate_with, numpy_dataset.SingleDevice):
        with ops.device(colocate_with.device):
          return var_creator(**kwargs)
      with ops.device(None):
        with ops.colocate_with(colocate_with):
          return var_creator(**kwargs)

    with ops.colocate_with(None, ignore_existing=True):
      # 
      with ops.device(self._variable_device): # 這裡使用到了 replica_device_setter
        return var_creator(**kwargs)

具體建立變數是通過 _create_var_creator。這裡主要的是呼叫了 ps_values.AggregatingVariable 生成變數。

  def _create_var_creator(self, next_creator, **kwargs):
    if self._num_replicas_in_sync > 1:
      aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
      if aggregation not in (
          vs.VariableAggregation.NONE,
          vs.VariableAggregation.SUM,
          vs.VariableAggregation.MEAN,
          vs.VariableAggregation.ONLY_FIRST_REPLICA
      ):
        raise ValueError("Invalid variable aggregation mode:" + aggregation +
                       " for variable:" + kwargs["name"])

      def var_creator(**kwargs):
      """Create an AggregatingVariable and fix up collections."""
        # Record what collections this variable should be added to.
        collections = kwargs.pop("collections", None)
        if collections is None:
          collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        kwargs["collections"] = []

        # Create and wrap the variable.
        v = next_creator(**kwargs)
        
        # 建立變數
        wrapped = ps_values.AggregatingVariable(self._container_strategy(), v,
                                                aggregation)

        # Add the wrapped variable to the requested collections.
        # The handling of eager mode and the global step matches
        # ResourceVariable._init_from_args().
        if not context.executing_eagerly():
          g = ops.get_default_graph()
          # If"trainable" is True, next_creator() will add the contained
          # variable to the TRAINABLE_VARIABLES collection, so we manually
          # remove it and replace with the wrapper. We can't set"trainable"
          # to False for next_creator() since that causes functions like
          # implicit_gradients to skip those variables.
          if kwargs.get("trainable", True):
            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            if v in l:
              l.remove(v)
          g.add_to_collections(collections, wrapped)
        elif ops.GraphKeys.GLOBAL_STEP in collections:
          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)

        return wrapped
    
      return var_creator
    else:
      return next_creator

4.3 PS 變數

AggregatingVariable 就是為變數加了一個 wrapper,這樣對於變數的操作就落到了 strategy 之上。這裡只給出了部分程式碼。

# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy.
class AggregatingVariable(variables_lib.Variable, core.Tensor):
"""A wrapper around a variable that aggregates updates across replicas."""

  def __init__(self, strategy, v, aggregation):
    self._distribute_strategy = strategy
    self._v = v
    # NOTE: We don't use"_distributed_container" here because we don't want
    # to trigger that code path in regroup().
    v._aggregating_container = weakref.ref(self)  # pylint: disable=protected-access
    self._aggregation = aggregation

  def get(self):
    return self._v

  @property
  def distribute_strategy(self):
    return self._distribute_strategy

  def __getattr__(self, name):
    return getattr(self._v, name)

  def _assign_func(self, *args, **kwargs):
    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
      f = kwargs.pop("f")
      
      # 這裡使用了跨副本上下文
      if ds_context.in_cross_replica_context():
        if distribute_lib.get_update_replica_id() is not None:
          # We are calling an assign function in an update context.
          return f(self._v, *args, **kwargs)

        # We are calling an assign function in cross replica context, wrap it in
        # an update call.
        # 使用策略來更新        
        return self._distribute_strategy.extended.update(
            self, f, args=args, kwargs=kwargs)
      else:
        replica_context = ds_context.get_replica_context()
        assert replica_context
        # We are calling an assign function in replica context.
        # We reduce the value we want to assign/add/sub. More details about how
        # we handle the different use cases can be found in the _reduce method.
        # We call the function with the reduced value.
        if self._aggregation == vs.VariableAggregation.NONE:
          raise ValueError(
              values_util.aggregation_error_msg.format(
                  variable_type="AggregatingVariable"))

        def merge_fn(strategy,
                     value,
                     use_locking=False,
                     name=None,
                     read_value=True):
          v = values_util.apply_aggregation(strategy, value, self._aggregation,
                                            self)
          if name and isinstance(name, values.PerReplica):
            name = name.values[0]
          return strategy.extended.update(
              self,
              f,
              args=(v,),
              kwargs={
                "use_locking": use_locking,
                "name": name,
                "read_value": read_value
              })
        return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)

  def assign_sub(self, *args, **kwargs):
    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
    return self._assign_func(f=assign_sub_fn, *args, **kwargs)

  def assign_add(self, *args, **kwargs):
    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
    return self._assign_func(f=assign_add_fn, *args, **kwargs)

  def assign(self, *args, **kwargs):
    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
    return self._assign_func(f=assign_fn, *args, **kwargs)

  @property
  def initializer(self):
    return self._v.initializer

  def initialized_value(self):
    return self._v.initialized_value()

  @property
  def initial_value(self):
    return self._v.initial_value

  # 省略大部分程式碼

具體邏輯如下,第一個操作序列是建立變數,第二個操作序列是處理變數。

[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V1

圖 2 建立變數

5. 執行

我們接下來看看 ParameterServerStrategyV1 如何執行。

5.1 基類

ParameterServerStrategyV1 其實呼叫了基類 StrategyV1 的 run 方法,具體定義在 tensorflow/python/distribute/distribute_lib.py。具體在前文之中我們已經分析過,這裡為了行文完整,再次列舉出來如下.

這個方法是用 tf.distribution 物件分發計算的主要方法。它在每個副本上呼叫fn。如果args或kwargs有tf.distribution.DistributedValues,當 fn 在一個特定的副本上執行時,它將與對應於該副本的 tf.distributed.DistributedValues 的元件一起執行。

tf.distribution.DistributedValues 的例子如下:由 tf.distribution.DistributedDataset 產生的tf.distribution.Strategy.experimental_distribute_dataset 或 tf.distribution.Strategy.Dataset 的 tf.distributedDataset,

fn 在副本上下文被呼叫,fn可以呼叫tf.distribution.get_replica_context()來訪問諸如all_reduce等成員。args 或kwargs 中的所有引數可以是一個巢狀的張量結構,例如一個張量列表,在這種情況下,args 和 kwargs 將被傳遞給在每個副本上呼叫的 fn。或者 args 或 kwargs 可以是包含張量或複合張量的tf.compat.v1.TensorInfo.CompositeTensor 的 tf.distributedValues,在這種情況下,每個fn呼叫將得到與其副本對應的tf.distributedValues的元件。

重要的是:根據 tf.distribution.Strategy 的實現和是否啟用 eager execution,fn可能被呼叫一次或多次。如果 fn被註解為 tf.function 或者 tf.distribution.Strategy.run 在 tf.function 中被呼叫(預設情況下 tf.function 中禁止 eager execution),fn 在每個副本中被呼叫一次以生成 Tensorflow 圖,然後被重新用於新輸入的執行。

run 方法之中,主要就是呼叫了 call_for_each_replica。

  def run(self, fn, args=(), kwargs=None, options=None):
  """Invokes fn on each replica, with the given arguments.
  """
    del options

    if not isinstance(args, (list, tuple)):
      raise ValueError(
        "positional args must be a list or tuple, got {}".format(type(args)))

    with self.scope():
      # tf.distribute supports Eager functions, so AutoGraph should not be
      # applied when the caller is also in Eager mode.
      fn = autograph.tf_convert(
          fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
      return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)

Extend

執行來到了 StrategyExtendedV2,其實際上呼叫的是派生類的 _call_for_each_replica。

  def call_for_each_replica(self, fn, args=(), kwargs=None):
  """Run fn once per replica.

    fn may call tf.get_replica_context() to access methods such as
    replica_id_in_sync_group and merge_call().

    merge_call() is used to communicate between the replicas and
    re-enter the cross-replica context. All replicas pause their execution
    having encountered a merge_call() call. After that the
    merge_fn-function is executed. Its results are then unwrapped and
    given back to each replica call. After that execution resumes until
    fn is complete or encounters another merge_call().  Example:

    ```python
    # Called once in"cross-replica" context.
    def merge_fn(distribution, three_plus_replica_id):
      # sum the values across replicas
      return sum(distribution.experimental_local_results(three_plus_replica_id))

    # Called once per replica in distribution, in a"replica" context.
    def fn(three):
      replica_ctx = tf.get_replica_context()
      v = three + replica_ctx.replica_id_in_sync_group
      # Computes the sum of the v values across all replicas.
      s = replica_ctx.merge_call(merge_fn, args=(v,))
      return s + v

    with distribution.scope():
      # in"cross-replica" context
      ...
      merged_results = distribution.run(fn, args=[3])
      # merged_results has the values from every replica execution of fn.
      # This statement prints a list:
      print(distribution.experimental_local_results(merged_results))
    ```

    Args:
      fn: function to run (will be run once per replica).
      args: Tuple or list with positional arguments for fn.
      kwargs: Dict with keyword arguments for fn.

    Returns:
      Merged return value of fn across all replicas.
  """
    _require_cross_replica_or_default_context_extended(self)
    if kwargs is None:
      kwargs = {}
    with self._container_strategy().scope():
      return self._call_for_each_replica(fn, args, kwargs)

5.2 派生

派生類 ParameterServerStrategyExtended 的 _call_for_each_replica 如下:

  def _call_for_each_replica(self, fn, args, kwargs):
    return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
                                              args, kwargs)

具體 mirrored_run 部分已經在前文分析過,這裡不再贅述,具體邏輯如下:

[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V1

圖 3 執行

或者從另一個角度如下圖所示:

[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V1

0xFF 參考

https://www.youtube.com/watch?v=B2Tpv_N7wkg&ab_channel=TensorFlow