[原始碼解析] TensorFlow 之 分散式變數

羅西的思考發表於2022-04-14

[原始碼解析] TensorFlow 之 分散式變數

在 TensorFlow 之中,分散式變數是在多個裝置上建立的變數。Mirrored variable 和 SyncOnRead variable 是兩個例子。本文就對分散式變數進行分析。我們通過一系列問題來引導分析:

  • 建立如何呼叫到 Strategy 這裡?
  • 如何生成 Mirrored Variable?
  • 如何把張量分發到各個裝置上?
  • 如果對外保持一個統一的檢視?
  • 變數之間如何保持一致?

依然安利兩個大神:

[TensorFlow Internals] (https://github.com/horance-liu/tensorflow-internals),雖然其分析的不是最新程式碼,但是建議對 TF 內部實現機制有興趣的朋友都去閱讀一下,絕對大有收穫。
https://home.cnblogs.com/u/deep-learning-stacks/ 西門宇少,不僅僅是 TensorFlow,其公共號還有更多其他領域,業界前沿。

本系列其他文章是:

[翻譯] TensorFlow 分散式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"

[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構

[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache

[原始碼解析] TensorFlow 分散式環境(5) --- Session

[原始碼解析] TensorFlow 分散式環境(7) --- Worker 動態邏輯

[原始碼解析] TensorFlow 分散式環境(8) --- 通訊機制

[翻譯] 使用 TensorFlow 進行分散式訓練

[原始碼解析] TensorFlow 分散式 DistributedStrategy 之基礎篇

1. MirroredVariable

tf.distribute.MirroredStrategy 支援在一臺機器的多個 GPU 上進行同步分散式訓練。該策略會為每個 GPU 裝置建立一個副本。模型中的每個變數都會在所有副本之間進行映象。這些變數將共同形成一個名為 MirroredVariable 的單個概念上的變數。這些變數會通過應用相同的更新彼此保持同步。

圖 1 MirroredVariable

具體使用程式碼示例如下:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
# Variable created inside scope:
with strategy.scope():
  mirrored_variable = tf.Variable(1.)

# Variable created outside scope:
regular_variable = tf.Variable(1.)

列印結果如下:

>>> mirrored_variable
  MirroredVariable:{
    0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
    1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
  }

>>> regular_variable
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>  

或者也可以參見 tensorflow/python/module/module_test.py 之中的示例.

def test_supports_distributed_variables(self):
  mirrored = distributed_values.MirroredVariable(
      None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
  tpu = tpu_values.TPUMirroredVariable(
      strategy=None, values=[variables.Variable(42.)], aggregation=None)
  aggregating = ps_values.AggregatingVariable(
      strategy=None, v=variables.Variable(1.), aggregation=None)

  m = module.Module()
  m.a = mirrored

1.1 定義

MirroredVariable 註釋之中指出其作用是 :儲存一個從副本到變數的對映,這些變數的值保持同步。具體沒有任何新增成員變數,只是實現了一些成員函式。

class MirroredVariable(DistributedVariable, Mirrored):
  """Holds a map from replica to variables whose values are kept in sync."""

  def _update_replica(self, update_fn, value, **kwargs):
    return _on_write_update_replica(self, update_fn, value, **kwargs)

  def scatter_min(self, *args, **kwargs):
    if values_util.is_saving_non_distributed():
      return self._primary.scatter_min(*args, **kwargs)
    return super(MirroredVariable, self).scatter_min(*args, **kwargs)

  def scatter_max(self, *args, **kwargs):
    if values_util.is_saving_non_distributed():
      return self._primary.scatter_max(*args, **kwargs)
    return super(MirroredVariable, self).scatter_max(*args, **kwargs)

  def scatter_update(self, *args, **kwargs):
    if values_util.is_saving_non_distributed(): # 非分散式情況
      # 直接返回本地數值
      return self._primary.scatter_update(*args, **kwargs)
    # 否則進行分散式處理
    return super(MirroredVariable, self).scatter_update(*args, **kwargs)

  def _get_cross_replica(self):
    # Return identity, to avoid directly exposing the variable to the user and
    # allowing it to be modified by mistake.
    return array_ops.identity(Mirrored._get_cross_replica(self))

我們以 scatter_update 為例看看,當不是分散式時候,其會直接呼叫 _primary 進行處理,否則會呼叫基類方法處理。另外,_update_replica 方法在更新時候會呼叫 _on_write_update_replica 進行副本同步,_on_write_update_replica 又會從使用上下文來進行更新,具體定義在 tensorflow/python/distribute/values.py 之中。

def _on_write_update_replica(var, update_fn, value, **kwargs):
  """Updates variables with ON_WRITE synchronization in replica context."""
  if var.aggregation == vs.VariableAggregation.NONE:
    return update_fn(var._get_on_device_or_primary(), value, **kwargs) 

    aggregated_value = apply_aggregation_replica_context(
        value, var.aggregation, var)
    values_util.mark_as_unsaveable()

    return ds_context.get_replica_context()._update(  
        var,
        update_fn,
        args=(aggregated_value,),
        kwargs=kwargs,
        group=True)

  else:

    def merge_fn(strategy, value, **kwargs):
      """Aggregate values and update all variables in cross replica context."""
      v = values_util.apply_aggregation(strategy, value, var.aggregation, var)
      return var._update_cross_replica(update_fn, v, **kwargs)  

    return ds_context.get_replica_context().merge_call(
        merge_fn, args=(value,), kwargs=kwargs)

只看這些成員方法,我們很難對 MirroredVariable 有一個清晰認識,我們還是需要從其類體系入手來分析。

1.2 相關類

1.2.1 類體系

MirroredVariable 類體系如下,我們會在逐一分析之後,再最終進行彙總。

圖 2 MirroredVariable 類體系

1.2.2 DistributedValues

我們首先看看 DistributedValues。

圖 3 DistributedValues

分散式變數(DistributedValues)由基類 tf.distribution.DistributedValues 表示。 tf.distributed.DistributedValues 概念適合表示多個裝置上的值,它包含一個從副本ID到值的對映。

tf.distributed.DistributedValues 包含每個副本的一個值。根據子類的不同,這些值可以在更新時同步,也可以在需求時同步,或者從不同步。 tf.distributed.DistributedValues 可以規約(reduce)以獲得跨副本的單一值來作為 tf.distributed.Strategy.run 的輸入,或使用 tf.distributed.Strategy.experimental_local_results 檢查每個副本的值。

DistributedValues 作為基類不應該被直接例項化。而應該在 distribution strategy 之中建立其子類例項,具體可以通過在 tf.distribution.DistributedDataset 迭代或者通過 tf.distribution.Strategy.run 建立。

tf.distributed.DistributedValues 的兩種代表性型別是 "PerReplica" 和 "Mirrored" 值。

  • "PerReplica"值存在於 worker 裝置上,每個副本有不同的值。它們是由 tf.distribution.Strategy.experimental_distribute_dataset 和 tf.distribution.Strategy.distribution_datasets_from_function 返回的分散式資料集的迭代產生。它們也是由 tf.distribution.Strategy.run 返回的典型結果。

  • "Mirrored"值與 "PerReplica"值類似,只是所有副本上的值都是一樣的。我們可以通過使用任何副本上的值,在跨副本上下文中安全地讀取 "Mirrored"值。

定義

DistributedValues 有 兩個成員變數比較重要,_values 和 _primary。初始化變數被設定到 _values 陣列之中,陣列第一個變數被複製為 _primary。

因為派生類會用到,所以我們分析 DistributedValues 的幾個成員函式。

  • _get_on_device_or_primary 就是返回本副本對應的value,或者直接返回 _primary 對應的value。
  • _get_cross_replica :返回跨副本value,這個留給派生類實現。
  • _get :如果得到replica_id,就呼叫 _get_cross_replica 返回跨副本數值,或者返回本地資料。

概念圖如下:

圖 4 DistributedValues

DistributedValues 具體程式碼如下:

@tf_export("distribute.DistributedValues", v1=[])
class DistributedValues(object):
  """Base class for representing distributed values.

  A subclass instance of  tf.distribute.DistributedValues  is created when
  creating variables within a distribution strategy, iterating a
   tf.distribute.DistributedDataset  or through  tf.distribute.Strategy.run .
  This base class should never be instantiated directly.
   tf.distribute.DistributedValues  contains a value per replica. Depending on
  the subclass, the values could either be synced on update, synced on demand,
  or never synced.

   tf.distribute.DistributedValues  can be reduced to obtain single value across
  replicas, as input into  tf.distribute.Strategy.run  or the per-replica values
  inspected using  tf.distribute.Strategy.experimental_local_results .
  """

  def __init__(self, values):
    """Should only be called by subclass __init__."""
    self._values = tuple(values)

  def _get(self):
    """Returns the value for the current device or raises a ValueError."""
    replica_id = values_util.get_current_replica_id_as_int()
    if replica_id is None:
      return self._get_cross_replica() # 返回跨副本資訊
    else:
      return self._values[replica_id] # 返回本地資訊

  def _get_cross_replica(self):
    raise NotImplementedError(
        "DistributedValues._get_cross_replica should be implemented by "
        "sub-classes which support cross-replica accesses.")

  def _get_on_device_or_primary(self):
    """Returns value in same replica or device if possible, else the _primary."""
    # 獲取當前副本id
    replica_id = values_util.get_current_replica_id_as_int()
    if replica_id is None: # 如果沒有副本id,則看看本機上裝置集合
      # Try to find a value on the current device.
      # 拿到當前裝置名字,current_device 是一個string
      current_device = device_util.canonicalize(device_util.current())
      for value in self._values: # 遍歷
        if device_util.canonicalize(value.device) == current_device:
          return value # 返回
      return self._primary # 返回 _primary
    else:
      # 返回本副本對應的value
      return self._values[replica_id]

  @property
  def _primary(self):
    """Returns a representative component."""
    return self._values[0]

  @property
  def _devices(self):
    return tuple(v.device for v in self._values)

上面程式碼之中大量用到了 get_current_replica_id_as_int,此函式定義在 tensorflow/python/distribute/values_util.py 之中,作用是獲取當前副本id。

def get_current_replica_id_as_int():
  """Returns the current replica ID as an integer, or  None ."""
  replica_context = ds_context.get_replica_context()
  if replica_context:
    replica_id = replica_context._replica_id
    if not isinstance(replica_id, int):
      replica_id = tensor_util.constant_value(replica_id)
  else:
    replica_id = distribute_lib.get_update_replica_id()
  return replica_id
使用

我們從原始碼之中找出一些使用例子如下,都是使用 MirroredStrategy 來獲取 DistributedValues。

# 1. Created from a  tf.distribute.DistributedDataset :
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)

# 2. Returned by  run :
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@tf.function
def run():
   ctx = tf.distribute.get_replica_context()
   return ctx.replica_id_in_sync_group
distributed_values = strategy.run(run)

# 3. As input into  run :
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
@tf.function
def run(input):
   return input + 1.0
updated_value = strategy.run(run, args=(distributed_values,))

# 4. Reduce value:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                 distributed_values,
                                 axis = 0)

# 5. Inspect local replica values:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
per_replica_values = strategy.experimental_local_results(distributed_values)
print(per_replica_values)

# 輸出結果
#  (<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,
#   <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>)

1.2.3 DistributedDelegate

接下來我們看看 DistributedDelegate 。

圖 5 DistributedDelegate

DistributedDelegate 作用是在 DistributedValues 之上增加了計算功能。具體是通過 _get_as_operand 來呼叫基類 DistributedValues 的 _get 方法,得到value,然後進行計算。

圖 6 如何計算

DistributedDelegate 定義如下,省略部分程式碼。

class DistributedDelegate(DistributedValues):
  """A map from device to values; acts as the same type as the values."""

  def __getattr__(self, name):
    # The '_use_resource_variables' and the attrs starts with '_self' are used
    # for restoring the saved_model proto, and '_attribute_sentinel' is used for
    # Layer tracking. At the point these attrs are queried, the variable has not
    # been initialized. Thus it should not query those of the underlying
    # components.
    if name.startswith("_self_") or name in ("_use_resource_variables",
                                             "_attribute_sentinel",
                                             "_distributed_container"):
      return super(DistributedDelegate, self).__getattr__(name)

    # This allows copy.copy(DistributedDelegate). When copying an object,
    # copy.copy doesn't invoke its __init__ method, instead it makes a new
    # empty object, then copies the attributes over. copy.copy looks for
    # attributes like "__getstate__" in case the object implements its custom
    # copying. Since DistributedDelegate doesn't have those attributes defined,
    # __getattr__ will be invoked, which tries to access "_values" attributes,
    # but that doesn't exist either because this is an empty object, and again
    # __getattr__ is invoked, leading to an infinite recursion.
    if name == "_values":
      raise AttributeError()

    # TODO(priyag): This needs to be made robust against pitfalls from mix use
    # __getattr__ and @property. See b/120402273.
    return getattr(self._get(), name)

  @property
  def values(self):
    """Returns the per replica values."""
    return self._values

  def _get_as_operand(self):
    """Returns the value for operations for the current device.

    Some implementations, e.g.  TPUMirroredVariable , are not able to return the
    value type within a replica context. They can, however, return a value that
    can be used by the operations below.
    """
    return self._get()

  def __add__(self, o):
    return self._get_as_operand() + o

  def __radd__(self, o):
    return o + self._get_as_operand()

  def __sub__(self, o):
    return self._get_as_operand() - o

  def __rsub__(self, o):
    return o - self._get_as_operand()

  # 省略大部分程式碼

1.2.4 PerReplica

PerReplica 的作用是:持有一個map,用來維持從副本到未同步value的對映。

class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
  """Holds a map from replica to unsynchronized values."""

  @property
  def _type_spec(self):
    return PerReplicaSpec(
        *(type_spec.type_spec_from_value(v) for v in self._values))

  @property
  def values(self):
    """Returns the per replica values."""
    return self._values

1.2.5 Mirrored

接著我們來到 Mirrored這裡。

圖 7 Mirrored

Mirrored 代表了在多個裝置上建立的變數,其通過對每個副本應用相同的更新來保持變數的同步。映象變數(Mirrored variables)是用 tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...) 建立的。 通常它們只用於同步訓練。

回憶一下 DistributedValues 的功能,其儲存一個從副本到值的對映,這些值將保持同步,其 _get_cross_replica 方法沒有實現。而 Mirrored 的目的是在跨副本模式(cross-replica mode)下可以直接使用。所以 Mirrored 這裡實現了 _get_cross_replica。_get_cross_replica 就是呼叫基類 DistributedValues 的 _get_on_device_or_primary 方法(具體請參見對應小節),作用是返回本副本對應的數值,或者直接返回 _primary 對應的數值。

概念圖如下:

圖 8 Mirrored 如何計算

Mirrored 定義如下:

# Note that unlike PerReplica, Mirrored values inherit from
# DistributedDelegate and so can be used directly in cross-replica mode.
class Mirrored(DistributedDelegate):
  """Holds a map from replica to values which are kept in sync."""

  def _get_cross_replica(self):
    return self._get_on_device_or_primary() # 呼叫基類 DistributedValues 的方法

  def _as_graph_element(self):
    obj = self._get() # 呼叫基類 DistributedValues 的方法
    conv_fn = getattr(obj, "_as_graph_element", None)
    if conv_fn and callable(conv_fn):
      return conv_fn()
    return obj

1.2.6 Policy

我們接下來看看分散式策略。

圖 9 分散式策略

VariablePolicy

VariablePolicy 是分散式策略的基類,其定義了分散式變數的同步和聚合的策略。在 tf.distribution 範圍內建立變數時,鑑於 tf.Variable 上設定了 synchronization 和 aggregation 引數, tf.distribution 會建立一個適當的策略物件並將其分配給分散式變數。所有的變數操作都被委託給相應的策略物件來完成。

class VariablePolicy(object):
  """Policy defining synchronization and aggregation of a distributed variable.

  Given  synchronization  and  aggregation  parameters set on a  tf.Variable 
  during variable creation within  tf.distribute  scope,  tf.distribute  creates
  an appropriate policy object and assigns it to the distributed variable. All
  variable operations are delegated to the respective policy object.
  """

  def __init__(self, aggregation):
    self._aggregation = aggregation

  def value(self):
    raise NotImplementedError(
        "VariablePolicy.value should be overriden by sub-classes.")

  def _is_mirrored(self):
    raise NotImplementedError(
        "VariablePolicy._is_mirrored should be overriden by sub-classes.")

  def _as_graph_element(self, _):
    raise NotImplementedError(
        "VariablePolicy._as_graph_element should be overriden by sub-classes.")

  def _get_cross_replica(self, var):
    raise NotImplementedError(
        "VariablePolicy._get_cross_replica should be overriden by sub-classes.")

  def _update_replica(self, var, update_fn, value, **kwargs):
    raise NotImplementedError(
        "VariablePolicy._update_replica should be overriden by sub-classes.")
OnReadPolicy

OnReadPolicy 是讀取策略,比如其成員變數 _get_cross_replica 就會呼叫 var.distribute_strategy.reduce 來完成讀取。

class OnReadPolicy(VariablePolicy):
  """Policy defined for  tf.VariableSynchronization.ON_READ  synchronization.

  This policy is created when  synchronization  is set to
   tf.VariableSynchronization.ON_READ  and  aggregation  is set to any of the
  values allowed by the  tf.VariableAggregation  enum such as  NONE ,  SUM ,
   MEAN  or  ONLY_FIRST_REPLICA when creating a  tf.Variable  in  tf.distribute 
  scope.
  """

  def _is_mirrored(self):
    return False

  def value(self, var):
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if (ds_context.in_cross_replica_context() and
          not values_util.in_replica_update_context()):
        if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
          return var._get_replica(0).value()  
        return var._get_cross_replica()  
      else:
        return var._get_on_device_or_primary().value()  

  def _as_graph_element(self, var):
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if ds_context.in_cross_replica_context():
        return ops.convert_to_tensor(var._get_cross_replica())  
    return var._get()._as_graph_element()  

  def _get_cross_replica(self, var):
    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
      return var._get_replica(0)  # 從第一個副本讀取
    if self._aggregation == vs.VariableAggregation.SUM:
      values_util.mark_as_unsaveable() # 不能更新
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      # 呼叫 distribute_strategy 完成規約
      return var.distribute_strategy.reduce(
          reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
          var,
          axis=None)

  def _update_replica(self, var, update_fn, value, **kwargs):
    return update_fn(var._get_on_device_or_primary(), value, **kwargs)  

  def assign_add(self,
                 var,
                 value,
                 use_locking=False,
                 name=None,
                 read_value=True):
    """Adds a value to this variable."""
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if (ds_context.in_cross_replica_context() and
          not values_util.in_replica_update_context()):
        values_util.mark_as_unsaveable()
        return values_util.on_read_assign_add_cross_replica(
            var, value, read_value=read_value)
      else:
        return values_util.on_write_assign_add(
            var,
            value,
            use_locking=use_locking,
            name=name,
            read_value=read_value)

  def assign(self, var, value, use_locking=False, name=None, read_value=True):
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if (ds_context.in_cross_replica_context() and
          not values_util.in_replica_update_context()):
        values_util.mark_as_unsaveable()
        return values_util.on_read_assign_cross_replica(
            var, value, read_value=read_value)
      else:
        return values_util.on_write_assign(
            var,
            value,
            use_locking=use_locking,
            name=name,
            read_value=read_value)
    
  # 省略大部分程式碼
OnWritePolicy

OnWritePolicy 類用來實現寫策略。其主要是呼叫 var._get_on_device_or_primary() 來完成各種操作,比如 _get_cross_replica 就是呼叫 var._get_on_device_or_primary() 來完成操作。 而且也呼叫了 values_util 之中的各種基礎操作。

class OnWritePolicy(VariablePolicy):
  """Policy defined for  tf.VariableSynchronization.ON_WRITE  synchronization.

  This policy is created when the following  synchronization  and  aggregation 
  parameters are specified when creating a  tf.Variable  in  tf.distribute 
  scope and  synchronization  is equal to  tf.VariableSynchronization.ON_WRITE 
  or  tf.VariableSynchronization.AUTO .
  """

  def _is_mirrored(self):
    return True

  def value(self, var):
    return var._get_on_device_or_primary().value()  

  def _as_graph_element(self, var):
    return var._get_on_device_or_primary()._as_graph_element()  

  def _get_cross_replica(self, var):
    # Return identity, to avoid directly exposing the variable to the user and
    # allowing it to be modified by mistake.
    return array_ops.identity(var._get_on_device_or_primary())  

  # 呼叫 update_fn 和 _on_write_update_replica 來完成對應操作
  def _update_replica(self, var, update_fn, value, **kwargs):
    if var.aggregation == variables_lib.VariableAggregation.NONE:
      return update_fn(var._get_on_device_or_primary(), value, **kwargs)  
    return _on_write_update_replica(var, update_fn, value, **kwargs)

  def assign(self, var, value, use_locking=False, name=None, read_value=True):
    return values_util.on_write_assign(
        var, value, use_locking=use_locking, name=name, read_value=read_value)

  def assign_add(self,
                 var,
                 value,
                 use_locking=False,
                 name=None,
                 read_value=True):
    # 呼叫 values_util 完成工作
    return values_util.on_write_assign_add(
        var, value, use_locking=use_locking, name=name, read_value=read_value)

  # 這裡後續會提到
  def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
    return values_util.scatter_update(
        var, sparse_delta, use_locking=use_locking, name=name)

  def get_saveable(self, var, primary_var, name):
    """Saveable ops for AUTO variables."""
    return values_util.get_on_write_saveable(var, primary_var, name)

  def get_restore_ops(self, var, tensor):
    return values_util.get_on_write_restore_ops(var, tensor)

  # 省略大部分程式碼
values_util

上面兩種策略都使用了 on_write_assign_add ,其定義在 ensorflow/python/distribute/values_util.py 之中。

def on_write_assign_add(var, value, use_locking=False, name=None,
                        read_value=True):
  assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
  return var._update(  
      update_fn=assign_add_fn,
      value=value,
      use_locking=use_locking,
      name=name,
      read_value=read_value)

OnWritePolicy 也使用了 values_util 定義的 scatter_update,發現其還是呼叫回到了 var._update。

def scatter_update(var, sparse_delta, use_locking=False, name=None):
  scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
  return var._update( 
      update_fn=scatter_update_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)

1.2.7 DistributedVariable

順著類關係,我們最後來到 DistributedVariable,這裡其實是 MirroredVariable 的主要功能所在。

圖 10 DistributedVariable

DistributedVariable 持有從副本到變數的對映,對於 MirroredVariable 來說,self._policy 就是 OnWritePolicy,具體更新變數就是通過 _policy 完成。

class DistributedVariable(DistributedDelegate, variables_lib.Variable,
                          core.Tensor):
  """Holds a map from replica to variables."""

  def __init__(self, strategy, values, aggregation, var_policy=None):
    if (aggregation == variables_lib.VariableAggregation.MEAN and
        not values[0].dtype.is_floating):
      raise ValueError(
          "creating distributed tf.Variable with aggregation=MEAN and a "
          "non-floating dtype is not supported, please use a different "
          "aggregation or dtype")
    self._distribute_strategy = strategy
    self._aggregation = aggregation
    super(DistributedVariable, self).__init__(values)
    self._common_name = self._primary.name.split(":")[0]
    # Use a weakref to make it easy to map from the contained values
    # to the container without introducing a reference cycle.
    for v in values:
      v._distributed_container = weakref.ref(self)  # pylint: disable=protected-access

    # Packed variable is used to reduce the overhead of function execution.
    # For a DistributedVariable, only one variable handle is captured into a
    # function graph. It's only supported in eager mode.
    if ops.executing_eagerly_outside_functions() and getattr(
        strategy, "_enable_packed_variable_in_eager_mode", False):
      name = "%s/packed/" % self._common_name
      self._packed_var = packed.PackedDistributedVariable(values, name=name)
    else:
      self._packed_var = None

    # tf.keras keeps track of variables initialized using this attribute. When
    # tf.keras gets the default session, it initializes all uninitialized vars.
    # We need to make _keras_initialized a member of DistributedVariable because
    # without this it will use  __getattr__  which will delegate to a component
    # variable.
    self._keras_initialized = False
    # Typically, a  DistributedVariable 's initializer is composed of the
    # initializers of the components variables. However, in some cases, such as
    # when restoring from a checkpoint, we may set the _initializer_op
    # property on the entire  DistributedVariable .
    self._initializer_op = None
    # Set a VariablePolicy which decides how we replicate/aggregate the given
    # variable.
    self._policy = var_policy

具體如何處理,需要看實際情況,但是最終都是歸結到 strategy 或者 strategy.extended 之上。

讀取

讀取時候,會呼叫 _get_cross_replica,其內部呼叫 Policy。而 Policy 會呼叫 distribute_strategy 完成規約。

def _get_cross_replica(self):
  if values_util.is_saving_non_distributed(): 
    return self._primary # 如果是非分散式儲存,就直接返回
  if self._policy:
    # 返回跨樣本
    return self._policy._get_cross_replica(self)  

  raise NotImplementedError(
      "DistributedVariable._get_cross_replica requires a valid "
      "VariablePolicy. Please set the policy via the  var_policy  argument "
      "in the constructor, or override this method in sub-classes which "
      "support cross-replica accesses.")

具體如下:

圖 11 DistributedVariable 讀取

scatter_update

比如 scatter_update 也會呼叫 _policy 完成更新操作。

def scatter_update(self, sparse_delta, use_locking=False, name=None):
  if values_util.is_saving_non_distributed():
    return self._primary.scatter_update(sparse_delta, use_locking, name)
  if self._policy:
    return self._policy.scatter_update(
        self, sparse_delta, use_locking=use_locking, name=name)
  return values_util.scatter_update(
      self, sparse_delta, use_locking=use_locking, name=name)

前面在 OnWritePolicy 之中討論過,scatter_update 最後會呼叫回到 DistributedVariable 自己的 _update 方法。

def scatter_update(var, sparse_delta, use_locking=False, name=None):
  scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
  return var._update(  
      update_fn=scatter_update_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)

var._update 裡面有各種執行路徑,我們只選擇部分分析。

def _update(self, update_fn, value, **kwargs):
  """Applies updates depending on the context.

  The method calls  _update_replica  in replica context,
   _update_cross_replica  in cross replica context, and  update_fn  in update
  context.

  If  read_value  is True, the method returns the updated Variable. If
   read_value  is False, the method returns the update  tf.Operation .

  Args:
    update_fn: A callable to pass to  strategy.extended.update  to update the
      variable. It should have the same signature as  Variable.assign() .
    value: value to be passed to  update_fn .
    **kwargs: keyword arguments to  update_fn .

  Returns:
    Updated variable or  tf.Operation .

  """
  if values_util.is_saving_non_distributed():
    return update_fn(self._primary, value, **kwargs) # 非分散式

  with ds_context.enter_or_assert_strategy(self.distribute_strategy):
    if ds_context.in_cross_replica_context():
      update_replica_id = distribute_lib.get_update_replica_id()
      if update_replica_id is not None:
        replica_value = self._get_replica(update_replica_id)
        return update_fn(replica_value, value, **kwargs)
      return self._update_cross_replica(update_fn, value, **kwargs) # 跨副本更新
    else:
      values_util.assert_replica_context(self.distribute_strategy)
      return self._update_replica(update_fn, value, **kwargs)

然後呼叫了 _update_cross_replica 進行跨副本更新。

def _update_cross_replica(self, update_fn, value, **kwargs):
  """Applies updates across replicas.

  Args:
    update_fn: A callable to pass to  strategy.extended.update  to update the
      variable. It should has the same signature as  Variable.assign() .
    value: value to be passed to  update_fn .
    **kwargs: remaining arguments to  update_fn .

  Returns:
    Updated variable or  tf.Operation .
  """
  values_util.mark_as_unsaveable()
  return self.distribute_strategy.extended.update(
      self, update_fn, args=(value,), kwargs=kwargs, group=True)

我們展示如下:

圖 12 DistributedVariable 更新

1.2.8 儲存

我們接下來看看 MirroredVariable 如何儲存,可以看到,在 _saveable_factory 之中使用 _MirroredSaveable 完成儲存功能。

class MirroredVariable(DistributedVariable, Mirrored):

  def _gather_saveables_for_checkpoint(self):
    """Overrides Trackable method.

    This allows both name-based and object-based save and restore of
    MirroredVariables.

    Returns:
      A dictionary mapping attribute names to  SaveableObject  factories.
    """

    def _saveable_factory(name=self._common_name):
      return _MirroredSaveable(self, self._primary, name)

    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}

_MirroredSaveable 來定義如何儲存 MirroredVariable。

class _MirroredSaveable(saveable_object.SaveableObject):
  """Class for defining how to restore a MirroredVariable."""

  def __init__(self, mirrored_variable, primary_variable, name):
    self._mirrored_variable = mirrored_variable
    # 這裡呼叫到
    tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable,
                                                     primary_variable, name)
    super(_MirroredSaveable, self).__init__(tensor, spec, name)

  def restore(self, restored_tensors, restored_shapes):
    """Restore the same value into all variables."""
    tensor, = restored_tensors
    return values_util.get_on_write_restore_ops(self._mirrored_variable, tensor)

get_on_write_saveable 程式碼如下:

def get_on_write_saveable(var, primary_var, name):
  """Return saveable spec for AUTO and ON_WRITE variables."""
  # We use a callable so that we don't have to evaluate this expression
  # in the case where we are trying to restore instead of save.
  def tensor():
    if context.executing_eagerly() and not primary_var.is_initialized():
      # A SaveSpec tensor value of  None  indicates that the variable is
      # uninitialized.
      return None
    strategy = var.distribute_strategy
    return strategy.extended.read_var(var) # 獲取張量

  spec = saveable_object.SaveSpec(
      tensor=tensor,
      slice_spec="",
      name=name,
      dtype=var.dtype,
      device=primary_var.device)

  return tensor, [spec]

tensorflow/python/distribute/mirrored_strategy.py 這裡會跨副本進行取值。

def read_var(self, replica_local_var):
  """Read the aggregate value of a replica-local variable."""
  if distribute_utils.is_sync_on_read(replica_local_var):
    return replica_local_var._get_cross_replica()
  return array_ops.identity(replica_local_var._get())

1.2.9 小結

經過上述分析,最終我們得到 MirroredVariable 繼承體系註解版如下,其很多功能最終落實在 tf.distribute.Strategy 之上。

圖 13 MirroredVariable 繼承體系註解版

1.3 構建變數

在 MirroredStrategy 下建立的變數是一個 MirroredVariable。如果在策略的構造引數中沒有指定裝置,那麼它將使用所有可用的 GPU。如果沒有找到 GPU,它將使用可用的 CPU。請注意,TensorFlow 將一臺機器上的所有 CPU 視為一個單一的裝置,並在內部使用執行緒進行並行化。我們接下來看看如何構建 MirroredVariable。

1.3.1 StrategyBase

首先,在 tensorflow/python/distribute/distribute_lib.py 之中有如下程式碼,說明關於 scope 的使用,還是 _extended 起了作用。

def scope(self):
  """Returns a context manager selecting this Strategy as current.

  Inside a  with strategy.scope():  code block, this thread
  will use a variable creator set by  strategy , and will
  enter its "cross-replica context".

  Returns:
    A context manager.
  """
  return self._extended._scope(self)  

1.3.2 StrategyExtendedV2

於是我們來到了 StrategyExtendedV2。StrategyExtendedV2 這裡呼叫了 creator_with_resource_vars 來提供一種如何建立變數的機制,creator_with_resource_vars 內部則呼叫派生類的_create_variable 來建立變數。

def _scope(self, strategy):
  """Implementation of tf.distribute.Strategy.scope()."""

  def creator_with_resource_vars(next_creator, **kwargs):
    """Variable creator to use in  _CurrentDistributionContext ."""
    _require_strategy_scope_extended(self)
    kwargs["use_resource"] = True
    kwargs["distribute_strategy"] = strategy

    # Unwrap  initial_value  if it is a  CheckpointInitialValue  to avoid
    # dereferencing a  Tensor  that is without a  name . We still need to
    # propagate the metadata it's holding.
    if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
      checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
      kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
    elif isinstance(kwargs["initial_value"],
                    trackable.CheckpointInitialValueCallable):
      checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
    elif (isinstance(kwargs["initial_value"], functools.partial) and
          isinstance(kwargs["initial_value"].func,
                     trackable.CheckpointInitialValueCallable)):
      # Some libraries (e.g, Keras) create partial function out of initializer
      # to bind shape/dtype, for example:
      #  initial_val = functools.partial(initializer, shape, dtype=dtype)
      # Therefore to get the restore_uid we need to examine the "func" of
      # the partial function.
      checkpoint_restore_uid = kwargs[
          "initial_value"].func.checkpoint_position.restore_uid
    else:
      checkpoint_restore_uid = None

    created = self._create_variable(next_creator, **kwargs)

    if checkpoint_restore_uid is not None:
      # Let the checkpointing infrastructure know that the variable was
      # already restored so it doesn't waste memory loading the value again.
      # In this case of CheckpointInitialValueCallable this may already be
      # done by the final variable creator, but it doesn't hurt to do it
      # again.
      created._maybe_initialize_trackable()
      created._update_uid = checkpoint_restore_uid
    return created

  def distributed_getter(getter, *args, **kwargs):
    return getter(*args, **kwargs)

  # 這裡使用了 creator_with_resource_vars
  return _CurrentDistributionContext(
      strategy,
      variable_scope.variable_creator_scope(creator_with_resource_vars), # 配置如何建立變數
      variable_scope.variable_scope(
          variable_scope.get_variable_scope(),
          custom_getter=distributed_getter), self._default_device)

邏輯如下,進入scope之後經過一系列操作之後,返回了_CurrentDistributionContext,其內部又會有一系列操作,我們繼續看看。

圖 14 如何建立變數

1.3.3 _CurrentDistributionContext

_CurrentDistributionContext 維護了策略相關的資訊,設定各種作用域,返回策略。

class _CurrentDistributionContext(object):
  """Context manager setting the current  tf.distribute.Strategy .

  Also: overrides the variable creator and optionally the current device.
  """

  def __init__(self,
               strategy,
               var_creator_scope,
               var_scope=None,
               resource_creator_scope=None,
               default_device=None):
    self._context = distribution_strategy_context._CrossReplicaThreadMode( 
        strategy)
    self._var_creator_scope = var_creator_scope
    self._var_scope = var_scope
    self._resource_creator_scope = resource_creator_scope
    if default_device:
      self._device_scope = ops.device(default_device)
    else:
      self._device_scope = None
    self._same_scope_again_count = 0

  def __enter__(self):
    # Allow this scope to be entered if this strategy is already in scope.
    if distribution_strategy_context.has_strategy():
      _require_cross_replica_or_default_context_extended(
          self._context.strategy.extended)
      self._same_scope_again_count += 1
    else:
      _push_per_thread_mode(self._context)
      if self._var_scope:
        self._var_scope.__enter__()
      self._var_creator_scope.__enter__()
      if self._resource_creator_scope:
        nest.map_structure(lambda scope: scope.__enter__(),
                           self._resource_creator_scope)
      if self._device_scope:
        self._device_scope.__enter__()
    return self._context.strategy

  def __exit__(self, exception_type, exception_value, traceback):
    if self._same_scope_again_count > 0:
      self._same_scope_again_count -= 1
      return
    if self._device_scope:
      try:
        self._device_scope.__exit__(exception_type, exception_value, traceback)
      except RuntimeError as e:
        six.raise_from(
            RuntimeError("Device scope nesting error: move call to "
                         "tf.distribute.set_strategy() out of  with  scope."),
            e)

    try:
      self._var_creator_scope.__exit__(
          exception_type, exception_value, traceback)
    except RuntimeError as e:
      six.raise_from(
          RuntimeError("Variable creator scope nesting error: move call to "
                       "tf.distribute.set_strategy() out of  with  scope."),
          e)

    if self._resource_creator_scope:
      try:
        if isinstance(self._resource_creator_scope, list):
          reversed_resource_creator_scope = self._resource_creator_scope[::-1]
          nest.map_structure(
              lambda scope: scope.__exit__(exception_type, exception_value,  
                                           traceback),
              reversed_resource_creator_scope)

        else:
          self._resource_creator_scope.__exit__(exception_type, exception_value,
                                                traceback)
      except RuntimeError as e:
        six.raise_from(
            RuntimeError("Resource creator scope nesting error: move call "
                         "to tf.distribute.set_strategy() out of  with  "
                         "scope."), e)

    if self._var_scope:
      try:
        self._var_scope.__exit__(exception_type, exception_value, traceback)
      except RuntimeError as e:
        six.raise_from(
            RuntimeError("Variable scope nesting error: move call to "
                         "tf.distribute.set_strategy() out of  with  scope."),
            e)
    _pop_per_thread_mode()

1.3.4 MirroredStrategy

有了上面的分析,我們可以知道,當使用了 Strategy 時候,會使用 Strategy 的 _create_variable 最終生成變數。

create_variable 負責具體業務。裡面會用到 self._devices,然後呼叫到了 distribute_utils.create_mirrored_variable,其會使用 real_mirrored_creator,VARIABLE_CLASS_MAPPING 和 create_mirrored_variable 來建立變數。real_mirrored_creator會配置具體的變數名稱,後續呼叫則會據此來設定變數應該放到哪個裝置之上。對於第一個裝置,這裡依然採用原來的名字,而後續裝置則在原變數名之後加上 /replica_裝置號 ,這樣就可以和原始變數區別。接著會把原來變數的值賦值給這些對應的副本變數。

def _create_variable(self, next_creator, **kwargs):
  """Create a mirrored variable. See  DistributionStrategy.scope ."""
  colocate_with = kwargs.pop("colocate_with", None)
  if colocate_with is None:
    devices = self._devices
  elif isinstance(colocate_with, numpy_dataset.SingleDevice):
    with ops.device(colocate_with.device):
      return next_creator(**kwargs)
  else:
    devices = colocate_with._devices  

  def _real_mirrored_creator(**kwargs):  
    value_list = []
    for i, d in enumerate(devices):
      with ops.device(d):
        kwargs["initial_value"] = self._get_variable_creator_initial_value(
            replica_id=i,
            device=d,
            primary_var=value_list[0] if value_list else None,
            **kwargs)
        if i > 0:
          # Give replicas meaningful distinct names:
          var0name = value_list[0].name.split(":")[0]
          # We append a / to variable names created on replicas with id > 0 to
          # ensure that we ignore the name scope and instead use the given
          # name as the absolute name of the variable.
          kwargs["name"] = "%s/replica_%d/" % (var0name, i)
        with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
          # Don't record operations (e.g. other variable reads) during
          # variable creation.
          with tape.stop_recording():
            v = next_creator(**kwargs)
        assert not isinstance(v, values.DistributedVariable)
        value_list.append(v)
    return value_list

  return distribute_utils.create_mirrored_variable(
      self._container_strategy(), _real_mirrored_creator,
      distribute_utils.VARIABLE_CLASS_MAPPING,
      distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs)

VARIABLE_CLASS_MAPPING 用來設定生成哪種型別的變數。VARIABLE_POLICY_MAPPING 設定使用何種策略來應對讀寫同步。

# The following mapping indicates the policy that you must use for a given
# variable  synchronization  and  aggregation  pair.
# OnWritePolicy is used for:
# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# OnReadPolicy is used for:
# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
VARIABLE_POLICY_MAPPING = {
    vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
    vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
}

VARIABLE_CLASS_MAPPING = {
    "VariableClass": values_lib.DistributedVariable,
    vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable, # 我們關注這裡
    vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
}

1.3.5 distribute_utils

tensorflow/python/distribute/distribute_utils.py 的 create_mirrored_variable 會具體建立變數。對於我們的例子,class_mapping 就是 values_lib.MirroredVariable。

def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
                             policy_mapping, **kwargs):
  """Create distributed variables with given synchronization and aggregation."""
  # Figure out what collections this variable should be added to.
  # We'll add the MirroredVariable to those collections instead.
  var_collections = kwargs.pop("collections", None)
  if var_collections is None:
    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
  kwargs["collections"] = []

  synchronization = _validate_synchronization(kwargs)
  # Update synchronization in kwargs in case it's AUTO, which is converted to
  # ON_WRITE.
  kwargs["synchronization"] = synchronization
  aggregation = _validate_aggregation(kwargs)
  use_var_policy = getattr(strategy.extended, "_use_var_policy", False)

  # Ignore user-specified caching device, not needed for mirrored variables.
  kwargs.pop("caching_device", None)

  with tape.stop_recording():
    # 構建映象變數列表    
    value_list = real_mirrored_creator(**kwargs)
    # MirroredVariable is recreated during saved_model loading, and its
    # component variables (value_list) will have None initializer. We
    # set their initializers to no_op so that consumer like
    #  global_variables_initializer  wouldn't complain, as it groups all
    # variables' initializers thus all variables have to have initializers.
    for v in value_list:
      if hasattr(v, "_initializer_op") and v._initializer_op is None:
        v._initializer_op = control_flow_ops.no_op()
    if use_var_policy:
      # 獲取策略,得到類,生成變數
      var_policy_cls = policy_mapping.get(synchronization)
      var_policy = var_policy_cls(aggregation=aggregation)
      var_cls = class_mapping.get("VariableClass")
      result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
    else:
      var_cls = class_mapping.get(synchronization)
      result = var_cls(strategy, value_list, aggregation)

  # Add the wrapped variable to the requested collections.
  # The handling of eager mode and the global step matches
  # ResourceVariable._init_from_args().
  if not context.executing_eagerly():
    g = ops.get_default_graph()
    # If "trainable" is True, next_creator() will add the member variables
    # to the TRAINABLE_VARIABLES collection, so we manually remove
    # them and replace with the MirroredVariable. We can't set
    # "trainable" to False for next_creator() since that causes functions
    # like implicit_gradients to skip those variables.
    if kwargs.get("trainable", True):
      var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
      for value in value_list:
        for i, trainable_variable in enumerate(l):
          if value is trainable_variable:
            del l[i]
            break

    g.add_to_collections(var_collections, result)
  elif ops.GraphKeys.GLOBAL_STEP in var_collections:
    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

  return result

最終構建邏輯如下,_CurrentDistributionContext 成員函式 _var_creator_scope 會指向 creator_with_resource_vars。當生成變數時候,呼叫時候 creator_with_resource_vars 會逐層呼叫,最後生成 MirroredVariable。

圖 15 建立變數

1.4 總結

前面的問題我們目前為止回答如下:

  • 如何呼叫到 Strategy 這裡?
    • 讀寫變數最終都會落到 strategy 或者 strategy.extended 之上。
  • 如何生成 Mirrored Variable?
    • 使用者在 scope 之中會獲得上下文,上下文提供了建立變數的方法,使用者在上下文之中建立的變數自然就是 Mirrored Variable。
  • 如何把張量分發到各個裝置上?
    • 當使用了 Strategy 時候,會使用 Strategy 的 _create_variable 生成變數。 _create_variable 最終呼叫到 _real_mirrored_creator 。
    • _real_mirrored_creator 會配置具體的變數名稱,後續呼叫則會據此來設定變數應該放到哪個裝置之上。對於第一個裝置,這裡依然採用原來的名字,而後續裝置則在原變數名之後加上 /replica _裝置號 ,這樣就可以和原始變數區別。
    • 後續在佈局(placement)時候,會根據裝置名字進行分配,把變數放置到對應裝置之上。
  • 如果對外保持一個統一的檢視?
    • 在上下文之中,使用者得到的是 Mirrored Variable, Mirrored Variable 對外遮蔽了內部變數,提供了統一檢視。比如:讀取時候,會呼叫 _get_cross_replica,其內部呼叫 Policy。而 Policy 會呼叫 distribute_strategy 完成規約。
  • 變數之間如何保持一致?
    • 在前面 scatter_update 分析時候知道,更新變數時候,會呼叫到 strategy.extended 之上,在 strategy.extended 中,變數之間通過例如 All-Reduce 來保持一致,這個我們後文會詳細分析。

用示例圖來演示下,假設有一個 MirroredVariable A 變數,其內部是由 3 個張量組成。每個 Worker 都覺得自己在更新 MirroredVariable A,實際上是分別更新不同的變數,變數之間通過例如 All-Reduce 來保持一致。

圖 16 如何更新

2. ShardedVariable

在機器學習訓練之中,如果變數太大,無法放入單個裝置上(例如大型embedding),則可能需要在多個裝置上對這個變數進行分片。在 TensorFlow 中,與這個思想對應的概念就是 ShardedVariable 。

圖 17 ShardedVariable

變數分片(Variable sharding)是指將一個變數分割成多個較小的變數,這些變數被稱為分片(shards)。ShardedVariable 可以被看做是一個容器,容器中的 "變數 "應被視為分片。ShardedVariable 類維護一個可以獨立儲存在不同裝置(例如,多個引數伺服器)上的較小變數的列表,並負責儲存和恢復這些變數,就像它們是一個較大的變數一樣。變數分片對於緩解分配訪問這些分片時的網路負載很有用,它對於在多個引數伺服器上分配一個普通變數的計算和儲存也很有用。

圖 18 ShardedVariable 容器

ShardedVariable 類的物件可以用給定數量的分片進行儲存,然後從檢查點恢復到不同數量的分片。SavedModel可以被 TF serving API 等程式使用,但是不支援 tf.saved_model.load 。由於 ShardedVariable 可以被儲存,然後根據恢復環境恢復到不同數量的分片,例如,TF serving API 會恢復到只有一個分片以提高服務效率,所以當在tf.function 中使用 ShardedVariable 時,一般不應假設它在儲存和載入時具有相同數量的分片。

2.1 問題

對於 ShardedVariable,我們依然用幾個問題來引導分析。

  • 如何實現引數存到引數伺服器之上?
  • 如何對引數實現分片儲存?
  • 如何把計算(梯度更新引數的操作)放到引數伺服器之上?(會在後續章節進行分析)
  • Coordinator 是隨機分配計算的嗎?(會在後續章節進行分析)

2.2 定義

ShardedVariable 的定義其實沒有太多內容,主要精華都在基類 ShardedVariableMixin 之中,我們稍後就會進行分析。

圖 19 ShardedVariable 定義

具體定義程式碼如下:

class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
  """A container for  Variables  that should be treated as shards.
  """

  @property
  def _type_spec(self):
    return ShardedVariableSpec(
        *(resource_variable_ops.VariableSpec(v.shape, v.dtype)
          for v in self._variables))

  @classmethod
  def _overload_all_operators(cls):
    """Register overloads for all operators."""
    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
      if operator == '__getitem__':
        continue

      cls._overload_operator(operator)

  @classmethod
  def _overload_operator(cls, operator):
    """Delegate an operator overload to  ops.Tensor ."""
    tensor_operator = getattr(ops.Tensor, operator)

    def _operator(v, *args, **kwargs):
      return tensor_operator(_var_to_tensor(v), *args, **kwargs)

    setattr(cls, operator, _operator)

2.3 如何分割槽

ShardedVariable 的精華之一就是分割槽,我們探究一下其機理。需要注意的是:ShardedVariable 只支援在第一個維度進行分割槽。

2.3.1 基類

基類 Partitioner 沒有太多東西,其派生類需要實現 call

@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
class Partitioner(object):
  """Partitioner base class: all partitiners inherit from this class.

  Partitioners should implement a  __call__  method with the following
  signature:

  ```python
  def __call__(self, shape, dtype, axis=0):
    # Partitions the given  shape  and returns the partition results.
    # See docstring of  __call__  method for the format of partition results.
  ```
  """

  def __call__(self, shape, dtype, axis=0):
    """Partitions the given  shape  and returns the partition results.

    Examples of a partitioner that allocates a fixed number of shards:

    ```python
    partitioner = FixedShardsPartitioner(num_shards=2)
    partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
    print(partitions) # [2, 0]
    ```

    Args:
      shape: a  tf.TensorShape , the shape to partition.
      dtype: a  tf.dtypes.Dtype  indicating the type of the partition value.
      axis: The axis to partition along.  Default: outermost axis.

    Returns:
      A list of integers representing the number of partitions on each axis,
      where i-th value correponds to i-th axis.
    """
    raise NotImplementedError

2.2.4 固定分割槽

FixedShardsPartitioner 會把變數分成固定的分片。註釋之中有一個使用樣例,對於本例來說,axis = 0 時候,min(self._num_shards, shape.dims[axis].value) = min(2, 10),所以分成兩個 shard。

@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
class FixedShardsPartitioner(Partitioner):
  """Partitioner that allocates a fixed number of shards.

  Examples:

  >>> # standalone usage:
  >>> partitioner = FixedShardsPartitioner(num_shards=2)
  >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32)
  >>> [2, 1]
  >>>
  >>> # use in ParameterServerStrategy
  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)

  """

  def __init__(self, num_shards):
    """Creates a new  FixedShardsPartitioner .

    Args:
      num_shards:  int , number of shards to partition.
    """
    self._num_shards = num_shards

  def __call__(self, shape, dtype, axis=0):
    del dtype
    result = [1] * len(shape)
    result[axis] = min(self._num_shards, shape.dims[axis].value)
    return result

2.2.5 最小分割槽

MinSizePartitioner 為每個分片分配最小尺寸的分割槽器。該分割槽器確保每個分片至少有"min_shard_位元組",並嘗試分配儘可能多的分片,即保持分片大小盡可能小。此類分片的最大數量(上限)由"max_Shard"給出。

@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
class MinSizePartitioner(Partitioner):
  """Partitioner that allocates a minimum size per shard.

  This partitioner ensures each shard has at least  min_shard_bytes , and tries
  to allocate as many shards as possible, i.e., keeping shard size as small as
  possible. The maximum number of such shards (upper bound) is given by
   max_shards .

  Examples:

  >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2)
  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
  >>> [2, 1]
  >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10)
  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
  >>> [6, 1]
  >>>
  >>> # use in ParameterServerStrategy
  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
  """

  def __init__(self,
               min_shard_bytes=256 << 10,
               max_shards=1,
               bytes_per_string=16):
    """Creates a new  MinSizePartitioner .

    Args:
      min_shard_bytes: Minimum bytes of each shard. Defaults to 256K.
      max_shards: Upper bound on the number of shards. Defaults to 1.
      bytes_per_string: If the partition value is of type string, this provides
        an estimate of how large each string is.
    """
    self._min_shard_bytes = min_shard_bytes
    self._max_shards = max_shards
    self._bytes_per_string = bytes_per_string

  def __call__(self, shape, dtype, axis=0):
    return partitioned_variables.min_max_variable_partitioner(
        max_partitions=self._max_shards,
        axis=axis,
        min_slice_size=self._min_shard_bytes,
        bytes_per_string_element=self._bytes_per_string)(shape, dtype)

min_max_variable_partitioner 是具體業務實現。該方法返回一個分割槽器,該分割槽器對"給定形狀和資料型別"的變數進行分割槽,使每個分割槽有的最小值為 min_slice_size 大小的切片。此類分割槽的最大數量(上限)由 max_partitions 給出。

@tf_export(v1=["min_max_variable_partitioner"])
def min_max_variable_partitioner(max_partitions=1, axis=0,
                                 min_slice_size=256 << 10,
                                 bytes_per_string_element=16):
  """Partitioner to allocate minimum size per slice.

  Returns a partitioner that partitions the variable of given shape and dtype
  such that each partition has a minimum of  min_slice_size  slice of the
  variable. The maximum number of such partitions (upper bound) is given by
   max_partitions .

  Args:
    max_partitions: Upper bound on the number of partitions. Defaults to 1.
    axis: Axis along which to partition the variable. Defaults to 0.
    min_slice_size: Minimum size of the variable slice per partition. Defaults
      to 256K.
    bytes_per_string_element: If the  Variable  is of type string, this provides
      an estimate of how large each scalar in the  Variable  is.

  Returns:
    A partition function usable as the  partitioner  argument to
     variable_scope  and  get_variable .

  """
  def _partitioner(shape, dtype):
    """Partitioner that partitions list for a variable of given shape and type.

    Ex: Consider partitioning a variable of type float32 with
      shape=[1024, 1024].
      If  max_partitions  >= 16, this function would return
        [(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1].
      If  max_partitions  < 16, this function would return
        [ max_partitions , 1].

    Args:
      shape: Shape of the variable.
      dtype: Type of the variable.

    Returns:
      List of partitions for each axis (currently only one axis can be
      partitioned).

    Raises:
      ValueError: If axis to partition along does not exist for the variable.
    """
    if axis >= len(shape):
      raise ValueError("Can not partition variable along axis %d when shape is "
                       "only %s" % (axis, shape))
    if dtype.base_dtype == dtypes.string:
      bytes_per_element = bytes_per_string_element
    else:
      bytes_per_element = dtype.size
    total_size_bytes = shape.num_elements() * bytes_per_element
    partitions = total_size_bytes / min_slice_size
    partitions_list = [1] * len(shape)
    # We can not partition the variable beyond what its shape or
    #  max_partitions  allows.
    partitions_list[axis] = max(1, min(shape.dims[axis].value,
                                       max_partitions,
                                       int(math.ceil(partitions))))
    return partitions_list
  return _partitioner

2.3.4 最大分割槽

此分割槽器確保每個碎片最多有 max_shard_bytes 大的尺寸,並嘗試分配儘可能少的分片,即保持分片儘可能大。如果分割槽程式達到了 max_shard 限制,那麼每個 shard 可能最終都會大於 max_shard_bytes。預設情況下,max_shards..等於 None,就是不限制分片的數量。

@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
class MaxSizePartitioner(Partitioner):
  """Partitioner that keeps shards below  max_shard_bytes .

  This partitioner ensures each shard has at most  max_shard_bytes , and tries
  to allocate as few shards as possible, i.e., keeping shard size as large
  as possible.

  If the partitioner hits the  max_shards  limit, then each shard may end up
  larger than  max_shard_bytes . By default  max_shards  equals  None  and no
  limit on the number of shards is enforced.

  Examples:

  >>> partitioner = MaxSizePartitioner(max_shard_bytes=4)
  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
  >>> [6, 1]
  >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
  >>> [2, 1]
  >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024)
  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
  >>> [1, 1]
  >>>
  >>> # use in ParameterServerStrategy
  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
  """

  def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16):
    """Creates a new  MaxSizePartitioner .

    Args:
      max_shard_bytes: The maximum size any given shard is allowed to be.
      max_shards: The maximum number of shards in  int  created taking
        precedence over  max_shard_bytes .
      bytes_per_string: If the partition value is of type string, this provides
        an estimate of how large each string is.
    """
    if max_shard_bytes < 1:
      raise ValueError('max_shard_bytes must be positive, got: %r' %
                       max_shard_bytes)
    if max_shards and max_shards < 1:
      raise ValueError('max_shards must be positive, got: %r' % max_shards)
    if bytes_per_string < 1:
      raise ValueError('bytes_per_string must be positive, got: %r' %
                       bytes_per_string)

    self._max_shard_bytes = max_shard_bytes
    self._max_shards = max_shards
    self._bytes_per_string = bytes_per_string

  def __call__(self, shape, dtype, axis=0):
    return partitioned_variables.variable_axis_size_partitioner(
        max_shard_bytes=self._max_shard_bytes,
        max_shards=self._max_shards,
        bytes_per_string_element=self._bytes_per_string,
        axis=axis)(shape, dtype)

variable_axis_size_partitioner 是具體業務功能。此分割槽程式將沿一個軸切分一個變數,試圖將最大分片的大小保持在 max_shard_bytes 以下。如果分割槽程式達到了 max_shard 限制,那麼每個 shard 可能最終都會大於 max_shard_bytes。預設情況下,max_shards 等於 None,意思是不限制碎片的數量。

max_shard_bytes 的一個合理值是(64<<20)-1,或者在 64MB 左右,這樣可以保證低於 protobuf 位元組的限制。

@tf_export(v1=["variable_axis_size_partitioner"])
def variable_axis_size_partitioner(
    max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None):
  """Get a partitioner for VariableScope to keep shards below  max_shard_bytes .

  This partitioner will shard a Variable along one axis, attempting to keep
  the maximum shard size below  max_shard_bytes .  In practice, this is not
  always possible when sharding along only one axis.  When this happens,
  this axis is sharded as much as possible (i.e., every dimension becomes
  a separate shard).

  If the partitioner hits the  max_shards  limit, then each shard may end up
  larger than  max_shard_bytes . By default  max_shards  equals  None  and no
  limit on the number of shards is enforced.

  One reasonable value for  max_shard_bytes  is  (64 << 20) - 1 , or almost
   64MB , to keep below the protobuf byte limit.

  Args:
    max_shard_bytes: The maximum size any given shard is allowed to be.
    axis: The axis to partition along.  Default: outermost axis.
    bytes_per_string_element: If the  Variable  is of type string, this provides
      an estimate of how large each scalar in the  Variable  is.
    max_shards: The maximum number of shards in int created taking precedence
      over  max_shard_bytes .

  Returns:
    A partition function usable as the  partitioner  argument to
     variable_scope  and  get_variable .

  Raises:
    ValueError: If any of the byte counts are non-positive.
  """

  def _partitioner(shape, dtype):
    """Partitioner that partitions shards to have max_shard_bytes total size.

    Args:
      shape: A  TensorShape .
      dtype: A  DType .

    Returns:
      A tuple representing how much to slice each axis in shape.

    Raises:
      ValueError: If shape is not a fully defined  TensorShape  or dtype is not
        a  DType .
    """
    if dtype.base_dtype == dtypes.string:
      element_size = bytes_per_string_element
    else:
      element_size = dtype.size

    partitions = [1] * shape.ndims
    bytes_per_slice = 1.0 * (
        shape.num_elements() / shape.dims[axis].value) * element_size
    # How many slices can we fit on one shard of size at most max_shard_bytes?
    # At least one slice is required.
    slices_per_shard = max(1, math.floor(max_shard_bytes / bytes_per_slice))
    # How many shards do we need for axis given that each shard fits
    # slices_per_shard slices from a total of shape[axis] slices?
    axis_shards = int(math.ceil(
        1.0 * shape.dims[axis].value / slices_per_shard))
    if max_shards:
      axis_shards = min(max_shards, axis_shards)

    partitions[axis] = axis_shards

    return partitions

  return _partitioner

2.4 ShardedVariableMixin

前面提到了,ShardedVariableMixin 是核心所在,我們接下來就分析一下。ShardedVariableMixin 主要成員變數是:

  • _variables : 分割槽的變數。

  • _var_offsets : 分割槽變數在 ShardedVariableMixin 對應的偏移,就是把 _variables 看成是一個整體,然後用 offset 在其中查詢對應的資料。

  • _shape : ShardedVariableMixin 的 shape。

  • _name : ShardedVariableMixin 的名字。

class ShardedVariableMixin(trackable.Trackable):
  """Mixin for ShardedVariable."""

  def __init__(self,
               variables: Sequence[variables_lib.Variable],
               name='ShardedVariable'):
    """Treats  variables  as shards of a larger Variable.

    Args:
      variables: A list of  ResourceVariable s that comprise this sharded
        variable. Variables should not be shared between different
         ShardedVariableMixin  objects.
      name: String. Name of this container. Defaults to "ShardedVariable".
    """
    super(ShardedVariableMixin, self).__init__()
    self._variables = variables
    self._name = name

    var_dtypes = {v.dtype for v in variables}
    first_var = variables[0]
    self._dtype = first_var.dtype

    # All variables must have the same shape for axes > 0.
    # 計算整體形狀
    higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
    first_dim = sum(int(v.shape.as_list()[0]) for v in variables)
    self._shape = tensor_shape.TensorShape([first_dim] +
                                           first_var.shape.as_list()[1:])
    
    # 計算每個分割槽在整體之中的偏移
    self._var_offsets = [
        [0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
    ]
    for i in range(1, len(variables)):
      # Always partition on the first axis. Offsets on other axes are 0.
      self._var_offsets[i][0] += (
          self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0])

    save_slice_info = [v._get_save_slice_info() for v in variables]  

    # We create an uninitialized saving_variable with the full shape, which can
    # be later captured in signatures so that the signatures can treat this
    # ShardedVariable as one single variable.
    self._saving_variable = resource_variable_ops.UninitializedVariable(
        shape=self._shape, dtype=self._dtype, name=self._name)

2.4.1 使用

我們用如下示例看看如何使用。

variables = [
  tf.Variable(np.array([[3, 2]]), shape=(1, 2), dtype=tf.float32,),
  tf.Variable(np.array([[3, 2], [0, 1]]),  shape=(2, 2), dtype=tf.float32),
  tf.Variable(np.array([[3, 2]]),  shape=(1, 2), dtype=tf.float32)
]
sharded_variable = ShardedVariableMixin(variables)

sharded_variable 內部成員變數列印如下,可以看到,_var_offsets 就是把所有引數分割槽看為是一個整體,從中找到對應的分割槽。

_shape = {TensorShape: 2} (4, 2)
_var_offsets = {list: 3} [[0, 0], [1, 0], [3, 0]]
first_dim = {int} 4

比如上面例子之中,三個變數整體打包之後就是如下所示,使用者可以使用 offset 在這裡查詢資料。

[[3,2][3,2],[0,1],[3,2]]

我們再用另一個圖例看看。假設引數有4個分割槽,則具體如下:

圖 20 分割槽

如果變數都放在引數伺服器上,則具體如下。

圖 21 分割槽與引數伺服器

2.4.2 獲取分割槽

我們接下來看看如何獲取分割槽。就是從 sharded variable 之中把指定部分作為一個張量取出。具體邏輯是:分析傳入的 spec, 根據 spec 的內容對 sharded variable 進行處理,獲得一個引數分割槽。

  def __getitem__(self, slice_spec):
    """Extracts the specified region as a Tensor from the sharded variable.

    The API contract is identical to  Tensor.__getitem__ . Assignment to the
    sliced range is not yet supported.

    Args:
      slice_spec: The arguments to __getitem__, specifying the global slicing of
        the sharded variable.

    Returns:
      The appropriate slice of tensor based on  slice_spec .

    Raises:
      IndexError: If a slice index is out of bound.
      TypeError: If  spec_spec  contains Tensor.
    """

    # 拿到分割槽 spec
    if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
                                         slice_spec.dtype == dtypes.bool) or
        (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)):
      tensor = _var_to_tensor(self)
      return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)

    if not isinstance(slice_spec, (list, tuple)):
      slice_spec = (slice_spec,)

    s = slice_spec[0]
    if isinstance(s, slice):
      # 如果是 slice 型別,則解析分割槽
      first_dim_slice_specs = self._decompose_slice_spec(s)
      values = []
      for i, var in enumerate(self._variables):
        if first_dim_slice_specs[i] is not None:
          all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
          values.append(var[all_dim_slice_spec])
      if s.step is not None and s.step < 0:
        values.reverse()
      if not values:
        return constant_op.constant([],
                                    dtype=self._dtype,
                                    shape=((0,) + self._shape[1:]))
      return array_ops.concat(values, axis=0)
    elif s is Ellipsis:
      return array_ops.concat([var[slice_spec] for var in self._variables],
                              axis=0)
    elif s is array_ops.newaxis:
      return array_ops.concat([var[slice_spec[1:]] for var in self._variables],
                              axis=0)[array_ops.newaxis]
    else:
      if isinstance(s, ops.Tensor):
        raise TypeError(
            'ShardedVariable: using Tensor for indexing is not allowed.')
      if s < 0:
        s += self._shape[0]
        
      # 在引數分割槽之中遍歷,用offset來提取資料
      for i in range(len(self._variables)):
        if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and
                                             s < self._var_offsets[i + 1][0]):
          return self._variables[i][(s - self._var_offsets[i][0],) +
                                    slice_spec[1:]]

Spec 一般來說是什麼樣式?下面示例講述的比較清晰。

    For example, given component variables:
      v0 = [0, 1, 2]
      v1 = [3, 4, 5]
      v2 = [6, 7, 8, 9]

    If  slice_spec  is slice(start=None, stop=None, step=None), we will have:
      v0[returned[0]] = [0, 1, 2]
      v1[returned[1]] = [3, 4, 5]
      v2[returned[2]] = [6, 7, 8, 9]
    If  slice_spec  is slice(start=2, stop=8, step=3), we will have:
      v0[returned[0]] = [2]
      v1[returned[1]] = [5]
      returned[2] == None
    If  slice_spec  is slice(start=9, stop=3, step=-2), we will have:
      returned[0] == None
      v1[returned[1]] = [5]
      v2[returned[2]] = [9, 7]

獲取/解析 spec 的程式碼具體如下:

  def _decompose_slice_spec(self, slice_spec):
    """Decompose a global slice_spec into a list of per-variable slice_spec.

     ShardedVariable  only supports first dimension partitioning, thus
     slice_spec  must be for first dimension.

    Args:
      slice_spec: A python  slice  object that specifies the global slicing.

    Returns:
      A list of python  slice  objects or None specifying the local slicing for
      each component variable. None means no slicing.

    """
    result = []
    # Normalize start, end and stop.
    slice_step = slice_spec.step if slice_spec.step is not None else 1
    if slice_step == 0:
      raise ValueError('slice step cannot be zero')
    slice_start = slice_spec.start
    if slice_start is None:
      slice_start = 0 if slice_step > 0 else self._shape[0] - 1
    elif slice_start < 0:
      slice_start += self._shape[0]
    slice_end = slice_spec.stop
    if slice_end is None:
      # After the normalization, we no longer interpret negative index, thus
      # "-1" conceptually refers to the element before the first one, which
      # doesn't exist. This is to ease the decomposition code.
      slice_end = self._shape[0] if slice_step > 0 else -1
    elif slice_end < 0:
      slice_end += self._shape[0]

    # To find the local slice_spec of each component variable, we start from
    # the start of the global slice, and iterate through each variable.
    # When iterating on a variable, we move the cursor ( cur ) to the first
    # index that falls into the variable's range, which becomes the start of
    # the variable's local slice_spec. The end of the local_spec is determined
    # by using whatever is smaller between global slice end and variable range
    # end.
    cur = slice_start
    if slice_step > 0:
      for i in range(len(self._var_offsets)):
        var_start = self._var_offsets[i][0]
        var_end = (
            self._var_offsets[i + 1][0]
            if i < len(self._var_offsets) - 1 else self._shape[0])
        if cur < var_start:
          cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
        if cur >= var_end or cur >= slice_end:
          result.append(None)
        else:
          start = cur - var_start
          end = min(slice_end, var_end) - var_start
          result.append(slice(start, end, slice_step))
    else:  # slice_step < 0
      for i in range(len(self._var_offsets) - 1, -1, -1):
        var_start = self._var_offsets[i][0]
        var_end = (
            self._var_offsets[i + 1][0]
            if i < len(self._var_offsets) - 1 else self._shape[0])
        if cur >= var_end:
          cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
        if cur < var_start or cur <= slice_end:
          result.append(None)
        else:
          start = cur - var_start
          if slice_end >= var_start:
            end = slice_end - var_start
          else:
            end = None  # no explicit end: slice until hitting the boundary.
          result.append(slice(start, end, slice_step))

      result.reverse()

    return result

2.4.3 Embedding

接下來我們看看嵌入的查詢。可以發現這裡就是呼叫時候新增了對應的 partition_strategy,name, validate_indices, max_norm 等資訊,然後傳遞給embedding_ops.embedding_lookup。這裡分割槽策略是 'mod'。

# Override the behavior of embedding_lookup(sharded_variable, ...)
@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
def embedding_lookup(params,
                     ids,
                     partition_strategy='mod',
                     name=None,
                     validate_indices=True,
                     max_norm=None):
  if isinstance(params, list):
    params = params[0]
  return embedding_ops.embedding_lookup(params.variables, ids,
                                        partition_strategy, name,
                                        validate_indices, max_norm)

流程來到 embedding_lookup(tensorflow/python/ops/embedding_ops.py),我們需要繼續看 _embedding_lookup_and_transform。

@tf_export(v1=["nn.embedding_lookup"])
@dispatch.add_dispatch_support
def embedding_lookup(
    params,
    ids,
    partition_strategy="mod",
    name=None,
    validate_indices=True,  # pylint: disable=unused-argument
    max_norm=None):
  """Looks up embeddings for the given  ids  from a list of tensors.

  This function is used to perform parallel lookups on the list of tensors in
   params .  It is a generalization of  tf.gather , where  params  is
  interpreted as a partitioning of a large embedding tensor.   params  may be
  a  PartitionedVariable  as returned by using  tf.compat.v1.get_variable() 
  with a partitioner.

  If  len(params) > 1 , each element  id  of  ids  is partitioned between
  the elements of  params  according to the  partition_strategy .
  In all strategies, if the id space does not evenly divide the number of
  partitions, each of the first  (max_id + 1) % len(params)  partitions will
  be assigned one more id.

  If the input ids are ragged tensors, partition variables are not supported and
  the partition strategy and the max_norm are ignored.
  The results of the lookup are concatenated into a dense
  tensor. The returned tensor has shape  shape(ids) + shape(params)[1:] .

  Args:
    params: A single tensor representing the complete embedding tensor, or a
      list of P tensors all of same shape except for the first dimension,
      representing sharded embedding tensors.  Alternatively, a
       PartitionedVariable , created by partitioning along dimension 0. Each
      element must be appropriately sized for the given  partition_strategy .
    ids: A  Tensor  or a 'RaggedTensor' with type  int32  or  int64  containing
      the ids to be looked up in  params .
    partition_strategy: A string specifying the partitioning strategy, relevant
      if  len(params) > 1 . Currently  "div"  and  "mod"  are supported. Default
      is  "mod" .
    name: A name for the operation (optional).
    validate_indices: DEPRECATED. If this operation is assigned to CPU, values
      in  indices  are always validated to be within range.  If assigned to GPU,
      out-of-bound indices result in safe but unspecified behavior, which may
      include raising an error.
    max_norm: If not  None , each embedding is clipped if its l2-norm is larger
      than this value.

  Returns:
    A  Tensor  or a 'RaggedTensor', depending on the input, with the same type
    as the tensors in  params .

  Raises:
    ValueError: If  params  is empty.
  """
  if isinstance(ids, ragged_tensor.RaggedTensor):
    return embedding_lookup_ragged(params, ids,
                                   partition_strategy=partition_strategy,
                                   max_norm=max_norm,
                                   name=name)

  return _embedding_lookup_and_transform(
      params=params,
      ids=ids,
      partition_strategy=partition_strategy,
      name=name,
      max_norm=max_norm,
      transform_fn=None)

_embedding_lookup_and_transform 這裡是具體如何分割槽的程式碼,我們先用例項演示一下。

  • 如果 "partition_strategy "是 "mod",我們將每個id分配給分割槽 p = id % len(params) 。例如。
    13個ID被分割到5個分割槽中,結果如下: [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]
  • 如果 "partition_strategy "是 "div",我們會以連續的方式將ID分配給分割槽。在這個例子中,13個ID被分成5個分割槽,結果如下: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]] 。

具體程式碼如下:

def _embedding_lookup_and_transform(params,
                                    ids,
                                    partition_strategy="mod",
                                    name=None,
                                    max_norm=None,
                                    transform_fn=None):
  """Helper function for embedding_lookup and _compute_sampled_logits.

  This function is a generalization of embedding_lookup that optionally
  applies a caller-specified transformation to each embedding. This is
  done through the  transform_fn  argument. If provided, the function is
  applied to each partitioned tensor of retrieved embeddings, colocated
  with the embeddings. This function will be called with a single  Tensor 
  argument of the same type as the  params  tensor and should return a
   Tensor . The shape of the argument will be the same as  params  except
  for the size of the first dimension. The first dimension of the result's
  shape must be the same size as the argument's.

  Args:
    params: See embedding_lookup.
    ids: See embedding_lookup.
    partition_strategy: See embedding_lookup.
    name: See embedding_lookup.
    max_norm: See embedding_lookup.
    transform_fn: An optional function to apply to each retrieved embedding. If
      max_norm is provided, transform_fn is applied to the norm-limited
      embeddings.

  Returns:
    See embedding_lookup for details.
  Raises:
    ValueError: If  params  is empty.
  """

  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
    # 省略程式碼
    else:
      # Flatten the ids. There are two cases where we need to do this.
      # - There is more than one params tensor.
      # - There is a transform_fn and ids is not statically known to be 1-D.
      #   We must flatten in this case because transform_fn expects a flat
      #   tensor of embeddings.
      flat_ids = array_ops.reshape(ids, [-1])
      original_indices = math_ops.range(array_ops.size(flat_ids))

      # Create p_assignments and set new_ids depending on the strategy.
      if partition_strategy == "mod":
        p_assignments = flat_ids % np
        new_ids = flat_ids // np
      elif partition_strategy == "div":
        # Compute num_total_ids as the sum of dim-0 of params, then assign to
        # partitions based on a constant number of ids per partition. Optimize
        # if we already know the full shape statically.
        dim_0_size = tensor_shape.Dimension(
            tensor_shape.dimension_value(params[0].get_shape()[0]))
        for p in xrange(1, np):
          dim_0_size += tensor_shape.Dimension(
              tensor_shape.dimension_value(params[p].get_shape()[0]))
        if dim_0_size.value:
          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
        else:
          dim_0_sizes = []
          for p in xrange(np):
            param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0])
            if param_p_dim is not None:
              dim_0_sizes.append(param_p_dim)
            else:
              with ops.colocate_with(params[p]):
                dim_0_sizes.append(array_ops.shape(params[p])[0])
          num_total_ids = math_ops.reduce_sum(
              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
        ids_per_partition = num_total_ids // np
        extras = num_total_ids % np

        p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1),
                                         (flat_ids - extras) //
                                         ids_per_partition)

        # Emulate a conditional using a boolean indicator tensor
        new_ids = array_ops.where(p_assignments < extras,
                                  flat_ids % (ids_per_partition + 1),
                                  (flat_ids - extras) % ids_per_partition)
      else:
        raise ValueError("Unrecognized partition strategy: " +
                         partition_strategy)

  # 省略其他程式碼

如何使用 embedding?我們從註釋之中提取使用方法如下,這裡構建了一個 ShardedVariable,模型通過 embedding_lookup 來對此變數進行操作。

  >>> class Model(tf.Module):
  ...   def __init__(self):
  ...     self.sharded_variable = ShardedVariable([
  ...       tf.Variable([3.0], dtype=tf.float32),
  ...       tf.Variable([2.0], dtype=tf.float32)
  ...     ])
  ...
  ...   @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
  ...   def fn(self, x):
  ...     return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
  ...
  ...   @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
  ...   def serve_fn(self, x):
  ...     return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
  >>>
  >>> model = Model()
  >>> model.fn(1).numpy()
  2.0
  >>> tf.saved_model.save(model, export_dir='/tmp/saved_model',
  ...   signatures=model.serve_fn)

如果用圖例表示,則下面 worker 會在兩個引數伺服器上並行操作來提取 embedding。

圖 22 處理 embedding

2.5 構建

關於 ShardedVariable 的構建,我們直接看 ParameterServerStrategyV2 之中的構建過程。

2.5.1 變數分片

要啟用變數分片,你可以在構建 ParameterServerStrategy 物件時傳入一個 variable_partitioner。每次建立變數時,variable_partitioner 都會被呼叫,並希望它能沿變數的每個維度返回分片的數量。系統提供了一些開箱即用的 variable_partitioner,比如 tf.distribution.experimental.partitioners.MinSizePartitioner 。建議使用基於大小(size-based)的分割槽器,如 tf.distribution.experimental.partitioners.MinSizePartitioner ,以避免對小變數進行分割槽,因為那樣可能對模型訓練速度產生負面影響。

當傳入 variable_partitioner 時候,如果你直接在 strategy.scope() 下建立一個變數,它將成為一個具有 variables 屬性(property)的容器型別,此屬性將提供對分片列表的訪問。在大多數情況下,這個容器將通過連線(concatenating)所有的分片自動轉換為一個張量。因此,它可以作為一個正常的變數使用。另一方面,一些TensorFlow方法,如 tf.nn.embedding_lookup 為這種容器型別提供了有效的實現,這些方法可以避免自動連線。

3.2.4 初始化

在 ParameterServerStrategyV2Extended 初始化時候,會把傳入的 variable_partitioner 設定到 _variable_partitioner 之中,也會配置引數伺服器數目和 worker 數目。

class ParameterServerStrategyV2Extended(
    parameter_server_strategy.ParameterServerStrategyExtended):
  """Extended class for ParameterServerStrategyV2.

  Please see  tf.distribute.StrategyExtended  doc for more information.
  """

  def __init__(self, container_strategy, cluster_resolver,
               variable_partitioner):
    """Initialization of ParameterServerStrategyV2Extended."""
    super(ParameterServerStrategyV2Extended, self).__init__(container_strategy)
    self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", []))
    self._num_workers = len(cluster_resolver.cluster_spec().as_dict().get(
        "worker", []))
    self._variable_count = 0

    self._variable_partitioner = variable_partitioner

2.5.3 構建

我們接下來看看建立過程,也就是如何把變數分片到不同引數伺服器上。具體思路是:

  • 沒有配置分割槽生成器的話,就用 RR 策略(_create_variable_round_robin)把變數分配到引數伺服器之上。
  • 如果配置了分割槽生成器,則做如下操作:
    • 對 rank-0 不做分割槽。
    • 通過 _variable_partitioner 得到分割槽數目。
    • 分割槽數目需要大於第一維,否則用第一維。
    • 計算張量 offset。
    • 生成很多小張量。
    • 使用 _create_variable_round_robin 構建小張量列表。
    • 用小張量列表來生成 ShardedVariable。
  def _create_variable(self, next_creator, **kwargs):
    """Implements StrategyExtendedV2._create_variable.

    Creates a  Variable  or a  ShardedVariable . A  ShardedVariable  will be
    created if satisfying all the following criteria:
      1.  self._variable_partitioner  results in more than one partition on the
         first axis.
      2. variable's rank is greater than 0.
      3. variable is not colocated with another variable.
    Otherwise a  Variable  will be created.

    Args:
      next_creator: See  variable_scope.variable_creator_scope ; the next
        creator in the chain.
      **kwargs: Passed through to the next creator.

    Returns:
      A  Variable  or  ShardedVariable .
    """

    var_creator = self._create_var_creator(next_creator, **kwargs)
    if "colocate_with" in kwargs:  # Never partition colocated_with variables.
      colocate_with = kwargs["colocate_with"]
      # Clear the variable scope to avoid possible conflicts between device
      # scope and colocation scope.
      with ops.device(None):
        with ops.colocate_with(colocate_with):
          var = var_creator(**kwargs)
          return var

    # 沒有配置分割槽生成器的話,就用 RR 策略把變數分配到引數伺服器之上
    if self._variable_partitioner is None:
      return self._create_variable_round_robin(var_creator, **kwargs)

  # 下面是配置了分割槽生成器
    name = kwargs.get("name", None)
    initial_value = kwargs.get("initial_value", None)

    # Two cases where initial_value can be a callable:
    #   1. initial_value is passed as a callable, e.g, an  initializer  class.
    #   2. restoring from checkpoint, initial_value is a
    #     "CheckpointInitialValueCallable".
    init_from_fn = callable(initial_value)

    dtype = kwargs.get("dtype", None)
    shape = kwargs.get("shape", None)
    if init_from_fn and (shape is None or dtype is None):
      init_from_fn = False
      initial_value = initial_value()
    if not init_from_fn:
      # The initial_value is created on coordinator, it will need to be sent to
      # ps for variable initialization, which can be inefficient and can
      # potentially hit the 2GB limit on protobuf serialization.
      initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
      dtype = initial_value.dtype
      shape = initial_value.shape
    else:
      shape = tensor_shape.as_shape(shape)

    # rank-0 不做分割槽
    if shape.rank == 0:  # Skip partitioning rank-0 variable.
      return self._create_variable_round_robin(var_creator, **kwargs)

    # 得到分割槽數目
    num_partitions = self._variable_partitioner(shape=shape, dtype=dtype)
    if num_partitions[0] == 1:  # no partition
      return self._create_variable_round_robin(var_creator, **kwargs)

    # 分割槽數目需要大於第一維,否則用第一維
    # Use "div" partition strategy to partition the variable.
    num_partitions = min(num_partitions[0], shape[0])
    base = shape[0] // num_partitions
    
    # 計算 offset
    extra = shape[0] % num_partitions
    # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2]
    # offsets: [0, 3, 6, 8, 10]
    offsets = []
    for i in range(num_partitions):
      if i == 0:
        offsets.append(0)
      else:
        prev_shard_size = base + (1 if i - 1 < extra else 0)
        offsets.append(offsets[i - 1] + prev_shard_size)
    offsets.append(shape[0])

    def init_shard_fn(shard_index):
      if not init_from_fn:
        return initial_value[offsets[shard_index]:offsets[shard_index + 1]]
    
      partition_shape = (offsets[shard_index + 1] -
                         offsets[shard_index],) + shape[1:]
      partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:])
      arg_spec = tf_inspect.getfullargspec(initial_value)
      if ("shard_info" not in arg_spec.args and
          "shard_info" not in arg_spec.kwonlyargs):
        try:
          value = initial_value(
              partition_shape=partition_shape,
              partition_offset=partition_offset)
        except (TypeError, ValueError):
          # TypeError: Initializer doesn't accept kwargs
          # ValueError: Initializer doesn't accept partition kwargs
          # In both cases we go ahead creating the full value and then slice.
          value = initial_value()

        if value.shape == partition_shape:
          # Initializer supports partition: value is the partition value.
          return value
        else:
          # Initializer doesn't support partition: value is the full value
          # and needs to be sliced to get the partition value.
          return value[offsets[shard_index]:offsets[shard_index + 1]]
      else:
        # For compatibility with  CheckpointInitialValueCallable .
        return initial_value(
            shard_info=trackable.ShardInfo(
                shape=tensor_shape.as_shape(partition_shape),
                offset=partition_offset))

    # 生成很多小張量
    var_list = []
    for i in range(num_partitions):
      kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:]
      kwargs["initial_value"] = lambda: init_shard_fn(i) # 初始化
      if name is not None:
        kwargs["name"] = "{}/part_{}".format(name, i)
      # 使用 _create_variable_round_robin 得到張量如何分配  
      var_list.append(self._create_variable_round_robin(var_creator, **kwargs))

    #用小張量列表來生成 ShardedVariable
    result = sharded_variable.ShardedVariable(var_list)
    return result

上面邏輯之中,兩個分支都使用了 _create_variable_round_robin,其使用 RR 策略決定具體 placement 如何做。其實,就是給張量配置了對應的裝置名字,後續做佈局操作時候,就按照裝置名字進行操作。

  def _create_variable_round_robin(self, next_creator, **kwargs):
    # Clear the colocation scope to avoid possible conflicts between device
    # scope and colocation scope.
    with ops.colocate_with(None, ignore_existing=True):
      # Explicitly set CPU:0 device for PS in case create variable is called
      # inside replica_fn and worker has with GPU:0 scope.
      with ops.device("/job:ps/task:%d/device:CPU:0" %
                      (self._variable_count % self._num_ps)):
        var = next_creator(**kwargs)
        logging.debug(
            "Creating variable (name:%s, shape:%r) on "
            "/job:ps/task:%d/device:CPU:0",
            var.name, var.shape, (self._variable_count % self._num_ps))
        self._variable_count += 1
        return var

_create_variable_round_robin 的引數 next_creator 一般來說是如下方法,這裡使用了 AggregatingVariable 和 CachingVariable 來構建變數列表 var_list,然後才是利用 var_list 構建 ShardedVariable。我們主要介紹 AggregatingVariable。

  def _create_var_creator(self, next_creator, **kwargs):
    aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)

    def var_creator(**kwargs):
      """Create an AggregatingVariable."""
      # Create and wrap the variable.
      v = next_creator(**kwargs)
      wrapped_v = ps_values.CachingVariable(v)
      wrapped = ps_values.AggregatingVariable(self._container_strategy(),
                                              wrapped_v, aggregation)
      return wrapped

    if self._num_replicas_in_sync > 1:
      if aggregation not in (
          vs.VariableAggregation.NONE,
          vs.VariableAggregation.SUM,
          vs.VariableAggregation.MEAN,
          vs.VariableAggregation.ONLY_FIRST_REPLICA
      ):
        raise ValueError("Invalid variable aggregation mode: " + aggregation +
                         " for variable: " + kwargs["name"])
      return var_creator
    else:
      def variable_creator_single_replica(**kwargs):
        v = next_creator(**kwargs)
        return ps_values.CachingVariable(v)
      return variable_creator_single_replica

2.5.4 AggregatingVariable

AggregatingVariable 作用是對變數進行包裝,該變數可以進行跨副本彙集更改。以 _assign_func 為例,可以看到,其使用 _distribute_strategy.extended.update 對變數進行操作。

# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy.
class AggregatingVariable(resource_variable_ops.BaseResourceVariable,
                          core.Tensor):
  """A wrapper around a variable that aggregates updates across replicas."""

  def __init__(self, strategy, v, aggregation):
    self._distribute_strategy = strategy
    self._v = v
    # NOTE: We don't use "_distributed_container" here because we don't want
    # to trigger that code path in regroup().
    v._aggregating_container = weakref.ref(self)  # pylint: disable=protected-access
    self._aggregation = aggregation

  def __deepcopy__(self, memo):
    """Perform a deepcopy of the  AggregatingVariable .

    Unlike the deepcopy of a regular tf.Variable, this keeps the original
    strategy and devices of the  AggregatingVariable .  To avoid confusion
    with the behavior of deepcopy on a regular  Variable  (which does
    copy into new devices), we only allow a deepcopy of a  AggregatingVariable 
    within its originating strategy scope.

    Args:
      memo: The memoization object for  deepcopy .

    Returns:
      A deep copy of the current  AggregatingVariable .

    Raises:
      RuntimeError: If trying to deepcopy into a different strategy.
    """
    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
      v = copy.deepcopy(self._v, memo)

    copied_variable = type(self)(
        strategy=self._distribute_strategy,
        v=v,
        aggregation=self._aggregation)

    memo[id(self)] = copied_variable

    return copied_variable

  def get(self):
    return self._v

  @property
  def distribute_strategy(self):
    return self._distribute_strategy

  def __getattr__(self, name):
    return getattr(self._v, name)

  def _assign_func(self, *args, **kwargs):
    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
      f = kwargs.pop("f")
      if ds_context.in_cross_replica_context():
        if distribute_lib.get_update_replica_id() is not None:
          # We are calling an assign function in an update context.
          return f(self._v, *args, **kwargs)

        # We are calling an assign function in cross replica context, wrap it in
        # an update call.
        return self._distribute_strategy.extended.update(
            self, f, args=args, kwargs=kwargs)
      else:
        replica_context = ds_context.get_replica_context()
          # We are calling an assign function in replica context.
        # We reduce the value we want to assign/add/sub. More details about how
        # we handle the different use cases can be found in the _reduce method.
        # We call the function with the reduced value.
        if self._aggregation == vs.VariableAggregation.NONE:
          raise ValueError(
              values_util.aggregation_error_msg.format(
                  variable_type="AggregatingVariable"))

        def merge_fn(strategy,
                     value,
                     use_locking=False,
                     name=None,
                     read_value=True):
          v = values_util.apply_aggregation(strategy, value, self._aggregation,
                                            self)
          if name and isinstance(name, values.PerReplica):
            name = name.values[0]
          return strategy.extended.update(
              self,
              f,
              args=(v,),
              kwargs={
                  "use_locking": use_locking,
                  "name": name,
                  "read_value": read_value
              })
        return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)

2.6 使用

下面示例展示了 ShardedVariable 如何使用。在 Dense 之中構建了一個 ShardedVariable,就是 self.w,其 shape 是 [100, 10],分割槽之後的結果是兩個 (50, 10) 的張量。

  class Dense(tf.Module):
    def __init__(self, name=None):
      super().__init__(name=name)
      self.w = tf.Variable(tf.random.normal([100, 10]), name='w')

    def __call__(self, x):
      return x * self.w

  # Partition the dense layer into 2 shards.
  variable_partitioner = (
    tf.distribute.experimental.partitioners.FixedShardsPartitioner(
      num_shards = 2))
  strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver=...,
    variable_partitioner = variable_partitioner)
  with strategy.scope():
    dense = Dense() # 位於 strategy 上下文之中,於是生成的變數被自動分成 2 個分割槽。
    
  assert len(dense.variables) == 2
  assert isinstance(dense.variables[0], tf.Variable)
  assert isinstance(dense.variables[1], tf.Variable)
  assert dense.variables[0].shape == (50, 10)
  assert dense.variables[1].shape == (50, 10)

ShardedVariable 也是一種形式上的模型並行,比如把 AB 這個矩陣分解到兩個引數伺服器之上,分別與 C 相乘,最後把相乘結果在 worker 上聚合起來, concatenation 成一個最終結果張量。

圖 23 合併張量

0xFF 參考

tensorflow原始碼解析之distributed_runtime

TensorFlow分散式訓練

TensorFlow核心剖析

原始碼

Tensorflow分散式原理理解

TensorFlow架構與設計:概述

Tensorflow 跨裝置通訊

TensorFlow 篇 | TensorFlow 2.x 分散式訓練概覽

《用TensorFlow 2.4 實現分散式訓練》周玥楓 https://www.bilibili.com/video/BV1MT4y1M7Ym

深入 TensorFlow:引數伺服器訓練 https://www.bilibili.com/video/BV1u5411H798

相關文章