[原始碼解析] TensorFlow 分散式之 MirroredStrategy 分發計算
前一篇我們分析了MirroredStrategy 的基本架構和如何更新變數,本文我們來看看 MirroredStrategy 如何執行。具體希望瞭解的是,MirroredStrategy 通過什麼方式在遠端裝置節點上執行訓練方法(如何分發計算),MirroredStrategy 和我們之前分析的 TF 執行時怎麼聯絡起來?和 master,worker 這些概念怎麼聯絡起來?
安利兩個github,都是非常好的學習資料,推薦。
https://github.com/yuhuiaws/ML-study
https://github.com/Jack47/hack-SysML
另外推薦西門宇少的最新大作讓Pipeline在Transformer LM上沿著Token level並行起來——TeraPipe。
本系列其他文章是:
[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"
[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構
[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯
[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯
[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache
[原始碼解析] TensorFlow 分散式環境(5) --- Session
[原始碼解析] TensorFlow 分散式環境(7) --- Worker 動態邏輯
[原始碼解析] TensorFlow 分散式環境(8) --- 通訊機制
[原始碼解析] TensorFlow 分散式 DistributedStrategy 之基礎篇
[原始碼解析] TensorFlow 分散式之 MirroredStrategy
0x1. 執行
程式碼示例如下,我們需要從 strategy.run 開始看。
>>> def run(strategy):
... with strategy.scope():
... v = tf.Variable(0.)
... strategy.run(step_fn, args=(v,))
... return v
1.1 基類 Strategy
Strategy 的 run 方法是用 tf.distribution 物件分發計算的主要手段。它在每個副本上呼叫 fn 。如果 args 或 kwargs 有 tf.distribution.DistributedValues,當 "fn"在一個特定的副本上執行時,它將與對應於該副本的 "tf.distributed.DistributedValues" 的元件一起執行。
當 fn 在副本上下文被呼叫,fn 可以呼叫 tf.distribution.get_replica_context() 來訪問諸如 all_reduce 等成員。 args 或 kwargs 中的所有引數可以是一個巢狀的張量結構,例如一個張量列表,在這種情況下,args 和 kwargs 將被傳遞給在每個副本上呼叫的 fn 。或者 args 或 kwargs 可以是包含張量或複合張量的 tf.compat.v1.TensorInfo.CompositeTensor 的 tf.distributedValues ,在這種情況下,每個 fn 呼叫將得到與其副本對應的 tf.distributedValues 的元件。
重要的是:根據 tf.distribution.Strategy 的實現和是否啟用 eager execution, fn 可能被呼叫一次或多次。如果 fn 被註解為 tf.function 或者 tf.distribution.Strategy.run 在 tf.function 中被呼叫(預設情況下 tf.function 中禁止急切執行),fn 在每個副本中會被呼叫一次以生成 Tensorflow 圖,然後被重新用於新輸入的執行。
具體在 run 方法之中就是呼叫了 call_for_each_replica。
def run(self, fn, args=(), kwargs=None, options=None):
"""Invokes fn on each replica, with the given arguments.
This method is the primary way to distribute your computation with a
tf.distribute object. It invokes fn on each replica. If args or kwargs
have tf.distribute.DistributedValues , such as those produced by a
tf.distribute.DistributedDataset from
tf.distribute.Strategy.experimental_distribute_dataset or
tf.distribute.Strategy.distribute_datasets_from_function ,
when fn is executed on a particular replica, it will be executed with the
component of tf.distribute.DistributedValues that correspond to that
replica.
fn is invoked under a replica context. fn may call
tf.distribute.get_replica_context() to access members such as
all_reduce . Please see the module-level docstring of tf.distribute for the
concept of replica context.
All arguments in args or kwargs can be a nested structure of tensors,
e.g. a list of tensors, in which case args and kwargs will be passed to
the fn invoked on each replica. Or args or kwargs can be
tf.distribute.DistributedValues containing tensors or composite tensors,
i.e. tf.compat.v1.TensorInfo.CompositeTensor , in which case each fn call
will get the component of a tf.distribute.DistributedValues corresponding
to its replica. Note that arbitrary Python values that are not of the types
above are not supported.
IMPORTANT: Depending on the implementation of tf.distribute.Strategy and
whether eager execution is enabled, fn may be called one or more times. If
fn is annotated with tf.function or tf.distribute.Strategy.run is
called inside a tf.function (eager execution is disabled inside a
tf.function by default), fn is called once per replica to generate a
Tensorflow graph, which will then be reused for execution with new inputs.
Otherwise, if eager execution is enabled, fn will be called once per
replica every step just like regular python code.
Args:
fn: The function to run on each replica.
args: Optional positional arguments to fn . Its element can be a tensor,
a nested structure of tensors or a tf.distribute.DistributedValues .
kwargs: Optional keyword arguments to fn . Its element can be a tensor,
a nested structure of tensors or a tf.distribute.DistributedValues .
options: An optional instance of tf.distribute.RunOptions specifying
the options to run fn .
Returns:
Merged return value of fn across replicas. The structure of the return
value is the same as the return value from fn . Each element in the
structure can either be tf.distribute.DistributedValues , Tensor
objects, or Tensor s (for example, if running on a single replica).
"""
del options
if not isinstance(args, (list, tuple)):
raise ValueError(
"positional args must be a list or tuple, got {}".format(type(args)))
with self.scope():
# tf.distribute supports Eager functions, so AutoGraph should not be
# applied when the caller is also in Eager mode.
fn = autograph.tf_convert(
fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
1.2 StrategyExtendedV1
因為 StrategyExtendedV1 是 StrategyExtendedV2 的派生類,所以無論是 StrategyExtendedV1 還是 StrategyExtendedV2 都會呼叫到 call_for_each_replica 方法。
def call_for_each_replica(self, fn, args=(), kwargs=None):
"""Run fn once per replica.
fn may call tf.get_replica_context() to access methods such as
replica_id_in_sync_group and merge_call() .
merge_call() is used to communicate between the replicas and
re-enter the cross-replica context. All replicas pause their execution
having encountered a merge_call() call. After that the
merge_fn -function is executed. Its results are then unwrapped and
given back to each replica call. After that execution resumes until
fn is complete or encounters another merge_call() . Example:
```python
# Called once in "cross-replica" context.
def merge_fn(distribution, three_plus_replica_id):
# sum the values across replicas
return sum(distribution.experimental_local_results(three_plus_replica_id))
# Called once per replica in distribution , in a "replica" context.
def fn(three):
replica_ctx = tf.get_replica_context()
v = three + replica_ctx.replica_id_in_sync_group
# Computes the sum of the v values across all replicas.
s = replica_ctx.merge_call(merge_fn, args=(v,))
return s + v
with distribution.scope():
# in "cross-replica" context
...
merged_results = distribution.run(fn, args=[3])
# merged_results has the values from every replica execution of fn .
# This statement prints a list:
print(distribution.experimental_local_results(merged_results))
```
Args:
fn: function to run (will be run once per replica).
args: Tuple or list with positional arguments for fn .
kwargs: Dict with keyword arguments for fn .
Returns:
Merged return value of fn across all replicas.
"""
_require_cross_replica_or_default_context_extended(self)
if kwargs is None:
kwargs = {}
with self._container_strategy().scope():
return self._call_for_each_replica(fn, args, kwargs)
1.3 MirroredExtended
_call_for_each_replica 是在 MirroredExtended 實現的,其呼叫了 mirrored_run。
def _call_for_each_replica(self, fn, args, kwargs):
return mirrored_run.call_for_each_replica(
self._container_strategy(), fn, args, kwargs)
0x2. mirrored_run
mirrored_run 指的是 mirrored_run.py 檔案提供的內容。
2.1 call_for_each_replica
在 mirrored_run 之中,首先來到了 call_for_each_replica,其目的是在每個裝置上呼叫 fn。
def call_for_each_replica(strategy, fn, args=None, kwargs=None):
"""Call fn on each worker devices(replica).
It's highly recommended to wrap the call to this function inside a
tf.function , otherwise the performance is poor.
Args:
strategy: tf.distribute.Strategy .
fn: function to call on each worker devices.
args: positional arguments to fn .
kwargs: keyword arguments to fn .
Returns:
Wrapped returned value of fn from all replicas.
"""
if args is None:
args = ()
if kwargs is None:
kwargs = {}
if isinstance(fn, def_function.Function):
# Don't lift up the tf.function decoration if fn is compiled with XLA
# and all devices are GPU. In this case we will use collectives to do
# cross-device communication, thus no merge_call is in the path.
if fn._jit_compile and all(
[_is_gpu_device(d) for d in strategy.extended.worker_devices]):
return _call_for_each_replica(strategy, fn, args, kwargs)
if strategy not in _cfer_fn_cache:
_cfer_fn_cache[strategy] = weakref.WeakKeyDictionary()
wrapped = _cfer_fn_cache[strategy].get(fn)
if wrapped is None:
# We need to wrap fn such that it triggers _call_for_each_replica inside
# the tf.function. We use _clone() instead of @tf.function wrapped
# call_for_each_replica() because we would like to retain the arguments to
# the @tf.function decorator of fn.
wrapped = fn._clone(
python_function=functools.partial(call_for_each_replica, strategy,
fn.python_function))
_cfer_fn_cache[strategy][fn] = wrapped
return wrapped(args, kwargs)
else:
# When a tf.function is wrapped to trigger _call_for_each_replica (see
# the other branch above), AutoGraph stops conversion at
# _call_for_each_replica itself (TF library functions are allowlisted).
# This makes sure that the Python function that originally passed to
# the tf.function is still converted.
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
return _call_for_each_replica(strategy, fn, args, kwargs)
2.2 建立執行緒
在 _call_for_each_replica 之中,會建立 _MirroredReplicaThread 來執行。每個裝置會起動一個執行緒,並行執行fn,直至所有 fn 都完成。
def _call_for_each_replica(distribution, fn, args, kwargs):
"""Run fn in separate threads, once per replica/worker device.
Args:
distribution: the DistributionStrategy object.
fn: function to run (will be run once per replica, each in its own thread).
args: positional arguments for fn
kwargs: keyword arguments for fn .
Returns:
Merged return value of fn across all replicas.
Raises:
RuntimeError: If fn() calls get_replica_context().merge_call() a different
number of times from the available devices.
"""
run_concurrently = False
if not context.executing_eagerly():
# Needed for per-thread device, etc. contexts in graph mode.
ops.get_default_graph().switch_to_thread_local()
coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
shared_variable_store = {}
devices = distribution.extended.worker_devices
threads = []
for index in range(len(devices)): # 遍歷裝置
variable_creator_fn = shared_variable_creator.make_fn(
shared_variable_store, index)
t = _MirroredReplicaThread(distribution, coord, index, devices,
variable_creator_fn, fn,
distribute_utils.caching_scope_local,
distribute_utils.select_replica(index, args),
distribute_utils.select_replica(index, kwargs))
threads.append(t)
for t in threads:
t.start()
# When fn starts should_run event is set on _MirroredReplicaThread
# ( MRT ) threads. The execution waits until
# MRT.has_paused is set, which indicates that either fn is
# complete or a get_replica_context().merge_call() is called. If fn is
# complete, then MRT.done is set to True. Otherwise, arguments
# of get_replica_context().merge_call from all paused threads are grouped
# and the merge_fn is performed. Results of the
# get_replica_context().merge_call are then set to MRT.merge_result .
# Each such get_replica_context().merge_call call returns the
# MRT.merge_result for that thread when MRT.should_run event
# is reset again. Execution of fn resumes.
try:
with coord.stop_on_exception():
all_done = False
while not all_done and not coord.should_stop():
done = []
if run_concurrently:
for t in threads:
t.should_run.set()
for t in threads:
t.has_paused.wait()
t.has_paused.clear()
if coord.should_stop():
return None
done.append(t.done)
else:
for t in threads:
t.should_run.set()
t.has_paused.wait()
t.has_paused.clear()
if coord.should_stop():
return None
done.append(t.done)
if coord.should_stop():
return None
all_done = all(done)
if not all_done:
if any(done):
raise RuntimeError("Some replicas made a different number of "
"replica_context().merge_call() calls.")
# get_replica_context().merge_call() case
merge_args = distribute_utils.regroup(
tuple(t.merge_args for t in threads))
merge_kwargs = distribute_utils.regroup(
tuple(t.merge_kwargs for t in threads))
# We capture the name_scope of the MRT when we call merge_fn
# to ensure that if we have opened a name scope in the MRT,
# it will be respected when executing the merge function. We only
# capture the name_scope from the first MRT and assume it is
# the same for all other MRTs.
mtt_captured_name_scope = threads[0].captured_name_scope
mtt_captured_var_scope = threads[0].captured_var_scope
# Capture and merge the control dependencies from all the threads.
mtt_captured_control_deps = set()
for t in threads:
mtt_captured_control_deps.update(t.captured_control_deps)
with ops.name_scope(mtt_captured_name_scope),\
ops.control_dependencies(mtt_captured_control_deps), \
variable_scope.variable_scope(mtt_captured_var_scope):
merge_result = threads[0].merge_fn(distribution, *merge_args,
**merge_kwargs)
for r, t in enumerate(threads):
t.merge_result = distribute_utils.select_replica(r, merge_result)
finally:
for t in threads:
t.should_run.set()
coord.join(threads)
return distribute_utils.regroup(tuple(t.main_result for t in threads))
2.3 執行緒定義
_MirroredReplicaThread 的定義比較好理解:此執行緒在一個裝置上執行方法。這裡重要的是入口處呼叫了 context.ensure_initialized()。所以我們接下來要看看 Context 概念。
class _MirroredReplicaThread(threading.Thread):
"""A thread that runs() a function on a device."""
def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, fn,
caching_scope, args, kwargs):
super(_MirroredReplicaThread, self).__init__()
self.coord = coord
self.distribution = dist
self.devices = devices
self.replica_id = replica_id
self.replica_id_in_sync_group = (
dist.extended._get_replica_id_in_sync_group(replica_id))
self.variable_creator_fn = variable_creator_fn
# State needed to run and return the results of fn .
self.main_fn = fn
self.main_args = args
self.main_kwargs = kwargs
self.main_result = None
self.done = False
# State needed to run the next merge_call() (if any) requested via
# ReplicaContext.
self.merge_fn = None
self.merge_args = None
self.merge_kwargs = None
self.merge_result = None
self.captured_name_scope = None
self.captured_var_scope = None
try:
self.caching_scope_entered = caching_scope.new_cache_scope_count
self.caching_scope_exited = caching_scope.cache_scope_exited_count
except AttributeError:
self.caching_scope_entered = None
self.caching_scope_exited = None
# We use a thread.Event for the main thread to signal when this
# thread should start running ( should_run ), and another for
# this thread to transfer control back to the main thread
# ( has_paused , either when it gets to a
# get_replica_context().merge_call or when fn returns). In
# either case the event starts cleared, is signaled by calling
# set(). The receiving thread waits for the signal by calling
# wait() and then immediately clearing the event using clear().
self.should_run = threading.Event()
self.has_paused = threading.Event()
# These fields have to do with inheriting various contexts from the
# parent thread:
context.ensure_initialized() # 確保初始化上下文
ctx = context.context() # 獲取上下文
self.in_eager = ctx.executing_eagerly()
self.record_thread_local_summary_state()
self.record_thread_local_eager_context_state()
self.context_device_policy = (
pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(
ctx._context_handle))
self.graph = ops.get_default_graph()
with ops.init_scope():
self._init_in_eager = context.executing_eagerly()
self._init_graph = ops.get_default_graph()
self._variable_creator_stack = self.graph._variable_creator_stack[:]
self._var_scope = variable_scope.get_variable_scope()
# Adding a "/" at end lets us re-enter this scope later.
self._name_scope = self.graph.get_name_scope()
if self._name_scope:
self._name_scope += "/"
if self.replica_id > 0:
if not self._name_scope:
self._name_scope = ""
self._name_scope += "replica_%d/" % self.replica_id
def run(self):
self.should_run.wait()
self.should_run.clear()
try:
if self.coord.should_stop():
return
self.restore_thread_local_summary_state()
self.restore_thread_local_eager_context_state()
if (self.caching_scope_entered is not None and
self.caching_scope_exited is not None):
distribute_utils.caching_scope_local.new_cache_scope_count = self.caching_scope_entered
distribute_utils.caching_scope_local.cache_scope_exited_count = self.caching_scope_exited
with self.coord.stop_on_exception(), \
_enter_graph(self._init_graph, self._init_in_eager), \
_enter_graph(self.graph, self.in_eager,
self._variable_creator_stack), \
context.device_policy(self.context_device_policy), \
_MirroredReplicaContext(self.distribution,
self.replica_id_in_sync_group), \
# 這裡設定了某一個裝置
ops.device(self.devices[self.replica_id]), \
ops.name_scope(self._name_scope), \
variable_scope.variable_scope(
self._var_scope, reuse=self.replica_id > 0), \
variable_scope.variable_creator_scope(self.variable_creator_fn):
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
self.done = True
finally:
self.has_paused.set()
def record_thread_local_summary_state(self):
"""Record the thread local summary state in self."""
# TODO(slebedev): is this still relevant? the referenced bug is closed.
summary_state = summary_ops_v2._summary_state
self._summary_step = summary_state.step
self._summary_writer = summary_state.writer
self._summary_recording = summary_state.is_recording
self._summary_recording_distribution_strategy = (
summary_state.is_recording_distribution_strategy)
def restore_thread_local_summary_state(self):
"""Restore thread local summary state from self."""
summary_state = summary_ops_v2._summary_state
summary_state.step = self._summary_step
summary_state.writer = self._summary_writer
summary_state.is_recording = self._summary_recording
summary_state.is_recording_distribution_strategy = (
self._summary_recording_distribution_strategy)
def record_thread_local_eager_context_state(self):
ctx = context.context()
eager_context_state = ctx._thread_local_data
self._eager_context_op_callbacks = eager_context_state.op_callbacks
def restore_thread_local_eager_context_state(self):
ctx = context.context()
eager_context_state = ctx._thread_local_data
eager_context_state.op_callbacks = self._eager_context_op_callbacks
目前邏輯如下:
圖 1 如何執行
具體邏輯大致如下,這裡假定有兩個裝置,對應啟動了兩個執行緒。
現在本地啟動了多個執行緒進行訓練,我們接下來看看如何把計算分配到遠端工作者之上。
0x3. Context
之前我們接觸的 TF 分散式都是基於 session 之上,但是在 TF 2 之中已經取消了 Session,我們需要找到一個和 session 對應的概念,這就是 context。Session 的作用是 TF runtime 互動,context 也有類似的作用,context 儲存需要和 runtime 互動的所有資訊,但是其生命週期遠遠比 session 長。可以認為 Context 在某種程度上起到了 TF 1 Session 概念環境之中 Master 的作用。
其定義如下,可以從註釋之中看到,TF計劃將其與Eager的關係再明確一下:
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
# TODO(agarwal): consider keeping the corresponding Graph here.
class Context(object):
"""Environment in which eager operations execute."""
# TODO(agarwal): create and link in some documentation for `execution_mode`.
def __init__(self,
config=None,
device_policy=None,
execution_mode=None,
server_def=None):
"""Creates a new Context.
Args:
config: (Optional.) A `ConfigProto` protocol buffer with configuration
options for the Context. Note that a lot of these options may be
currently unimplemented or irrelevant when eager execution is enabled.
device_policy: (Optional.) What policy to use when trying to run an
operation on a device with inputs which are not on that device. When set
to None, an appropriate value will be picked automatically. The value
picked may change between TensorFlow releases. Defaults to
DEVICE_PLACEMENT_SILENT.
Valid values:
- DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
correct.
- DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right
device but raises a warning.
- DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide
performance problems.
- DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
raising errors on the other ones.
execution_mode: (Optional.) Policy controlling how operations dispatched
are actually executed. When set to None, an appropriate value will be
picked automatically. The value picked may change between TensorFlow
releases.
Valid values:
- SYNC: executes each operation synchronously.
- ASYNC: executes each operation asynchronously. These operations may
return "non-ready" handles.
server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution
on remote devices. GrpcServers need to be started by creating an
identical server_def to this, and setting the appropriate task_indexes,
so that the servers can communicate. It will then be possible to execute
operations on remote devices.
Raises:
ValueError: If execution_mode is not valid.
"""
# This _id is used only to index the tensor caches.
# TODO(iga): Remove this when tensor caches are moved to C++.
self._id = _context_id_counter.increment_and_get()
self._tensor_cache_deleter = _TensorCacheDeleter(self._id)
_tensor_caches_map[self._id] = _TensorCaches()
self._config = config
self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData(
self,
is_eager=lambda: default_execution_mode == EAGER_MODE,
device_spec=_starting_device_spec)
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
self._context_handle = None
self._context_devices = None
self._seed = None
self._initialize_lock = threading.Lock()
self._initialized = False
if device_policy is None:
device_policy = DEVICE_PLACEMENT_SILENT
self._device_policy = device_policy
self._mirroring_policy = None
if execution_mode not in (None, SYNC, ASYNC):
raise ValueError("execution_mode should be None/SYNC/ASYNC. Got %s" %
execution_mode)
if execution_mode is None:
execution_mode = SYNC
self._default_is_async = execution_mode == ASYNC
self._use_tfrt = is_tfrt_enabled()
self._use_tfrt_distributed_runtime = None
self._run_eager_op_as_function = run_eager_op_as_function_enabled()
self._server_def = server_def
self._collective_ops_server_def = None
self._collective_leader = None
self._collective_scoped_allocator_enabled_ops = None
self._collective_use_nccl_communication = None
self._collective_device_filters = None
self._coordination_service = None
self._device_lock = threading.Lock()
self._physical_devices = None
self._physical_device_to_index = None
self._visible_device_list = []
self._memory_growth_map = None
self._virtual_device_map = {}
# Values set after construction
self._optimizer_jit = None
self._intra_op_parallelism_threads = None
self._inter_op_parallelism_threads = None
self._soft_device_placement = None
self._log_device_placement = None
self._enable_mlir_graph_optimization = None
self._optimizer_experimental_options = {}
_python_eager_context_create_counter.get_cell().increase_by(1)
我們接下來按照初始化流程走一下。
3.1 ensure_initialized
Python context 是 CPP context 的 wrapper,ensure_initialized 是用來確保初始化的方法。
def ensure_initialized():
"""Initialize the context."""
context().ensure_initialized()
具體程式碼如下,其中呼叫了很多名字類似 TFE_ContextOptionsSetXXX 的設定函式。
def ensure_initialized(self):
"""Initialize handle and devices if not already done so."""
if self._initialized:
return
with self._initialize_lock:
if self._initialized:
return
assert self._context_devices is None
opts = pywrap_tfe.TFE_NewContextOptions()
try:
config_str = self.config.SerializeToString()
pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str)
if self._device_policy is not None:
pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy(
opts, self._device_policy)
if self._mirroring_policy is not None:
pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy(
opts, self._mirroring_policy)
if self._default_is_async == ASYNC:
pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
if self._use_tfrt is not None:
pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt)
if self._use_tfrt is not None and \
self._use_tfrt_distributed_runtime is not None:
pywrap_tfe.TFE_ContextOptionsSetTfrtDistributedRuntime(
opts, self._use_tfrt_distributed_runtime)
pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(
opts, self._run_eager_op_as_function)
context_handle = pywrap_tfe.TFE_NewContext(opts)
finally:
pywrap_tfe.TFE_DeleteContextOptions(opts)
if self._server_def is not None:
server_def_str = self._server_def.SerializeToString()
pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS,
server_def_str)
elif self._collective_ops_server_def is not None:
server_def_str = self._collective_ops_server_def.SerializeToString()
pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str)
self._context_handle = context_handle
self._initialize_logical_devices()
self._initialized = True
3.2 TFE_ContextSetServerDef
我們用 TFE_ContextSetServerDef 來看看,其程式碼在 tensorflow/c/eager/c_api.cc。主要功能是呼叫了 GetDistributedManager() 的方法。
// Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
int keep_alive_secs,
const void* proto,
size_t proto_len,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::errors::Unimplemented(
"TFE_ContextSetServerDef not supported on mobile");
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::ServerDef server_def;
if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
status->status =
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
server_def, /*reset_context=*/true, keep_alive_secs);
#endif // !IS_MOBILE_PLATFORM
}
3.3 EagerContextDistributedManager
EagerContextDistributedManager 的程式碼位於 tensorflow/core/common_runtime/eager/context_distributed_manager.cc。 其呼叫到了 UpdateContextWithServerDef。
Status EagerContextDistributedManager::SetOrUpdateServerDef(
const ServerDef& server_def, bool reset_context, int keep_alive_secs) {
if (server_def.has_cluster_device_filters()) {
if (reset_context) {
const auto& cdf = server_def.cluster_device_filters();
for (const auto& jdf : cdf.jobs()) {
const string remote_prefix = "/job:" + jdf.name() + "/task:";
for (const auto& tdf : jdf.tasks()) {
const int32_t task_index = tdf.first;
std::vector<string> device_filters(tdf.second.device_filters_size());
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
device_filters[i] = tdf.second.device_filters(i);
}
const string remote_worker =
strings::StrCat(remote_prefix, task_index);
TF_RETURN_IF_ERROR(
context_->SetRemoteDeviceFilters(remote_worker, device_filters));
}
}
}
}
// 呼叫到了 UpdateContextWithServerDef
return UpdateContextWithServerDef(context_, server_def, reset_context,
keep_alive_secs);
}
3.4 UpdateContextWithServerDef
UpdateContextWithServerDef 這裡有幾個關鍵步驟:
- 生成了 DistributedFunctionLibraryRuntime。
- 生成了 CreateContextRequest,呼叫 CreateRemoteContexts 來傳送請求。
這裡我們可以看到一系列看起來熟悉的名字,比如 grpc_server,curr_remote_workers,master_env,worker_session ..... 都是我們前面遇到的執行時概念。如此看來,雖然Session API不存在了,但是內部依然使用了這些概念,只是經由Context來重新組織封裝。
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get());
auto remote_mgr = std::make_unique<tensorflow::eager::RemoteMgr>(
/*is_master=*/true, context);
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
std::move(new_server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
std::move(remote_mgr)));
UpdateContextWithServerDef 的具體程式碼如下:
tensorflow::Status UpdateContextWithServerDef(
EagerContext* context, const tensorflow::ServerDef& server_def,
bool reset_context, int keep_alive_secs) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
#define LOG_AND_RETURN_IF_ERROR(...) \
do { \
const ::tensorflow::Status _status = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
LOG(ERROR) << _status.error_message(); \
return _status; \
} \
} while (0);
string worker_name =
tensorflow::strings::StrCat("/job:", server_def.job_name(),
"/replica:0/task:", server_def.task_index());
// List of current remote workers before updating server_def. Unused if
// resetting the server_def.
std::vector<string> curr_remote_workers;
// List of updated remote workers.
std::vector<string> remote_workers;
// New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
tensorflow::DeviceMgr* device_mgr =
AreLocalDevicesCompatible(context, server_def)
? context->local_device_mgr()
: nullptr;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
server_def, {device_mgr}, &new_server));
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
} else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
&curr_remote_workers));
// No need to check the cast here, since ListRemoteWorkers already checks
// if the server is a GRPC server or not.
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
}
tensorflow::uint64 context_id = context->GetContextId();
tensorflow::uint64 context_view_id = context->GetContextViewId();
if (reset_context) {
context_id = tensorflow::EagerContext::NewContextId();
context_view_id = 0;
// Make master eager context accessible by local eager service, which might
// receive send tensor requests from remote workers.
LOG_AND_RETURN_IF_ERROR(
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
}
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
LOG_AND_RETURN_IF_ERROR(
grpc_server->master_env()->worker_cache->GetEagerClientCache(
&remote_eager_workers));
// For cluster update, use a status group to aggregate statuses from
// * adding and removing remote devices
// * creating remote contexts on newly added workers
// * updating remote contexts on existing workers
// * updating the master context
// Note that we should not return immediately on errors in the middle of these
// updates to prevent cluster from having inconsistent context views.
//
// Unused if reset_context is True.
tensorflow::StatusGroup sg;
// When updating an existing context, populate the following lists with:
// * added_workers: set(remote_workers) - set(curr_remote_workers)
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
// * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
// * replaced_workers: workers with the same task names and potentially the
// same hostname:port s, but replaced by different processes
std::vector<string> added_workers;
std::vector<string> removed_workers;
std::vector<string> existing_workers;
std::vector<string> replaced_workers;
// New remote device manager created for new server_def. Unused if updating
// server_def.
std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
remote_workers, grpc_server->master_env()->worker_cache,
&new_remote_device_mgr));
remote_device_mgr = new_remote_device_mgr.get();
} else {
context->ClearCachesAndDefaultExecutor();
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
std::sort(remote_workers.begin(), remote_workers.end());
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
&added_workers, &removed_workers,
&existing_workers);
sg.Update(GetReplacedFromExistingWorkers(
&existing_workers, context_id, context->GetContextViewId(), server_def,
remote_eager_workers.get(), &replaced_workers));
if (!replaced_workers.empty()) {
// Treat replaced workers as removed then added back, so that we recreate
// remote devices and contexts, and re-register functions on those workers
removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
replaced_workers.end());
added_workers.insert(added_workers.end(), replaced_workers.begin(),
replaced_workers.end());
for (const string& w : replaced_workers) {
existing_workers.erase(
std::remove(existing_workers.begin(), existing_workers.end(), w),
existing_workers.end());
}
}
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
sg.Update(AddRemoteDevicesToMgr(added_workers,
grpc_server->master_env()->worker_cache,
remote_device_mgr));
}
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes);
// This request make sure that we can create Rendezvous properly between
// Local and Remote context.
tensorflow::eager::CreateContextRequest base_request; // 生成了 CreateContextRequest
for (const auto& da : cluster_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
// Initialize remote eager workers.
if (reset_context) {
const tensorflow::Status s = CreateRemoteContexts(
context, remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
base_request);
} else {
if (sg.ok()) {
// Create remote contexts on the newly added workers only if the master
// has collected all device information from them (i.e., the
// GetAllRemoteDevices call returns succussfully). Note that in rare cases
// GetAllRemoteDevices can still fail even with RPCs configured to wait
// until the remote workers to become alive. If the master creates remote
// contexts on the workers whose devices are still not collected, those
// workers will be treated as existing workers subsequently, so the master
// will never get devices from them even with retrying UpdateServerDef.
sg.Update(CreateRemoteContexts(
context, added_workers, context_id, context_view_id + 1,
keep_alive_secs, server_def, remote_eager_workers.get(),
context->Executor().Async(), base_request));
}
if (!existing_workers.empty()) {
// The master's context_view_id will be incremented by one in the
// UpdateRemoteMaster call later. We want existing workers to also have
// the updated context_view_id, so we must set their context_view_id to
// the master's current context_view_id + 1.
sg.Update(UpdateRemoteContexts(context, existing_workers, added_workers,
removed_workers, context_id,
context_view_id + 1, server_def,
remote_eager_workers.get(), base_request));
}
}
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
if (reset_context) {
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
auto* device_mgr = grpc_server->worker_env()->device_mgr;
std::shared_ptr<tensorflow::WorkerSession> worker_session;
LOG_AND_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, base_request.cluster_device_attributes(),
true));
LOG_AND_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
// Initialize remote tensor communication based on worker session.
LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get());
auto remote_mgr = std::make_unique<tensorflow::eager::RemoteMgr>(
/*is_master=*/true, context);
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
std::move(new_server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
std::move(remote_mgr)));
// NOTE: We start the server after all other initialization, because the
// GrpcServer cannot be destroyed after it is started.
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
} else {
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
session_name, server_def, base_request.cluster_device_attributes(),
/*isolate_session_state=*/true));
sg.Update(context->UpdateRemoteMaster(context_id,
std::move(remote_eager_workers),
added_workers, removed_workers));
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
}
#undef LOG_AND_RETURN_IF_ERROR
return tensorflow::Status::OK();
}
3.5 CreateRemoteContexts
CreateRemoteContexts 方法會建立遠端上下文,既然與遠端有關係,就說明會用到gRPC機制。
tensorflow::Status CreateRemoteContexts(
EagerContext* context, const std::vector<string>& remote_workers,
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
&parsed_name)) {
counter.DecrementCount();
continue;
}
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextResponse* response =
new tensorflow::eager::CreateContextResponse();
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
std::vector<bool> filtered_device_mask;
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(),
base_request.cluster_device_attributes_size());
for (int i = 0; i < filtered_device_mask.size(); i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
request.set_lazy_copy_remote_function_inputs(true);
eager_client->CreateContextAsync(
&request, response,
[i, &statuses, &counter, response](const tensorflow::Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
}
counter.Wait();
tensorflow::StatusGroup sg;
for (int i = 0; i < num_remote_workers; i++) {
if (TF_PREDICT_FALSE(!statuses[i].ok())) {
sg.Update(statuses[i]);
}
}
return sg.as_summary_status();
}
3.6 CreateContextAsync
CreateContextAsync 方法會傳送 CreateContextRequest 請求。
3.6.1 EagerClient
EagerClient 是 gRPC 的客戶端介面。
// This is a base class that can be implemented by a variety of
// transports (e.g. gRPC which for each of the client methods makes an RPC).
class EagerClient : public core::RefCounted {
public:
~EagerClient() override {}
#define CLIENT_METHOD(method) \
virtual void method##Async(const method##Request* request, \
method##Response* response, \
StatusCallback done) = 0;
CLIENT_METHOD(CreateContext);
CLIENT_METHOD(UpdateContext);
CLIENT_METHOD(WaitQueueDone);
CLIENT_METHOD(KeepAlive);
CLIENT_METHOD(CloseContext);
#undef CLIENT_METHOD
#define CLIENT_CANCELABLE_METHOD(method) \
virtual void method##Async( \
CallOptions* call_opts, const method##Request* request, \
method##Response* response, StatusCallback done) = 0;
CLIENT_CANCELABLE_METHOD(Enqueue);
CLIENT_CANCELABLE_METHOD(RunComponentFunction);
#undef CLIENT_CANCELABLE_METHOD
// Feeds request into the request stream of EagerService::StreamingEnqueue.
// response will be filled with the response for this request . The
// 1-to-1 correspondence between requests and responses is a property
// of the current service implementation. When the response is received,
// done is invoked with the current status of the StreamingEnqueue call.
// The status can contain an error because of an earlier request in the
// current streaming call.
// The client initiates a streaming call the first time StreamingEnqueueAsync
// is invoked and keeps it open until some error condition.
// Similarly to the methods above, the request can be deleted as soon as
// StreamingEnqueueAsync returns.
virtual void StreamingEnqueueAsync(CallOptions* call_opts,
const EnqueueRequest* request,
EnqueueResponse* response,
StatusCallback done) = 0;
virtual bool allow_multiple_pending_requests() const = 0;
};
3.6.2 GrpcEagerClient
GrpcEagerClient 是 gRPC 的客戶端實現。
class GrpcEagerClient : public EagerClient {
public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
GrpcEagerClientThread* thread, const string& target)
: stub_(channel), thread_(thread), target_(target) {
// Hold a reference to make sure the corresponding EagerClientThread
// outlives the client.
thread_->Ref();
cq_ = thread->completion_queue();
}
~GrpcEagerClient() override { thread_->Unref(); }
bool allow_multiple_pending_requests() const override {
return EnableStreaming();
}
#define CLIENT_METHOD(method) \
void method##Async(const method##Request* request, \
method##Response* response, StatusCallback done) \
override { \
StatusCallback done_wrapped = callback_wrapper(std::move(done)); \
new RPCState<protobuf::Message>( \
&stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
response, std::move(done_wrapped), /*call_opts=*/nullptr, \
/*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, \
&target_); \
}
CLIENT_METHOD(CreateContext);
CLIENT_METHOD(UpdateContext);
CLIENT_METHOD(WaitQueueDone);
CLIENT_METHOD(KeepAlive);
#undef CLIENT_METHOD
#define CLIENT_CANCELABLE_METHOD(method) \
void method##Async(CallOptions* call_opts, const method##Request* request, \
method##Response* response, StatusCallback done) \
override { \
StatusCallback done_wrapped = callback_wrapper(std::move(done)); \
new RPCState<protobuf::Message>( \
&stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
response, std::move(done_wrapped), call_opts, /*threadpool=*/nullptr, \
/*max_retries=*/0, /*fail_fast=*/true, &target_); \
}
CLIENT_CANCELABLE_METHOD(Enqueue);
CLIENT_CANCELABLE_METHOD(RunComponentFunction);
#undef CLIENT_CANCELABLE_METHOD
void CloseContextAsync(const CloseContextRequest* request,
CloseContextResponse* response,
StatusCallback done) override {
StatusCallback done_wrapped = callback_wrapper(std::move(done));
new RPCState<protobuf::Message>(
&stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
response, std::move(done_wrapped), /*call_opts=*/nullptr,
/*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
&target_);
mutex_lock l(mu_);
const auto& it = enqueue_dispatchers_.find(request->context_id());
if (it != enqueue_dispatchers_.end()) {
it->second.CancelCall();
enqueue_dispatchers_.erase(it);
} else if (EnableStreaming()) {
LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
<< " does not seem to exist.";
}
}
void StreamingEnqueueAsync(CallOptions* call_opts,
const EnqueueRequest* request,
EnqueueResponse* response,
StatusCallback done) override {
StatusCallback done_wrapped = callback_wrapper(std::move(done));
if (EnableStreaming()) {
mutex_lock l(mu_);
auto it = enqueue_dispatchers_.find(request->context_id());
if (it == enqueue_dispatchers_.end()) {
auto it_and_bool = enqueue_dispatchers_.emplace(
std::piecewise_construct,
std::forward_as_tuple(request->context_id()),
std::forward_as_tuple(
&stub_, cq_,
"/tensorflow.eager.EagerService/StreamingEnqueue"));
it = it_and_bool.first;
}
// TODO(haoyuzhang): Consider supporting cancellation for streaming RPC?
it->second.SendNextRequest(*request, response, std::move(done_wrapped));
} else {
Notification n;
Status status;
EnqueueAsync(call_opts, request, response,
[&n, &status](const Status& s) {
status.Update(s);
n.Notify();
});
n.WaitForNotification();
done_wrapped(status);
}
}
private:
::grpc::GenericStub stub_;
const GrpcEagerClientThread* thread_;
const string target_;
::grpc::CompletionQueue* cq_;
mutable mutex mu_;
std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
enqueue_dispatchers_ TF_GUARDED_BY(mu_);
StatusCallback callback_wrapper(StatusCallback done) {
Ref();
return [this, done = std::move(done)](const Status& status) {
done(status);
this->Unref();
};
}
};
於是我們得到了目前具體邏輯如下:
圖 上下文相關邏輯
0x4. 通訊協議
此時我們發現了一個之前在runtime分析時候看到但是並沒有分析過的 tensorflow/core/protobuf/eager_service.proto,我們就入手看看。
4.1 建立遠端上下文
我們首先看看如何建立遠端上下文,具體訊息定義如下:
message CreateContextRequest {
// Identifies the full cluster, and this particular worker's position within.
ServerDef server_def = 1;
// Whether the ops on the worker should be executed synchronously or
// asynchronously. By default, ops are executed synchronously.
bool async = 2;
// Number of seconds to keep the context alive. If more than keep_alive_secs
// has passed since a particular context has been communicated with, it will
// be garbage collected.
int64 keep_alive_secs = 3;
// This is the version for all the ops that will be enqueued by the client.
VersionDef version_def = 4;
// Device attributes in the cluster
repeated DeviceAttributes cluster_device_attributes = 6;
// The ID of the created context. This is usually a randomly generated number,
// that will be used to identify the context in future requests to the
// service. Contexts are not persisted through server restarts.
// This ID will be used for all future communications as well. It is essential
// that both ends use this ID for selecting a rendezvous to get everything to
// match.
fixed64 context_id = 7;
// The view ID of the context.
fixed64 context_view_id = 8;
// For a multi device function, if false, eagerly copy all remote inputs to
// the default function device; if true, lazily copy remote inputs to their
// target devices after function instantiation to avoid redundant copies.
bool lazy_copy_remote_function_inputs = 9;
reserved 5;
}
message CreateContextResponse {
// List of devices that are locally accessible to the worker.
repeated DeviceAttributes device_attributes = 2;
reserved 1;
}
4.2 如何執行
其次看看如何執行方法。
message RunComponentFunctionRequest {
fixed64 context_id = 1;
Operation operation = 2;
// The output indices of its parent function.
repeated int32 output_num = 3;
}
message RunComponentFunctionResponse {
repeated TensorShapeProto shape = 1;
repeated TensorProto tensor = 2;
}
有了協議為基礎,我們接下來看看對應的服務。
0x5. Eager Service
Eager 服務定義了一個 TensorFlow 服務,其代表一個遠端 Eager 執行器(Eager executor),會在一組本地裝置上(eagerly)執行操作。該服務將跟蹤它所訪問的各種客戶端和裝置,並允許客戶端在它能夠訪問的任何裝置上排隊執行操作,並安排從/到任何對等體(peers)的資料傳輸。
一個客戶端可以生成多個上下文,以便能夠獨立執行操作,但不能在兩個上下文之間共享資料。注意:即使客戶端生成的上下文應該是獨立的,但低階別的tensorflow執行引擎不是,所以它們可能會共享一些資料(例如,裝置的ResourceMgr)。
////////////////////////////////////////////////////////////////////////////////
//
// Eager Service defines a TensorFlow service that executes operations eagerly
// on a set of local devices, on behalf of a remote Eager executor.
//
// The service impl will keep track of the various clients and devices it has
// access to and allows the client to enqueue ops on any devices that it is able
// to access and schedule data transfers from/to any of the peers.
//
// A client can generate multiple contexts to be able to independently execute
// operations, but cannot share data between the two contexts.
//
// NOTE: Even though contexts generated by clients should be independent, the
// lower level tensorflow execution engine is not, so they might share some data
// (e.g. a Device's ResourceMgr).
//
////////////////////////////////////////////////////////////////////////////////
service EagerService {
// This initializes the worker, informing it about the other workers in the
// cluster and exchanging authentication tokens which will be used in all
// other RPCs to detect whether the worker has restarted.
rpc CreateContext(CreateContextRequest) returns (CreateContextResponse);
// This updates the eager context on an existing worker when updating the set
// of servers in a distributed eager cluster.
rpc UpdateContext(UpdateContextRequest) returns (UpdateContextResponse);
// This takes a list of Execute and DeleteTensorHandle operations and enqueues
// (in async mode) or executes (in sync mode) them on the remote server.
// All outputs of ops which were not explicitly deleted with
// DeleteTensorHandle entries will be assumed to be alive and are usable by
// future calls to Enqueue.
rpc Enqueue(EnqueueRequest) returns (EnqueueResponse);
// A streaming version of Enqueue.
// Current server implementation sends one response per received request.
// The benefit for using a streaming version is that subsequent requests
// can be sent without waiting for a response to the previous request. This
// synchronization is required in the regular Enqueue call because gRPC does
// not guarantee to preserve request order.
rpc StreamingEnqueue(stream EnqueueRequest) returns (stream EnqueueResponse);
// Takes a set of op IDs and waits until those ops are done. Returns any error
// in the stream so far.
rpc WaitQueueDone(WaitQueueDoneRequest) returns (WaitQueueDoneResponse);
// This takes an Eager operation and executes it in async mode on the remote
// server. Different from EnqueueRequest, ops/functions sent through this
// type of requests are allowed to execute in parallel and no ordering is
// preserved by RPC stream or executor.
// This request type should only be used for executing component functions.
// Ordering of component functions should be enforced by their corresponding
// main functions. The runtime ensures the following invarients for component
// functions (CFs) and their main functions (MFs):
// (1) MF1 -> MF2 ==> CF1 -> CF2 ("->" indicates order of execution);
// (2) MF1 || MF2 ==> CF1 || CF2 ("||" indicates possible parallel execution);
// (3) For CF1 and CF2 that come from the same MF, CF1 || CF2
// For executing ops/main functions, use Enqueue or StreamingEnqueue instead
// for correct ordering.
rpc RunComponentFunction(RunComponentFunctionRequest)
returns (RunComponentFunctionResponse);
// Contexts are always created with a deadline and no RPCs within a deadline
// will trigger a context garbage collection. KeepAlive calls can be used to
// delay this. It can also be used to validate the existence of a context ID
// on remote eager worker. If the context is on remote worker, return the same
// ID and the current context view ID. This is useful for checking if the
// remote worker (potentially with the same task name and hostname / port) is
// replaced with a new process.
rpc KeepAlive(KeepAliveRequest) returns (KeepAliveResponse);
// Closes the context. No calls to other methods using the existing context ID
// are valid after this.
rpc CloseContext(CloseContextRequest) returns (CloseContextResponse);
}
5.1 AsyncServiceInterface
AsyncServiceInterface 是處理 RPC 的非同步介面,後面的 GrpcEagerServiceImpl 就繼承了 AsyncServiceInterface。
// Represents an abstract asynchronous service that handles incoming
// RPCs with a polling loop.
class AsyncServiceInterface {
public:
virtual ~AsyncServiceInterface() {}
// A blocking method that should be called to handle incoming RPCs.
// This method will block until the service shuts down.
virtual void HandleRPCsLoop() = 0;
// Starts shutting down this service.
//
// NOTE(mrry): To shut down this service completely, the caller must
// also shut down any servers that might share ownership of this
// service's resources (e.g. completion queues).
virtual void Shutdown() = 0;
};
5.2 GrpcEagerServiceImpl
GrpcEagerServiceImpl 屬於 gRPC Service,執行在 Server 執行緒之中,這裡重要的是成員變數 EagerServiceImpl,EagerServiceImpl 實現了具體業務邏輯。當收到訊息時候,會使用 local_impl_.method(&call->request, &call->response)) 來呼叫具體邏輯。
EagerServiceImpl local_impl_;
GrpcEagerServiceImpl 具體定義如下:
// This class is a wrapper that handles communication for gRPC.
class GrpcEagerServiceImpl : public AsyncServiceInterface {
public:
template <class RequestMessage, class ResponseMessage>
using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
RequestMessage, ResponseMessage>;
template <class RequestMessage, class ResponseMessage>
using StreamingCall =
ServerBidirectionalStreamingCall<GrpcEagerServiceImpl,
grpc::EagerService::AsyncService,
RequestMessage, ResponseMessage>;
GrpcEagerServiceImpl(const WorkerEnv* env,
::grpc::ServerBuilder* server_builder);
virtual ~GrpcEagerServiceImpl() {}
// Create a master context in eager service.
Status CreateMasterContext(const tensorflow::uint64 context_id,
EagerContext* context);
void HandleRPCsLoop() override;
void Shutdown() override;
private:
#define HANDLER(method) \
void method##Handler(EagerCall<method##Request, method##Response>* call) { \
env_->compute_pool->Schedule([this, call]() { \
call->SendResponse( \
ToGrpcStatus(local_impl_.method(&call->request, &call->response))); \
}); \
Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, \
method##Request, method##Response>:: \
EnqueueRequest(&service_, cq_.get(), \
&grpc::EagerService::AsyncService::Request##method, \
&GrpcEagerServiceImpl::method##Handler, false); \
}
HANDLER(CreateContext);
HANDLER(UpdateContext);
HANDLER(WaitQueueDone);
HANDLER(KeepAlive);
HANDLER(CloseContext);
#undef HANDLER
void EnqueueHandler(EagerCall<EnqueueRequest, EnqueueResponse>* call) {
env_->compute_pool->Schedule([this, call]() {
auto call_opts = std::make_shared<CallOptions>();
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
call->SendResponse(ToGrpcStatus(local_impl_.Enqueue(
call_opts.get(), &call->request, &call->response)));
});
Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, EnqueueRequest,
EnqueueResponse>::
EnqueueRequest(&service_, cq_.get(),
&grpc::EagerService::AsyncService::RequestEnqueue,
&GrpcEagerServiceImpl::EnqueueHandler,
/*supports_cancel=*/true);
}
void RunComponentFunctionHandler(
EagerCall<RunComponentFunctionRequest, RunComponentFunctionResponse>*
call) {
env_->compute_pool->Schedule([this, call]() {
auto call_opts = std::make_shared<CallOptions>();
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
local_impl_.RunComponentFunction(call_opts.get(), &call->request,
&call->response,
[call, call_opts](const Status& s) {
call->ClearCancelCallback();
call->SendResponse(ToGrpcStatus(s));
});
});
Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
RunComponentFunctionRequest, RunComponentFunctionResponse>::
EnqueueRequest(
&service_, cq_.get(),
&grpc::EagerService::AsyncService::RequestRunComponentFunction,
&GrpcEagerServiceImpl::RunComponentFunctionHandler,
/*supports_cancel=*/true);
}
// Called when a new request has been received as part of a StreamingEnqueue
// call.
// StreamingEnqueueHandler gets the request from the call and fills the
// response (also found in call ) by invoking the local EagerServiceImpl.
// The local EagerServiceImpl is invoked in a single-threaded thread pool. We
// do this to preserve request order. The local service can parallelize based
// on context_id in request if necessary. Remote contexts are created in async
// mode by default, so the local service impl just puts the request on eager
// executor queue.
void StreamingEnqueueHandler(
StreamingCall<EnqueueRequest, EnqueueResponse>* call) {
call->Ref();
enqueue_streaming_thread_.Schedule([this, call]() {
if (call->RefCountIsOne()) {
// This StreamingCall has already been shutdown. Don't need to anything.
call->Unref();
return;
}
// NOTE(fishx): Use the address of StreamingCall as the stream_id since we
// reuse the same StreamingCall for multiple requests in the same
// streaming connection.
Status status = local_impl_.Enqueue(
/*call_opts=*/nullptr, &call->request(), call->mutable_response(),
reinterpret_cast<uint64>(static_cast<void*>(call)));
if (status.ok()) {
call->SendResponse();
} else {
call->Finish(ToGrpcStatus(status));
}
call->Unref();
// We do not tell gRPC to accept a new StreamingEnqueue request because
// this method can be called multiple times for a given streaming call.
// The StreamingCall does this per call instead, after a call has been
// opened.
});
}
const WorkerEnv* const env_; // Not owned.
EagerServiceImpl local_impl_;
// A single-threaded thread pool to handle streaming enqueue rpc request.
thread::ThreadPool enqueue_streaming_thread_;
std::unique_ptr<::grpc::Alarm> shutdown_alarm_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
grpc::EagerService::AsyncService service_;
TF_DISALLOW_COPY_AND_ASSIGN(GrpcEagerServiceImpl);
};
5.3 執行執行緒
GrpcServer 會線上程之中執行 GrpcEagerServiceImpl。這裡省略了大多數程式碼,
Status GrpcServer::Init(const GrpcServerOptions& opts) {
eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
執行緒啟動執行在 GrpcServer::Start() 之中:
Status GrpcServer::Start() {
mutex_lock l(mu_);
switch (state_) {
case NEW: {
eager_thread_.reset(
env_->StartThread(ThreadOptions(), "TF_eager_service",
[this] { eager_service_->HandleRPCsLoop(); }));
其響應 RPC 是在 HandleRPCsLoop 之中。
void GrpcEagerServiceImpl::HandleRPCsLoop() {
#define ENQUEUE_REQUEST(method) \
do { \
Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, \
method##Request, method##Response>:: \
EnqueueRequest(&service_, cq_.get(), \
&grpc::EagerService::AsyncService::Request##method, \
&GrpcEagerServiceImpl::method##Handler, false); \
} while (0)
ENQUEUE_REQUEST(CreateContext);
5.4 業務實現 EagerServiceImpl
EagerServiceImpl 是業務實現,我們只給出成員變數,後續會介紹相關方法。
// A TensorFlow Eager Worker runs ops and supports worker to worker
// Tensor transfer.
//
// See eager_service.proto for more details about each method.
// This class can be wrapped by specific classes that implement rpc transports
// over this (e.g. gRPC).
class EagerServiceImpl {
const WorkerEnv* const env_; // Not owned.
mutex contexts_mu_;
std::unordered_map<uint64, ServerContext*> contexts_
TF_GUARDED_BY(contexts_mu_);
std::unique_ptr<Thread> gc_thread_;
mutex gc_thread_shutdown_mu_;
condition_variable gc_thread_cv_;
bool shutting_down_ TF_GUARDED_BY(gc_thread_shutdown_mu_) = false;
TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl);
};
5.5 建立遠端上下文
在接受到 CreateContextRequest 之後,Server 首先呼叫到 GrpcEagerServiceImpl 的 CreateContextHandler,然後呼叫到 EagerServiceImpl 的 CreateContext。看起來,context_id 類似於 session_id。Context 起到了之前我們分析過的 master 作用,所以下面程式碼之中,處處可見 worker_session。
Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
CreateContextResponse* response) {
{
mutex_lock l(contexts_mu_);
auto context_it = contexts_.find(request->context_id());
if (context_it != contexts_.end()) {
if (request->context_view_id() <
context_it->second->Context()->GetContextViewId()) {
return errors::InvalidArgument("EagerService:CreateContext failed. ",
"Context id: <", request->context_id(),
"> already exists.");
} else {
// For existing context with a stale context_view_id, close the old one
// and recreate with new view id. This is likely due to the worker
// disconnected and then reconnected after one or more cluster updates.
context_it->second->Unref();
contexts_.erase(context_it);
}
}
}
// 看起來,context_id 類似於 session_id
auto* r = env_->rendezvous_mgr->Find(request->context_id());
auto session_name =
tensorflow::strings::StrCat("eager_", request->context_id());
}
// 建立 worker_session
TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
session_name, request->server_def(), request->cluster_device_attributes(),
true));
int64_t context_id = request->context_id();
std::function<void()> session_destroyer = [this, context_id, session_name]() {
env_->rendezvous_mgr->Cleanup(context_id);
auto s = env_->session_mgr->DeleteSession(session_name);
};
// 拿到 worker_session
std::shared_ptr<WorkerSession> worker_session;
TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
// 拿到 DeviceMgr
tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
// Initialize remote tensor communication based on worker session.
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
std::function<Rendezvous*(const int64_t)> rendezvous_creator =
[worker_session, this](const int64_t step_id) {
auto* r = env_->rendezvous_mgr->Find(step_id);
r->Initialize(worker_session.get()).IgnoreError();
return r;
};
// 建立上下文 EagerContext
SessionOptions opts;
opts.config = request->server_def().default_session_config();
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
request->async(), device_mgr, false, r, worker_session->cluster_flr(),
env_->collective_executor_mgr.get());
// Ownership will be transferred to the ServerContext, or else in an error
// case ctx will be deleted by this unref.
core::ScopedUnref unref_ctx(ctx);
// 列出遠端 workers
std::vector<string> remote_workers;
worker_session->worker_cache()->ListWorkers(&remote_workers);
remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
worker_session->worker_name()),
remote_workers.end());
// 列出遠端 remote_eager_workers
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
&remote_eager_workers));
// 建立 DistributedFunctionLibraryRuntime
DistributedFunctionLibraryRuntime* cluster_flr =
eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get());
// 初始化 InitializeRemoteWorker
auto remote_mgr =
absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx);
Status s = ctx->InitializeRemoteWorker(
std::move(remote_eager_workers), worker_session->remote_device_mgr(),
remote_workers, request->context_id(), request->context_view_id(),
std::move(rendezvous_creator), cluster_flr, std::move(remote_mgr),
std::move(session_destroyer));
if (!s.ok()) {
return s;
}
#if !defined(IS_MOBILE_PLATFORM)
// 建立 EagerContextDistributedManager
const auto& config = request->server_def().default_session_config();
const bool enable_coordination =
!config.experimental().coordination_service().empty();
if (enable_coordination) {
auto dist_mgr = std::make_unique<EagerContextDistributedManager>(ctx);
ctx->SetDistributedManager(std::move(dist_mgr));
TF_RETURN_IF_ERROR(ctx->GetDistributedManager()->EnableCoordinationService(
config.experimental().coordination_service(), env_,
request->server_def(), worker_session->worker_cache()));
std::unique_ptr<CoordinationClientCache> client_cache;
TF_RETURN_IF_ERROR(
worker_session->worker_cache()->GetCoordinationClientCache(
&client_cache));
TF_RETURN_IF_ERROR(
ctx->GetDistributedManager()->GetCoordinationServiceAgent()->Initialize(
env_, request->server_def(), std::move(client_cache),
/*error_fn=*/[](Status s) {
LOG(ERROR) << "Coordination agent is set to error: " << s;
}));
}
#endif // !IS_MOBILE_PLATFORM
std::vector<DeviceAttributes> device_attributes;
device_mgr->ListDeviceAttributes(&device_attributes);
for (const auto& da : device_attributes) {
*response->add_device_attributes() = da;
}
{
mutex_lock l(contexts_mu_);
auto context_it = contexts_.find(request->context_id());
contexts_.emplace(request->context_id(),
new ServerContext(ctx, request->keep_alive_secs(), env_));
}
return Status::OK();
}
Worker 邏輯如下:
圖 2 Worker 端建立上下文流程
整體邏輯如下:
圖 3 建立上下文總體流程
至此,上下文環境我們分析完畢,遠端分散式執行的基礎也有了,我們接下來就要看看如何在遠端執行訓練程式碼。
0x6. FunctionLibraryRuntime
前面程式碼之中,Client 使用如下語句來建立 FunctionLibraryRuntime。
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, context, worker_session.get());
Server 在 EagerServiceImpl::CreateContext 之中也使用如下語句來建立 FunctionLibraryRuntime。
DistributedFunctionLibraryRuntime* cluster_flr =
eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get());
CreateClusterFLR 的定義在 tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc 之中。
DistributedFunctionLibraryRuntime* CreateClusterFLR(
const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) {
return new EagerClusterFunctionLibraryRuntime(
context_id, ctx, worker_session->remote_device_mgr());
}
於是我們引出了 FunctionLibraryRuntime 這個 TF 的核心概念。而 DistributedFunctionLibraryRuntime 就是其分散式實現。
6.1 介面 DistributedFunctionLibraryRuntime
DistributedFunctionLibraryRuntime 是基礎 API 介面。
// Used to instantiate and run functions in a distributed system.
class DistributedFunctionLibraryRuntime {
public:
virtual ~DistributedFunctionLibraryRuntime() {}
// Instantiate a function on a remote target specified in options.target , by
// sending the name and definition of the function to the remote worker. The
// local handle is filled for the instantiated function data and can be used
// for subsequent run function calls on the remote target.
virtual void Instantiate(
const std::string& function_name,
const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* handle,
FunctionLibraryRuntime::DoneCallback done) = 0;
// Run an instantiated remote function (specified by handle ) with a list of
// input Tensors in args and get its output Tensors in rets . The input
// tensor data will be sent with the function execution request, and must be
// available on the current caller side.
// opts.runner isn't used for execution.
virtual void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) = 0;
// Run an instantiated remote function (specified by handle ) with a list of
// input Tensors or RemoteTensorHandles as args and get its output Tensors
// or TensorShapes in rets . When using RemoteTensorHandles as function
// inputs or TensorShapes as outputs, the corresponding tensor data will be
// resolved on the remote worker, so it is not required to be locally
// available on the caller side. Using RemoteTensorHandle inputs is not
// supported in TensorFlow v1 runtime.
virtual void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args,
std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) = 0;
// Clean up a previously instantiated function on remote worker.
virtual void CleanUp(uint64 step_id,
FunctionLibraryRuntime::LocalHandle handle,
FunctionLibraryRuntime::DoneCallback done) = 0;
// DeviceMgr with *all* available devices (i.e., local and remote).
virtual DeviceMgr* remote_device_mgr() const = 0;
};
6.2 EagerClusterFunctionLibraryRuntime
EagerClusterFunctionLibraryRuntime 是具體實現,用來在服務之間通過 RPC 來執行 function。
// EagerClusterFunctionLibraryRuntime contains methods to Instantiate and Run
// functions across processes by making RPCs through eager service.
class EagerClusterFunctionLibraryRuntime
: public DistributedFunctionLibraryRuntime {
public:
EagerClusterFunctionLibraryRuntime(const uint64 context_id, EagerContext* ctx,
DeviceMgr* remote_device_mgr)
: context_id_(context_id),
ctx_(ctx),
remote_device_mgr_(remote_device_mgr) {}
~EagerClusterFunctionLibraryRuntime() override{};
// Register a partition (i.e., component function) of a multi-device function
// on the remote target specified in options.target . This should be
// triggered as part of instantiating a multi-device function in
// ProcessFunctionLibraryRuntime.
void Instantiate(const string& function_name,
const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* handle,
FunctionLibraryRuntime::DoneCallback done) override;
// Execute the component function specified by handle on its instantiated
// remote target. This should be triggered as part of driving a multi-device
// function execution in ProcessFunctionLibraryRuntime. Running the component
// function remotely is purely asynchronous, and multiple component functions
// with the same remote target are not executed in any particular ordering.
// The main function side must wait for all component functions to finish
// (i.e., the done callbacks triggered) before finishing its execution.
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) override;
// The component function inputs args and outputs rets may refer to remote
// tensors on a remote device, which will be lazily resolved remotely where
// the inputs/outputs are actually consumed.
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) override;
void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
FunctionLibraryRuntime::DoneCallback done) override;
DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; }
private:
const uint64 context_id_;
EagerContext* ctx_;
DeviceMgr* remote_device_mgr_; // not owned.
struct FunctionData {
const string target;
const absl::optional<std::vector<int>> ret_indices;
core::RefCountPtr<EagerClient> eager_client;
std::unique_ptr<EagerOperation> op;
FunctionData(const string& target,
const absl::optional<std::vector<int>>& ret_indices,
EagerClient* eager_client, std::unique_ptr<EagerOperation> op)
: target(target),
ret_indices(ret_indices),
eager_client(core::RefCountPtr<EagerClient>(eager_client)),
op(std::move(op)) {
eager_client->Ref();
}
};
mutable mutex mu_;
std::vector<FunctionData> function_data_ TF_GUARDED_BY(mu_);
};
6.2.1 初始化
Instantiate 方法用來初始化。
void EagerClusterFunctionLibraryRuntime::Instantiate(
const string& function_name, const FunctionLibraryDefinition& lib_def,
AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* handle,
FunctionLibraryRuntime::DoneCallback done) {
auto target = options.target;
auto released_op = std::make_unique<EagerOperation>(ctx_);
Status s =
released_op->Reset(function_name.c_str(), target.c_str(), true, nullptr);
core::RefCountPtr<eager::EagerClient> eager_client;
s = ctx_->GetClient(target, &eager_client);
const FunctionLibraryDefinition& func_lib_def =
options.lib_def ? *options.lib_def : lib_def;
auto request = std::make_shared<EnqueueRequest>();
auto response = std::make_shared<EnqueueResponse>();
request->set_context_id(context_id_);
RegisterFunctionOp* register_function =
request->add_queue()->mutable_register_function();
*register_function->mutable_function_def() =
*func_lib_def.Find(function_name);
register_function->set_is_component_function(true);
*register_function->mutable_library() =
func_lib_def.ReachableDefinitions(register_function->function_def())
.ToProto();
StripDefaultAttributesInRegisterFunctionOp(register_function);
const absl::optional<std::vector<int>>& ret_indices = options.ret_indices;
eager_client->EnqueueAsync(
/*call_opts=*/nullptr, request.get(), response.get(),
[this, request, response, handle, released_op = released_op.release(),
target, ret_indices, eager_client = eager_client.get(),
done](const Status& s) {
{
mutex_lock l(mu_);
*handle = function_data_.size();
function_data_.emplace_back(target, ret_indices, eager_client,
absl::WrapUnique(released_op));
}
done(s);
});
}
6.2.2 執行 component
如果希望執行計算圖,則會進入 EagerClusterFunctionLibraryRuntime 的 Run 方法,然後 RunComponentFunctionAsync 會呼叫 RPC 通知遠端 worker。
void EagerClusterFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) {
FunctionData* function_data = nullptr;
{
mutex_lock l(mu_);
DCHECK_LE(handle, function_data_.size());
function_data = &function_data_[handle];
}
EagerClient* eager_client = function_data->eager_client.get();
EagerOperation* op = function_data->op.get();
auto request = std::make_shared<RunComponentFunctionRequest>();
auto response = std::make_shared<RunComponentFunctionResponse>();
request->set_context_id(context_id_);
eager::Operation* remote_op = request->mutable_operation();
if (function_data->ret_indices.has_value()) {
for (const int ret_index : function_data->ret_indices.value()) {
request->add_output_num(ret_index);
}
}
for (const auto& arg : args) {
if (arg.index() == 0) {
absl::get<Tensor>(arg).AsProtoTensorContent(
remote_op->add_op_inputs()->mutable_tensor());
} else {
remote_op->add_op_inputs()->mutable_remote_handle()->Swap(
absl::get<RemoteTensorHandle*>(arg));
}
}
// The remote component function should use the same op_id as its parent
// multi-device function's in order to get the global unique op_id generated
// by the master context.
if (opts.op_id.has_value()) {
remote_op->set_id(opts.op_id.value());
} else {
remote_op->set_id(kInvalidRemoteOpId);
}
remote_op->set_is_function(true);
remote_op->set_is_component_function(true);
remote_op->set_func_step_id(opts.step_id);
remote_op->set_name(op->Name());
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
remote_op->set_device(function_data->target);
CancellationManager* cm = opts.cancellation_manager;
CancellationToken token = 0;
auto call_opts = std::make_shared<CallOptions>();
if (cm != nullptr) {
token = cm->get_cancellation_token();
const bool already_cancelled = !cm->RegisterCallback(
token,
[call_opts, request, response, done]() { call_opts->StartCancel(); });
if (already_cancelled) {
done(errors::Cancelled("EagerClusterFunctionLibraryRuntime::Run"));
return;
}
}
// Execute component function on remote worker using RunComponentFunction RPC.
// Different from executing remote functions with Enqueue, this method runs
// a function on remote worker without tying up a thread (i.e., pure
// asynchronously).
eager_client->RunComponentFunctionAsync(
call_opts.get(), request.get(), response.get(),
[request, response, rets, call_opts, cm, token,
done = std::move(done)](const Status& s) {
if (cm != nullptr) {
cm->TryDeregisterCallback(token);
}
if (!s.ok()) {
done(s);
return;
}
for (const auto& shape : response->shape()) {
rets->push_back(shape);
}
for (const auto& tensor_proto : response->tensor()) {
Tensor t;
if (t.FromProto(tensor_proto)) {
rets->push_back(std::move(t));
} else {
done(errors::Internal("Could not convert tensor proto: ",
tensor_proto.DebugString()));
return;
}
}
done(Status::OK());
});
}
然後傳送 RunComponentFunctionRequest 給遠端 Worker,遠端 Worker 處理之後返回 RunComponentFunctionResponse。類邏輯如下,其中 ClusterFunctionLibraryRuntime 也是一個派生類,但是和我們分析關係不大。
圖 4 DistributedFunctionLibraryRuntime 類邏輯
// ClusterFunctionLibraryRuntime contains methods to Instantiate and Run
// functions across processes by making RPCs through worker service.
class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
public:
ClusterFunctionLibraryRuntime(WorkerSession* worker_session,
bool create_worker_session_called,
DeviceMgr* remote_device_mgr)
: worker_session_(worker_session),
create_worker_session_called_(create_worker_session_called),
remote_device_mgr_(remote_device_mgr) {}
6.3 遠端 Worker
遠端 Worker 首先呼叫到 GrpcEagerServiceImpl 的 RunComponentFunctionHandler,然後呼叫到 EagerServiceImpl 的 RunComponent。
6.3.1 GrpcEagerServiceImpl
RunComponentFunctionHandler 是一個巨集,具體我們在分散式環境之中已經分析過。
#define ENQUEUE_REQUEST(method) \
do { \
Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, \
method##Request, method##Response>:: \
EnqueueRequest(&service_, cq_.get(), \
&grpc::EagerService::AsyncService::Request##method, \
&GrpcEagerServiceImpl::method##Handler, false); \
} while (0)
ENQUEUE_REQUEST(RunComponentFunction);
6.3.2 EagerServiceImpl
EagerServiceImpl::RunComponentFunction 則處理具體業務,主要就是呼叫 EagerLocalExecuteAsync 完成具體執行。
void EagerServiceImpl::RunComponentFunction(
CallOptions* call_opts, const RunComponentFunctionRequest* request,
RunComponentFunctionResponse* response, StatusCallback done) {
ServerContext* context = nullptr;
Status s = GetServerContext(request->context_id(), &context);
core::ScopedUnref context_unref(context);
auto& operation = request->operation();
// This codepath should only be triggered for executing component function
if (!operation.is_function() || !operation.is_component_function()) {
done(errors::Internal(
"RunComponentFunction request can only be used to execute "
"component functions."));
return;
}
EagerContext* eager_context = context->Context();
EagerExecutor* eager_executor = &eager_context->Executor();
EagerOperation* op = new EagerOperation(eager_context);
int* num_retvals = new int(0);
s = GetEagerOperationAndNumRetvals(operation, eager_context, eager_executor,
op, num_retvals);
s = op->SetAttrBool("is_component_function", true);
auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
std::vector<int32> output_nums;
for (const int32_t output_num : request->output_num()) {
output_nums.push_back(output_num);
}
auto cm = std::make_shared<CancellationManager>();
op->SetCancellationManager(cm.get());
call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
context->Ref();
EagerLocalExecuteAsync(
op, retvals->data(), num_retvals,
[op, op_id = operation.id(), num_retvals, retvals, output_nums, cm,
call_opts, response, eager_context, context,
done = std::move(done)](const Status& status) {
call_opts->ClearCancelCallback();
auto wrapped_done = [&](const Status& status) {
context->Unref();
done(status);
delete op;
delete num_retvals;
delete retvals;
};
if (!status.ok()) {
wrapped_done(status);
return;
}
// The output device of a component function is the component device
// which is known on the default device of it's parent function.
wrapped_done(AddOpRetvalsToResponse(
eager_context, op_id, *num_retvals, output_nums, retvals->data(),
[response] { return response->add_tensor(); },
[response] { return response->add_shape(); }));
});
}
因此我們最終邏輯如下:
圖 5 如何處理遠端執行時
0x7. 總結
我們總結一下本文所分析的成果:
-
本地多執行緒還是多程式計算?
MirroredStrategy 在本地會使用多執行緒進行訓練:在 _call_for_each_replica 之中,會建立 _MirroredReplicaThread 來執行。每個裝置會起動一個執行緒,並行執行fn,直至所有 fn 都完成。
每個執行緒的計算都會分配到遠端工作者之上。
-
MirroredStrategy 和我們之前分析的 TF 執行時怎麼聯絡起來?
Context 在某種程度上起到了 TF 1 Session 概念環境之中 Master 的作用,對計算進行分發。
在遠端,Eager 服務定義了一個 TensorFlow 服務,它會在遠端建立遠端上下文,會把 Context 分發的計算在本地裝置上執行操作。
-
如何分發計算?如何在遠端執行訓練程式碼?
EagerClusterFunctionLibraryRuntime 負責在服務之間通過 RPC 來執行 function。如果希望執行計算圖,本地會進入 EagerClusterFunctionLibraryRuntime 的 Run 方法,然後 RunComponentFunctionAsync 會呼叫 RPC(傳送 RunComponentFunctionRequest)通知遠端 worker。
遠端 Worker 首先呼叫到 GrpcEagerServiceImpl 的 RunComponentFunctionHandler,然後呼叫到 EagerServiceImpl 的 RunComponent。
EagerServiceImpl::RunComponentFunction 負責處理具體業務,主要就是呼叫 EagerLocalExecuteAsync 完成具體執行。
遠端 Worker 處理之後返回 RunComponentFunctionResponse。
至此,MirroredStrategy 分析完畢。