[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V2

羅西的思考 發表於 2022-05-14

[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V2

對於 ParameterServerStrategy V2,我們將從幾個方面來研究:如何與叢集建立連線,如何生成變數,如何獲取資料,如何執行。其中,變數和作用域我們在前文已經研究過,執行在 MirroredStrategy 裡面也介紹,所以本文主要看看如何使用,如何初始化。在下一篇之中會重點看看如何分發計算。

安利兩個github,都是非常好的學習資料,推薦。

https://github.com/yuhuiaws/ML-study

https://github.com/Jack47/hack-SysML

另外推薦西門宇少的最新大作讓Pipeline在Transformer LM上沿著Token level並行起來——TeraPipe

本系列其他文章是:

[翻譯] TensorFlow 分散式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"

[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構

[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache

[原始碼解析] TensorFlow 分散式環境(5) --- Session

[原始碼解析] TensorFlow 分散式環境(7) --- Worker 動態邏輯

[原始碼解析] TensorFlow 分散式環境(8) --- 通訊機制

[翻譯] 使用 TensorFlow 進行分散式訓練

[原始碼解析] TensorFlow 分散式 DistributedStrategy 之基礎篇

[原始碼解析] TensorFlow 之 分散式變數

[原始碼解析] TensorFlow 分散式之 MirroredStrategy

[原始碼解析] TensorFlow 分散式之 MirroredStrategy 分發計算

[原始碼解析] TensorFlow 分散式之 ParameterServerStrategy V1

1. 如何使用

在 TensorFlow 2 中,引數伺服器訓練由 tf.distribution.experimental.ParameterServerStrategy 類提供支援,該類將訓練步驟分佈到一個可擴充套件到數千個工作者(伴隨著引數伺服器)的叢集。

1.1 訓練方法

支援訓練有兩種主要方法:

  • Keras Model.fit API。如果使用者喜歡用高層次抽象來訓練,則建議使用這種方式。
  • 自定義訓練迴圈(custom training loop)。如果使用者需要自己實現或者定義訓練細節,則可以考慮這種方式。

1.2 叢集

無論選擇何種API( Model.fit 或自定義訓練迴圈),TensorFlow 2中的分散式訓練都會涉及如下概念:一個"叢集" 有若干個"作業(job)",每個作業可能包括一個或多個"任務"。而當使用引數伺服器訓練時,建議使用如下配置:

  • 一個協調者(coordinator ) job(job名稱為 chief)。
  • 多個工作者 jobs(job名稱為 worker)。
  • 多個引數伺服器 jobs(job名稱為 ps)。

協調者負責建立資源、分配訓練任務、寫檢查點和處理任務失敗,工作者引數伺服器則執行 tf.distribution.Server 來聽取協調者的請求。

1.3 使用 Model.fit API 進行訓練

如果使用 "Model.fit" API,則引數伺服器訓練需要協調者使用 tf.distribution.experimental.ParameterServerStrategy 物件和 tf.keras.utils.experimental.DatasetCreator 作為輸入。與其他策略類似,其工作流程包括:建立和編譯模型,準備回撥,呼叫 Model.fit。

1.4 使用自定義迴圈進行訓練

TensorFlow 2 推薦使用一種基於中央協調的架構來進行引數伺服器訓練。每個工作者和引數伺服器都執行一個 tf.distribution.Server,在此基礎上,一個協調者任務負責在工作者和引數伺服器上建立資源,排程功能,並協調訓練。協調器使用 tf.distribution.experimental.coordinator.ClusterCoordinator 來協調叢集,使用 tf.distribution.experimental.ParameterServerStrategy 來定義引數伺服器上的變數和工作者的計算。在自定義訓練迴圈中, tf.distribution.experimental.coordinator.ClusterCoordinator 類是用於協調器的關鍵元件。

  • ClusterCoordinator 類需要與 tf.distribution.Strategy 物件一起工作。
  • 對於引數伺服器訓練, ClusterCoordinator 需要與 tf.distribution.experimental.ParameterServerStrategy 一起工作。
  • 這個 tf.distribution.Strategy 物件需要使用者提供叢集的資訊,並使用這些資訊來定義訓練步驟。然後, ClusterCoordinator 物件將這些訓練步驟的執行分派給遠端工作者。

ClusterCoordinator 提供的最重要的 API 是 schedule 。

  • Schedule API 把一個 tf.function 插入佇列,並立即返回一個類似 future 的 RemoteValue 。
  • 在佇列之中排隊的函式被派發給後臺執行緒中的遠端工作者,他們的 RemoteValue 將被非同步賦值。
  • 由於 schedule 不需要執行分配任務,因此傳遞進來的 tf.function 可以在任何可用的工作者上執行。
  • 如果被執行的工作者在結束之前變得不可用,該 tf.function 將在另一個可用的工作者上重試。
  • 由於函式的執行不是原子性的,所以一個函式可能被執行多次。

除了排程遠端函式這個功能之外,ClusterCoordinator 還幫助在所有工作者上建立資料集,以及當一個工作者從失敗中恢復時重建這些資料集。

1.5 建立叢集

如上所述,一個引數伺服器訓練叢集需要一個協調者任務來執行你的訓練程式,程式包括一個或幾個執行TensorFlow 伺服器( tf.distribution.Server )的工作者和引數伺服器,可能還有一個執行 side-car 評估的評估任務。設定它們的要求是。

  • 協調者(coordinator)任務需要知道所有其他 TensorFlow 伺服器(評估者除外)的地址和埠。
  • 工作者和引數伺服器需要知道他們應該監聽哪個埠。為了簡單起見,使用者通常可以在這些任務上建立 TensorFlow 伺服器時傳入完整的叢集資訊。
  • 評估器(evaluator)任務不需要知道訓練叢集的設定,它也不應該試圖連線到訓練叢集。
  • 工作者和引數伺服器的任務型別應該分為 "worker" 和 "ps" 兩種。出於歷史原因,協調器應使用 "chief" 作為任務型別。

2. 初始化

2.1 用例

以下是如何初始化 ParameterServerStrategy 的樣例,無論是使用 Model.fit 還是自定義迴圈,都需要這步工作。為了使用 GPU 進行訓練,需要為每個工作者分配可見的 GPU。 ParameterServerStrategy 將使用每個工作者上所有可用的 GPU,但有個限制是:所有工作者都應該有相同數量的 GPU 可用。

variable_partitioner = (
    tf.distribute.experimental.partitioners.MinSizePartitioner(
        min_shard_bytes=(256 << 10),
        max_shards=NUM_PS))

strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver,
    variable_partitioner=variable_partitioner)

對於 variable_partitioner,這是一個 distribute.experimental.partitioners.Partitioner,其指定如何對變數進行分割槽。如果是 None,變數將不被分割,其特點如下:

  • 此引數取值是 tf.distribute.experimental.partitioners 中預定義的分割槽器。一個常用的分割槽器是 MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps),它為每個分片分配至少 256K,每個 ps 最多得到一個分片。
  • 在策略 scope 下建立的每個變數都會呼叫 variable_partitioner,以指示該變數應如何分割槽。沿著分割槽軸只有一個分割槽的變數(即不需要分割槽)將被建立為一個普通的 tf.Variable 。
  • 只支援第一個/最外層軸的分割槽。
  • Div 分割槽策略被用來對變數進行分割槽。假設我們沿著變數的第一軸分配連續的整數 id,那麼 id 會以連續的方式分配給分片,同時試圖保持每個分片的大小相同。如果 id 不能平均分配給分片的數量,那麼前幾個分片中的每一個將被多分配一個 id。例如,一個變數的第一個維度是 13,它有 13 個 id,它們被分成 5 個分片。 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]] .
  • 在 strategy.extended.colocate_vars_with 下建立的變數將不會被分割。

2.2 叢集設定

在真實的生產環境中,使用者需要在不同機器上的所有不同程式中執行訓練任務。在每個任務上配置叢集資訊的最簡單方法是設定"TF_CONFIG" 環境變數,並使用 tf.distribution.cluster_resolver.TFConfigClusterResolver 來解析"TF_CONFIG" 。如果使用者使用 Kubernetes 或其他配置模板開始訓練任務,很可能這些模板已經設定了"TF_CONFIG"

2.2.1 設定 "TF_CONFIG" 環境變數

假定你有 3 個工作者,3 個引數伺服器,那麼 worker 1 的 "TF_CONFIG" 可以如下:

os.environ["TF_CONFIG"] = json.dumps({
   "cluster": {
       "worker": ["host1:port","host2:port","host3:port"],
       "ps": ["host4:port","host5:port"],
       "chief": ["host6:port"]
    },
   "task": {"type":"worker","index": 1}
})

2.2.2 使用二進位制檔案

如果你喜歡用一個二進位制檔案來執行所有這些任務,你將需要在程式開始就指明不同分支負責處理不同的角色。

cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
if cluster_resolver.task_type in ("worker","ps"):
  # Start a TensorFlow server and wait.
elif cluster_resolver.task_type =="evaluator":
  # Run side-car evaluation
else:
  # Run the coordinator.

如下程式碼啟動一個 TensorFlow server 然後等待完成。

# Set the environment variable to allow reporting worker and ps failure to the
# coordinator. This is a workaround and won't be necessary in the future.
os.environ["GRPC_FAIL_FAST"] ="use_caller"

server = tf.distribute.Server(
    cluster_resolver.cluster_spec(),
    job_name=cluster_resolver.task_type,
    task_index=cluster_resolver.task_id,
    protocol=cluster_resolver.rpc_layer or"grpc",
    start=True)
server.join()

2.3 初始化方法

初始化方法如下,主要工作是連線到叢集,然後呼叫 _extended 進行繼續初始化。

  def __init__(self, cluster_resolver, variable_partitioner=None):
   """Initializes the TF2 parameter server strategy.

    This initializes the  tf.distribute.experimental.ParameterServerStrategy 
    object to be ready for use with
     tf.distribute.experimental.coordinator.ClusterCoordinator .
   """
    # pyformat: enable
    self._cluster_resolver = cluster_resolver

    self._verify_args_and_config(cluster_resolver)
    self._cluster_coordinator = None

    self._connect_to_cluster(coordinator_name="chief") # 連線到叢集
    self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
                                                       variable_partitioner)
    super(ParameterServerStrategyV2, self).__init__(self._extended)
    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
       "ParameterServerStrategy")
    self._should_use_with_coordinator = True
    # Used while constructing distributed iterators.
    self._canonicalize_devices = False

2.4 連線到叢集

_connect_to_cluster 起到了連線到叢集的作用,其主要邏輯是設定了 filter,然後呼叫 remote.connect_to_cluster 去連線叢集。

  def _connect_to_cluster(self, coordinator_name):
    if coordinator_name in ["worker","ps"]:
      raise ValueError("coordinator name should not be 'worker' or 'ps'.")
    cluster_spec = self._cluster_resolver.cluster_spec()
    self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
    self._num_ps = len(cluster_spec.as_dict().get("ps", ()))

    device_filters = server_lib.ClusterDeviceFilters()
    # For any worker, only the devices on ps and coordinator nodes are visible
    for i in range(self._num_workers):
      device_filters.set_device_filters(
         "worker", i, ["/job:ps","/job:%s" % coordinator_name])
    # Similarly for any ps, only the devices on workers and coordinator are
    # visible
    for i in range(self._num_ps):
      device_filters.set_device_filters(
         "ps", i, ["/job:worker","/job:%s" % coordinator_name])

    # Allow at most one outstanding RPC for each worker at a certain time. This
    # is to simplify worker failure handling in the runtime
    os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] ="False"

    remote.connect_to_cluster(
        cluster_spec,
        job_name=coordinator_name,
        protocol=self._cluster_resolver.rpc_layer,
        cluster_device_filters=device_filters)

    distribute_lib.distribution_strategy_replica_gauge.get_cell(
       "ps_strategy_num_workers").set(self._num_workers)
    distribute_lib.distribution_strategy_replica_gauge.get_cell(
       "ps_strategy_num_ps").set(self._num_ps)

connect_to_cluster 方法會連線到給定的叢集,使叢集上的裝置可用。如果給定的本地 job 名稱沒有出現在叢集規範中,它將被自動新增,並且使用本地主機上一個未使用的埠。

工作者如果在被過濾的遠端裝置上訪問資源或啟動程式/功能,將導致一個未知裝置錯誤。對於任何遠端任務,如果沒有裝置過濾器,所有的叢集裝置都是可見的;如果指定了裝置過濾器,任務則只能看到與至少一個過濾器匹配的裝置。任務本身的裝置始終是可見的。

以下是使用樣例。

cdf = tf.config.experimental.ClusterDeviceFilters()
# For any worker, only the devices on PS nodes and itself are visible
for i in range(num_workers):
  cdf.set_device_filters('worker', i, ['/job:ps'])
# Similarly for any ps, only the devices on workers and itself are visible
for i in range(num_ps):
  cdf.set_device_filters('ps', i, ['/job:worker'])

tf.config.experimental_connect_to_cluster(cluster_def,
                                          cluster_device_filters=cdf)

具體 connect_to_cluster 的程式碼如下。

@tf_export("config.experimental_connect_to_cluster")
def connect_to_cluster(cluster_spec_or_resolver,
                       job_name="localhost",
                       task_index=0,
                       protocol=None,
                       make_master_device_default=True,
                       cluster_device_filters=None):
 """Connects to the given cluster.

  Will make devices on the cluster available to use. Note that calling this more
  than once will work, but will invalidate any tensor handles on the old remote
  devices.

  If the given local job name is not present in the cluster specification, it
  will be automatically added, using an unused port on the localhost.

  Device filters can be specified to isolate groups of remote tasks to avoid
  undesired accesses between workers. Workers accessing resources or launching
  ops / functions on filtered remote devices will result in errors (unknown
  devices). For any remote task, if no device filter is present, all cluster
  devices will be visible; if any device filter is specified, it can only
  see devices matching at least one filter. Devices on the task itself are
  always visible. Device filters can be particially specified.

  Args:
    cluster_spec_or_resolver: A  ClusterSpec  or  ClusterResolver  describing
      the cluster.
    job_name: The name of the local job.
    task_index: The local task index.
    protocol: The communication protocol, such as "grpc" . If unspecified, will
      use the default from  python/platform/remote_utils.py .
    make_master_device_default: If True and a cluster resolver is passed, will
      automatically enter the master task device scope, which indicates the
      master becomes the default device to run ops. It won't do anything if
      a cluster spec is passed. Will throw an error if the caller is currently
      already in some device scope.
    cluster_device_filters: an instance of
       tf.train.experimental/ClusterDeviceFilters  that specify device filters
      to the remote tasks in cluster.
 """
  if not context.executing_eagerly():
    raise ValueError(
       " tf.config.experimental_connect_to_cluster  can only be called in"
       "eager mode."
    )
  protocol = protocol or remote_utils.get_default_communication_protocol()
  if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
    cluster_spec = cluster_spec_or_resolver
  elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver):
    if cluster_spec_or_resolver.master() in _LOCAL_MASTERS:
      # Do nothing if the master is local.
      return
    cluster_spec = cluster_spec_or_resolver.cluster_spec()
  else:
    raise ValueError(
       " cluster_spec_or_resolver  must be a  ClusterSpec  or a"
       " ClusterResolver .")

  cluster_def = copy.deepcopy(cluster_spec.as_cluster_def())
  if cluster_device_filters:
    if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters):
      cluster_device_filters = copy.deepcopy(
          cluster_device_filters._as_cluster_device_filters())  
    else:
      raise ValueError(" cluster_device_filters  must be an instance of"
                      " tf.train.experimental.ClusterDeviceFilters .")

  # Automatically add local job, if not part of the cluster spec.
  if job_name not in cluster_spec.jobs:
    local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
    job_def = cluster_def.job.add()
    job_def.name = job_name
    job_def.tasks[0] ="localhost:{}".format(local_port)

  server_def = ServerDef(
      cluster=cluster_def,
      job_name=job_name,
      task_index=task_index,
      protocol=protocol,
      default_session_config=context.context().config,
      cluster_device_filters=cluster_device_filters)

  if context.get_server_def() is None:
    context.set_server_def(server_def) # 這裡會做處理裝置
  else:
    context.update_server_def(server_def)

  # 配置 master Device  
  if make_master_device_default and isinstance(
      cluster_spec_or_resolver,
      cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master():
    master = cluster_spec_or_resolver.master()
    master_job_name = None
    master_task_id = None
    for job_name in cluster_spec.jobs:
      for task_id in cluster_spec.task_indices(job_name):
        task_address = cluster_spec.task_address(job_name, task_id)
        if master in task_address or task_address in master:
          master_job_name = job_name
          master_task_id = task_id
          break

    if not master_job_name:
      raise ValueError(
         " make_master_device_default  is set to True but cannot find"
         "master %s in the cluster" % master)

    master_device ="/job:{}/replica:0/task:{}".format(master_job_name,
                                                       master_task_id)
    master_device = device_util.canonicalize(master_device)
    current_device = device_util.current()
    if current_device:
      current_device = device_util.canonicalize(current_device)
    if current_device and current_device != master_device:
      raise ValueError(" connect_to_cluster  is called inside existing device"
                      "scope %s, which is different from the master device"
                      "scope %s to enter. This is not allowed." %
                       (current_device, master_device))

    if not current_device:
      logging.info("Entering into master device scope: %s", master_device)
      ops.device(master_device).__enter__()

2.5 初始化裝置

set_server_def 會呼叫 _initialize_logical_devices 來初始化邏輯裝置。

  def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
   """Allow setting a server_def on the context.

    When a server def is replaced, it effectively clears a bunch of caches
    within the context. If you attempt to use a tensor object that was pointing
    to a tensor on the remote device, it will raise an error.

    Args:
      server_def: A tensorflow::ServerDef proto. Enables execution on remote
        devices.
      keep_alive_secs: Num. seconds after which the remote end will hang up. As
        long as the client is still alive, the server state for the context will
        be kept alive. If the client is killed (or there is some failure), the
        server will clean up its context keep_alive_secs after the final RPC it
        receives.

    Raises:
      ValueError: if server_def is None.
   """
    if not server_def:
      raise ValueError("server_def is None.")

    self._server_def = server_def

    if self._context_handle:
      server_def_str = server_def.SerializeToString()
      pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
                                         server_def_str)
      self._initialize_logical_devices()

    # Clear all the caches in case there are remote tensors in them.
    self._clear_caches()

_initialize_logical_devices 則會呼叫上下文物件的方法和一些其他方法來實現功能。

  def _initialize_logical_devices(self):
   """Helper to initialize devices."""
    # Store list of devices
    logical_devices = []
    context_devices = []
    device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
    try:
      self._num_gpus = 0
      for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
        dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
        context_devices.append(pydev.canonical_name(dev_name))
        spec = pydev.DeviceSpec.from_string(dev_name)
        # If the job is localhost, we assume that the cluster has not yet been
        # configured and thus clear the job, replica & task.
        if spec.job =="localhost":
          spec = spec.replace(job=None, replica=None, task=None)
        logical_devices.append(
            LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
        dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
        if dev_type =="GPU":
          self._num_gpus += 1

    finally:
      self._logical_devices = logical_devices
      self._context_devices = context_devices
      pywrap_tfe.TF_DeleteDeviceList(device_list)

我們以 TFE_ContextListDevices 為例來看,其呼叫到了 Context 的 ListDevices 方法。

TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
  TF_DeviceList* l = new TF_DeviceList;
  tensorflow::unwrap(ctx)->ListDevices(&l->response);
  return l;
}

上下文如何實現,就需要具體情況具體分析了,比如下面的生成上下文的程式碼。

TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
  if (opts->use_tfrt) {
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
    tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
        opts->session_options.options,
        static_cast<tensorflow::ContextDevicePlacementPolicy>(
            opts->device_placement_policy),
        opts->async, opts->use_tfrt_distributed_runtime);
#if !defined(IS_MOBILE_PLATFORM)
    tfrt_context->SetDistributedManager(
        tfrt::tf::CreateDistributedManagerContext(
            tfrt_context->GetCoreRuntime()->GetHostContext()));
#endif  // !IS_MOBILE_PLATFORM
    return tensorflow::wrap(tfrt_context);
#else
    status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
    return nullptr;
#endif  // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
  }
  std::vector<std::unique_ptr<tensorflow::Device>> devices;
  status->status = tensorflow::DeviceFactory::AddDevices(
      opts->session_options.options,"/job:localhost/replica:0/task:0",
      &devices);
  if (!status->status.ok()) return nullptr;
  std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
      new tensorflow::DynamicDeviceMgr(std::move(devices)));

  tensorflow::Rendezvous* r =
      new tensorflow::IntraProcessRendezvous(device_mgr.get());
  tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
      opts->session_options.options,
      static_cast<tensorflow::ContextDevicePlacementPolicy>(
          opts->device_placement_policy),
      opts->async, device_mgr.release(),
      /*device_mgr_owned*/ true, r,
      /*cluster_flr=*/nullptr,
      /*collective_executor_mgr=*/nullptr,
      /*run_eager_op_as_function=*/opts->run_eager_op_as_function);
#if !defined(IS_MOBILE_PLATFORM)
  eager_context->SetDistributedManager(
      std::make_unique<tensorflow::EagerContextDistributedManager>(
          eager_context));
#endif  // !IS_MOBILE_PLATFORM
  return tensorflow::wrap(eager_context);
}

2.6 Master 裝置

在 connect_to_cluster 之中,會呼叫 ops.device(master_device).enter() 來設定 master Device。程式碼位於 tensorflow/python/framework/ops.py。 device_name_or_function 引數可以是一個裝置名稱字串,一個裝置函式,或者是None:

  • 如果它是一個裝置名稱字串,在這個上下文中構建的所有操作將被分配給具有該名稱的裝置,除非被巢狀的 device() 上下文覆蓋。
  • 如果它是一個函式,它將被視為一個從操作物件到裝置名稱字串的函式,並且在每次建立一個新操作時被呼叫。該操作將被分配給具有返回名稱的裝置。
  • 如果它是 None,所有來自包圍上下文(enclosing context)的 device() 呼叫將被忽略。
@tf_export(v1=["device"])
def device(device_name_or_function):
 """Wrapper for  Graph.device()  using the default graph.

  See  tf.Graph.device  for more details.

  Args:
    device_name_or_function: The device name or function to use in the context.

  Returns:
    A context manager that specifies the default device to use for newly
    created ops.

  Raises:
    RuntimeError: If eager execution is enabled and a function is passed in.
 """
  if context.executing_eagerly():
    if callable(device_name_or_function):
      raise RuntimeError(
         "tf.device does not support functions when eager execution"
         "is enabled.")
    return context.device(device_name_or_function)
  elif executing_eagerly_outside_functions():
    @tf_contextlib.contextmanager
    def combined(device_name_or_function):
      with get_default_graph().device(device_name_or_function):
        if not callable(device_name_or_function):
          with context.device(device_name_or_function):
            yield
        else:
          yield
    return combined(device_name_or_function)
  else:
    return get_default_graph().device(device_name_or_function)

3. 使用 Model.fit 訓練

Keras 通過 Model.fit 提供了一個易於使用的訓練 API,它在幕後處理訓練迴圈,並且通過可重寫的 train_step 和回撥方法提供了靈活性,也提供了檢查點儲存或 TensorBoard 摘要儲存等功能。通過 Model.fit,同樣的訓練程式碼只需通過簡單地交換策略物件即可被用於其他策略。

3.1 輸入資料

使用引數伺服器訓練的 Model.fit 需要在一個 callable 中提供輸入資料,該 callable 接收一個 tf.distribution.InputContext 型別的引數,並返回一個 tf.data.Dataset 。然後,系統將建立一個 tf.keras.utils.experimental.DatasetCreator 物件,它接受上述的 callable,並通過 input_options 引數建立一個可選的 tf.distribution.InputOptions 物件。

注意,建議用引數伺服器訓練來 shuffle 和 repeat 資料,並在 fit 呼叫中指定 steps_per_epoch,這樣庫就會知道 epoch 的界限。

關於 InputContext 引數的更多資訊,請參見官方 Distributed input 教程。

def dataset_fn(input_context):
  global_batch_size = 64
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)

  x = tf.random.uniform((10, 10))
  y = tf.random.uniform((10,))

  dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10).repeat()
  dataset = dataset.shard(
      input_context.num_input_pipelines,
      input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(2)

  return dataset

dc = tf.keras.utils.experimental.DatasetCreator(dataset_fn)

dataset_fn 中的程式碼將在每個工作者的輸入裝置上被呼叫,這個裝置通常是CPU。

3.2 模型構建和編譯

處理好資料之後,使用者需要建立一個 tf.keras.Model,然後是一個 Model.compile 呼叫,以納入元件,如優化器、度量或引數(如 steps_per_execution)。

with strategy.scope():
  model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
  model.compile(tf.keras.optimizers.SGD(), loss='mse', steps_per_execution=10)

3.3 回撥和訓練

在你呼叫 model.fit 進行實際訓練之前,還需要為常見的工作準備所需的回撥,例如。

  • ModelCheckpoint :儲存模型的權重。
  • BackupAndRestore :確保訓練進度被自動備份,並在叢集出現不可用情況(如中止或搶佔)時恢復;
  • TensorBoard :將進度報告儲存為摘要檔案,在 TensorBoard 工具中進行視覺化。

注意:由於效能方面的考慮,自定義回撥在與 ParameterServerStrategy 一起使用時不能覆蓋批級(batch level)回撥。請修改你的自定義回撥成為 epoch 級別的呼叫,並將 steps_per_epoch 調整到一個合適的值。此外,當與 ParameterServerStrategy 一起使用時, steps_per_epoch 是 Model.fit 的一個必要引數。

working_dir = '/tmp/my_working_dir'
log_dir = os.path.join(working_dir, 'log')
ckpt_filepath = os.path.join(working_dir, 'ckpt')
backup_dir = os.path.join(working_dir, 'backup')

callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=log_dir),
    tf.keras.callbacks.ModelCheckpoint(filepath=ckpt_filepath),
    tf.keras.callbacks.experimental.BackupAndRestore(backup_dir=backup_dir),
]

model.fit(dc, epochs=5, steps_per_epoch=20, callbacks=callbacks)

3.4 直接使用 ClusterCoordinator (optional)

即使你選擇了 Model.fit 訓練路徑,你也可以選擇例項化一個 tf.distribution.experimental.coordinator.ClusterCoordinator 物件來安排你希望在工作者上執行的其他功能。

0x04 自定義訓練

使用 tf.distribution.Strategy 的自定義訓練迴圈為定義訓練迴圈提供了極大的靈活性。通過上面定義的 ParameterServerStrategy (作為 strategy ),使用者可以使用 tf.distribution.experimental.coordinator.ClusterCoordinator 將訓練步驟排程給遠端工作者來執行。

和其他 tf.distribution.Strategy 的訓練迴圈一樣,使用者需要建立一個模型,定義一個資料集和一個步進函式(step function)。為了確保高效的資料集預取,建議使用下面會提到的分散式資料集建立 API。此外,確保在 worker_fn 內呼叫 Strategy.run,這樣可以充分利用分配給工作者的 GPU。

我們接下來看看如何建立這些元件。

4.1 配置資料

首先,編寫一個函式來建立一個資料集,其中包括由 Keras preprocessing layers 所實現的預處理邏輯。我們在 dataset_fn 之外建立這些層,但在 dataset_fn 內應用轉換,因為我們將把 dataset_fn 包裹到 tf.function 中,它不允許在其內部建立變數。

feature_vocab = [
   "avenger","ironman","batman","hulk","spiderman","kingkong","wonder_woman"
]
label_vocab = ["yes","no"]

with strategy.scope():
  feature_lookup_layer = tf.keras.layers.StringLookup(
      vocabulary=feature_vocab,
      mask_token=None)
  label_lookup_layer = tf.keras.layers.StringLookup(
      vocabulary=label_vocab,
      num_oov_indices=0,
      mask_token=None)

  raw_feature_input = tf.keras.layers.Input(
      shape=(3,),
      dtype=tf.string,
      name="feature")
  feature_id_input = feature_lookup_layer(raw_feature_input)
  feature_preprocess_stage = tf.keras.Model(
      {"features": raw_feature_input},
      feature_id_input)

  raw_label_input = tf.keras.layers.Input(
      shape=(1,),
      dtype=tf.string,
      name="label")
  label_id_input = label_lookup_layer(raw_label_input)

  label_preprocess_stage = tf.keras.Model(
      {"label": raw_label_input},
      label_id_input)

以下是構建資料的程式碼。

def feature_and_label_gen(num_examples=200):
  examples = {"features": [],"label": []}
  for _ in range(num_examples):
    features = random.sample(feature_vocab, 3)
    label = ["yes"] if"avenger" in features else ["no"]
    examples["features"].append(features)
    examples["label"].append(label)
  return examples

examples = feature_and_label_gen()

然後,使用 dataset_fn 把訓練資料集包裝起來。

def dataset_fn(_):
  raw_dataset = tf.data.Dataset.from_tensor_slices(examples)

  train_dataset = raw_dataset.map(
      lambda x: (
          {"features": feature_preprocess_stage(x["features"])},
          label_preprocess_stage(x["label"])
      )).shuffle(200).batch(32).repeat()
  return train_dataset

4.2 建立模型

接下來,我們來建立模型和其他物件,要確保在 strategy.scope 之下建立這些變數。

# These variables created under the  strategy.scope  will be placed on parameter
# servers in a round-robin fashion.
with strategy.scope():
  # Create the model. The input needs to be compatible with Keras processing layers.
  model_input = tf.keras.layers.Input(
      shape=(3,), dtype=tf.int64, name="model_input")

  emb_layer = tf.keras.layers.Embedding(
      input_dim=len(feature_lookup_layer.get_vocabulary()), output_dim=16384)
  emb_output = tf.reduce_mean(emb_layer(model_input), axis=1)
  dense_output = tf.keras.layers.Dense(units=1, activation="sigmoid")(emb_output)
  model = tf.keras.Model({"features": model_input}, dense_output)

  optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
  accuracy = tf.keras.metrics.Accuracy()

然後需要確保使用 FixedShardsPartitioner 將所有變數分成兩個分片,每個分片被分配給不同的引數伺服器。

assert len(emb_layer.weights) == 2
assert emb_layer.weights[0].shape == (4, 16384)
assert emb_layer.weights[1].shape == (4, 16384)
assert emb_layer.weights[0].device =="/job:ps/replica:0/task:0/device:CPU:0"
assert emb_layer.weights[1].device =="/job:ps/replica:0/task:1/device:CPU:0"

4.3 定義訓練步驟

第三步則是使用 tf.function 來建立訓練 step。

@tf.function
def step_fn(iterator):

  def replica_fn(batch_data, labels):
    with tf.GradientTape() as tape:
      pred = model(batch_data, training=True)
      per_example_loss = tf.keras.losses.BinaryCrossentropy(
              reduction=tf.keras.losses.Reduction.NONE)(labels, pred)
      loss = tf.nn.compute_average_loss(per_example_loss)
      gradients = tape.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)
    accuracy.update_state(labels, actual_pred)
    return loss

  batch_data, labels = next(iterator)
  losses = strategy.run(replica_fn, args=(batch_data, labels))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None)

在上面的訓練步進函式中,在 step_fn 中呼叫 Strategy.run 和 Strategy.reduce 就可以支援每個工作者的多個GPU。工作者被分配 GPU 之後, Strategy.run 將在多個模型副本上分配資料集。

4.4 分配計算到遠端

在使用 ParameterServerStrategy 定義所有的計算後,你將使用 tf.distribution.experimental.coordinator.ClusterCoordinator 類來建立資源並將訓練步驟分配給遠端工作者。

coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)

然後,為每個工作者(per-worker)建立一個資料集和一個迭代器。在下面的 per_worker_dataset_fn 中,建議將 dataset_fn 包裹到 strategy.distribution_datasets_from_function 中,以允許無縫高效的把資料預取到 GPU。

@tf.function
def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(dataset_fn)

per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)

最後一步是使用 ClusterCoordinator.schedule 將計算分配給遠端工作者。

  • schedule 方法把一個 tf.function 插入佇列,並立即返回一個 future-like 的 RemoteValue 。佇列之中的函式將被派發給後臺執行緒中的遠端工作者,RemoteValue 將被非同步填充。
  • 可以使用 join 方法( ClusterCoordinator.join )來等待所有被規劃(scheduled)的函式執行完畢。
num_epoches = 4
steps_per_epoch = 5
for i in range(num_epoches):
  accuracy.reset_states()
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))
  # Wait at epoch boundaries.
  coordinator.join()
  print ("Finished epoch %d, accuracy is %f." % (i, accuracy.result().numpy()))

下面是如何得到 RemoteValue 的結果。

loss = coordinator.schedule(step_fn, args=(per_worker_iterator,))
print ("Final loss is %f" % loss.fetch())

或者,你可以啟動所有的步驟,並在等待完成時做一些事情。

for _ in range(total_steps):
  coordinator.schedule(step_fn, args=(per_worker_iterator,))
while not coordinator.done():
  time.sleep(10)
  # Do something like logging metrics or writing checkpoints.

4.5 建立資料集

上述程式碼中的資料集是使用 ClusterCoordinator.create_per_worker_dataset API 建立的。它為每個工作者建立一個資料集,並返回一個容器物件。你可以呼叫 iter 方法來建立一個屬於每個工作者(per-worker)的迭代器。在工作者執行函式之前, ClusterCoordinator.schedule 方法的輸入引數將被設定成工作者的相應切片(slice)。

目前, ClusterCoordinator.schedule 方法假定worker都是相同的,因此假定不同worker上的資料集是相同的,如果資料集包含 Dataset.shuffle 操作,則資料集可能會被shuffle。正因為如此,建議使用者安排執行有限的步驟,而不是依賴資料集的 OutOfRangeError 。

另一個重要的注意事項是, tf.data 資料集不支援跨任務邊界的隱式序列化和反序列化。所以在傳遞給 ClusterCoordinator.create_per_worker_dataset 的函式內建立整個資料集是很重要的。

5. 執行

5.1 直接執行

如果直接呼叫 run 來執行,則 ParameterServerStrategy 和其他策略套路類似,比如在 parameter_server_strategy_v2 之中呼叫了 mirrored_run,所以我們不在贅述。

  def _call_for_each_replica(self, fn, args, kwargs):
    self._assert_being_scheduled_by_cluster_coordinator()

    return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
                                              args, kwargs)

5.2 ClusterCoordinator

另一種方式是使用 ClusterCoordinator 來執行,我們將在下一章節結合自定義訓練迴圈來進行分析。

6. 效能改進

如果你在使用 ParameterServerStrategy 和 ClusterResolver 訓練時發現效能問題,可能有幾個原因。

一個常見的原因是引數伺服器的負載不平衡,一些過載的引數伺服器已經達到容量。也可能有多種根本原因。緩解這個問題的一些簡單方法是:

  1. 在構建 ParameterServerStrategy 時,通過指定一個 variable_partitioner 來分割你的大型模型變數。
  2. 如果可能的話,避免建立一個所有引數伺服器都需要的熱點(hotspot)變數。例如,在優化器中使用一個恆定的學習率或子類 tf.keras.optimizers.schedules.LearningRateSchedule,因為預設行為是:學習率將成為一個放在特定引數伺服器上的變數,但是此變數在每一步中被所有其他引數伺服器使用。
  3. 在將你的大詞彙表傳遞給 Keras 預處理層之前,對它們進行 shuffle。

效能問題的另一個可能原因是協調器。你的第一個 schedule / join 的實現是基於Python的,因此可能有執行緒開銷。另外,協調器和工作者之間的延遲也可能很大。如果是這種情況,那麼建議:

  • 對於 Model.fit,你可以將 Model.compile 提供的 steps_per_execution 引數設定為大於1的值。
  • 對於一個自定義的訓練迴圈,你可以將多個步驟打包到一個 tf.function 中。
steps_per_invocation = 10

@tf.function
def step_fn(iterator):
  for _ in range(steps_per_invocation):
    features, labels = next(iterator)
    def replica_fn(features, labels):
      ...

    strategy.run(replica_fn, args=(features, labels))

隨著庫的進一步優化,希望可以讓大多數使用者在未來不必手動打包步驟。此外,提高效能的一個小竅門是安排沒有返回值的函式。

7. 已知限制

在上述章節中已經涉及了大部分已知的限制。本節提供一個總結。

7.1 ParameterServerStrategy

  • os.environment["grpc_fail_fast"]="use_caller" 在包括協調器在內的每個任務上都需要,以使容錯正常工作。
  • 不支援同步的引數伺服器訓練。
  • 通常需要將多個步驟打包到一個函式中,以實現最佳效能。
  • 不支援通過 tf.saved_model.load 載入含有分片變數的儲存模型。注意使用 TensorFlow Serving 載入這樣的 saved_model 是可以的。
  • 不支援將包含分片優化器插槽(slot)變數的檢查點載入到不同數量的分片中。
  • 不支援在不重啟協調者任務的情況下從引數伺服器故障中恢復。
  • 使用 tf.lookup.StaticHashTable(它通常被一些 Keras 預處理層採用,如 tf.keras.layer.IntegerLookup 、 tf.keras.layer.StringLookup 和 tf.keras.layer.TextVectorization )將導致在這一步之中引數伺服器訓練所使用的資源被放在協調器上。這會影響從工作者到協調器的查詢RPC的效能。這是目前需要解決的一個高度優先事項。

7.2 Model.fit

  • steps_per_epoch 引數在 Model.fit 中是必需的。你可以選擇一個值來確保epoch之內被分割恰當。

  • 由於效能原因, ParameterServerStrategy 不支援批量級自定義回撥。你應該將這些呼叫轉換為epoch級的呼叫,並適當選擇 steps_per_epoch,以便每隔 steps_per_epoch 步數呼叫這些回撥。內建回撥不受影響:它們的批處理級呼叫已經被修改為可執行的。官方正在計劃為"ParameterServerStrategy"支援批量呼叫。

  • 出於同樣的原因,與其他策略不同,進度條和指標只在epoch邊界被記錄。

  • 不支援 run_eagerly 。

7.3 自定義迴圈

  • ClusterCoordinator.schedule 不支援資料集的訪問量保證(visitation guarantees)。

0xFF 參考

https://www.youtube.com/watch?v=B2Tpv_N7wkg&ab_channel=TensorFlow

[中字] TFRT: 新的 TensorFlow 執行庫 - TF Dev Summit '20

深入理解 TFRT

Inside TensorFlow: Eager execution runtime

【 深度學習框架tensorflow: Inside TensorFlow 】Inside TensorFlow(合輯)

https://github.com/tensorflow/docs-l10n/blob/07e15a23c7fa397bc44acbf20f997f7cb268ab1c/site/en-snapshot/tutorials/distribute/parameter_server_training.ipynb