[原始碼解析] TensorFlow 分散式環境(7) --- Worker 動態邏輯

羅西的思考發表於2022-04-01

[原始碼解析] TensorFlow 分散式環境(7) --- Worker 動態邏輯

前文中,Master 在流程之中先後呼叫了 gRPC 給遠端 worker 傳送命令,即,GrpcRemoteWorker 類中的每一個函式都通過呼叫 IssueRequest() 發起一個非同步的 gRPC 呼叫。GrpcRemoteWorker 一共發了兩個請求:RegisterGraphAsync,RunGraphAsync,我們看看 GrpcWorkerService 如何處理。

本文依舊深度借鑑了兩位大神:

本系列其他文章是:

[翻譯] TensorFlow 分散式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"

[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構

[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache

[原始碼解析] TensorFlow 分散式環境(5) --- Session

1. 概述

1.1 溫故

我們首先回顧一下目前為止各種概念之間的關係。

  • Client會構建完整的計算圖(FullGraph),但是這個完整計算圖無法並行執行,所以需要切分優化。
  • Master會對完整計算圖進行處理,比如剪枝等操作,生成ClientGraph(可以執行的最小依賴子圖)。然後根據Worker資訊把ClientGraph繼續切分成多個PartitionGraph。把這些PartitionGraph註冊給每個Worker。
  • Worker接收到註冊請求之後,會把收到的PartitionGraph根據本地計算裝置集繼續做切分成多個PartitionGraph,並且在每個裝置上啟動一個Executor來執行本裝置收到的PartitionGraph。

1.2 知新

我們接下來看看Worker的流程概要。當流程來到某個特點 Worker 節點,如果 worker 節點收到了 RegisterGraphRequest,訊息會攜帶 MasterSession 分配的 session_handle 和子圖 graph_def(GraphDef形式)。GraphDef是TensorFlow把Client建立的計算圖使用Protocol Buffer序列化之後的結果。GraphDef包括了計算圖所有的後設資料。它可以被ConvertGraphDefToGraph方法轉換成Graph。Graph不但有計算圖的後設資料,還有其他執行時候所需要的資訊。

Worker 把計算圖按照本地裝置集繼續切分成多個 PartitionGraph,把PartitionGraph 分配給每個裝置,然後在每個計算裝置之上啟動一個 Executor,等待後續執行命令。Executor類是TensorFlow之中會話執行器的抽象,其提供非同步執行區域性圖的RunAsync虛方法及其同步封裝版本Run方法。

當 Worker 節點收到 RunGraphAsync 之後,各個裝置開始執行。WorkerSession 會呼叫 session->graph_mgr()->ExecuteAsync 執行,其又呼叫到 StartParallelExecutors,這裡會啟動一個 ExecutorBarrier。當某一個計算裝置執行完所分配的 PartitionGraph 後,ExecutorBarrier 計數器將會增加 1,如果所有裝置都完成 PartitionGraph 列表的執行,barrier.wait() 阻塞操作將退出。

我們接下來逐步分析一下上述流程。

2. 註冊子圖

當 worker 節點收到了 RegisterGraphRequest 之後,首先來到了 GrpcWorkerService,所以實際呼叫的是 "/tensorflow.WorkerService/RegisterGraph",對應程式碼如下,其實展開了就是 RegisterGraphHandler:

#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if (!s.ok()) {                                                          \
        VLOG(3) << "Bad response from " << #method << ": " << s;              \
      }                                                                       \
      call->SendResponse(ToGrpcStatus(s));                                    \
    };                                                                        \
    if ((may_block_on_compute_pool)) {                                        \
      worker_->env()->env->SchedClosure(std::move(closure));                  \
    } else {                                                                  \
      worker_->env()->compute_pool->Schedule(std::move(closure));             \
    }                                                                         \
    ENQUEUE_REQUEST(method, false);                                           \
  }

HANDLE_CALL(RegisterGraph, false);

2.1 GrpcWorker

RegisterGraph 實際呼叫的是 WorkerInterface 的方法,其內部會轉到 RegisterGraphAsync 方法。

Status WorkerInterface::RegisterGraph(const RegisterGraphRequest* request,
                     RegisterGraphResponse* response) {
  return CallAndWait(&ME::RegisterGraphAsync, request, response);
}

RegisterGraphAsync 最後來到 Worker 的實現,其首先依據 session_handle 查詢到 WokerSession,然後呼叫 GraphMgr。

GraphMgr* SessionMgr::graph_mgr() const { return graph_mgr_.get(); }

RegisterGraphAsync 具體如下:

void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
                                RegisterGraphResponse* response,
                                StatusCallback done) {
  std::shared_ptr<WorkerSession> session;
  Status s;
  if (request->create_worker_session_called()) {
    s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
                                                   &session);
  } else {
    session = env_->session_mgr->LegacySession();
  }
  if (s.ok()) {
    s = session->graph_mgr()->Register(
        request->session_handle(), request->graph_def(), session.get(),
        request->graph_options(), request->debug_options(),
        request->config_proto(), request->collective_graph_key(),
        session->cluster_flr(), response->mutable_graph_handle());
  }
  done(s);
}

2.2 GraphMgr

GraphMgr 負責跟蹤一組在 TensorFlow 工作者那裡註冊的計算圖。每個註冊的圖都由 GraphMgr 生成的控制程式碼 graph_handle 來識別,並返回給呼叫者。在成功註冊後,呼叫者使用圖控制程式碼執行一個圖。每個執行都通過呼叫者生成的全域性唯一ID "step_id"與其他執行區分開來。只要使用的 "step_id"不同,多個執行可以同時獨立使用同一個圖,多個執行緒可以併發地呼叫 GraphMgr 方法。

2.2.1 定義

GraphMgr 具體定義如下:

class GraphMgr {
 private:
  typedef GraphMgr ME;

  struct ExecutionUnit {
    std::unique_ptr<Graph> graph = nullptr;
    Device* device = nullptr;               // not owned.
    Executor* root = nullptr;               // not owned.
    FunctionLibraryRuntime* lib = nullptr;  // not owned.
    // Build the cost model if this value is strictly positive.
    int64_t build_cost_model = 0;
  };

  struct Item : public core::RefCounted {
    ~Item() override;

    // Session handle.
    string session;

    // Graph handle.
    string handle;

    std::unique_ptr<FunctionLibraryDefinition> lib_def;
    // Owns the FunctionLibraryRuntime objects needed to execute functions, one
    // per device.
    std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
    // A graph is partitioned over multiple devices.  Each partition
    // has a root executor which may call into the runtime library.
    std::vector<ExecutionUnit> units;

    // Used to deregister a cost model when cost model is required in graph
    // manager.
    GraphMgr* graph_mgr;

    int64_t collective_graph_key;
  };

  const WorkerEnv* worker_env_;  // Not owned.
  const DeviceMgr* device_mgr_;

  CostModelManager cost_model_manager_;

  // Owned.
  mutex mu_;
  int64_t next_id_ TF_GUARDED_BY(mu_) = 0;

  // If true, blocks until device has finished all queued operations in a step.
  bool sync_on_finish_ = true;

  // Table mapping graph handles to registered graphs.
  //
  // TODO(zhifengc): If the client does not call Deregister, we'll
  // lose memory over time. We should implement a timeout-based
  // mechanism to gc these graphs.
  std::unordered_map<string, Item*> table_;

  TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr);
};

具體各個類之間關係和功能如下,註冊圖就是往GraphMgr的table_變數之中進行註冊新Item,而執行圖就是執行具體的Item。

2.2.2 註冊圖

註冊圖程式碼如下,其實就是轉交給 InitItem,所以我們接下去看看 InitItem。

Status GraphMgr::Register(
    const string& handle, const GraphDef& gdef, WorkerSession* session,
    const GraphOptions& graph_options, const DebugOptions& debug_options,
    const ConfigProto& config_proto, int64_t collective_graph_key,
    DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
  Item* item = new Item;
  Status s = InitItem(handle, gdef, session, graph_options, debug_options,
                      config_proto, collective_graph_key, cluster_flr, item);
  if (!s.ok()) {
    item->Unref();
    return s;
  }

  // Inserts one item into table_.
  {
    mutex_lock l(mu_);
    *graph_handle =
        strings::Printf("%016llx", static_cast<long long>(++next_id_));
    item->handle = *graph_handle;
    CHECK(table_.insert({*graph_handle, item}).second);
  }
  return Status::OK();
}

InitItem 主要功能是:

  • 在給定 session 的一個圖定義 "gdef" 之後,建立 executors。

  • 如果 "gdef"中的一個節點被 "session "中的其他圖所共享,則相同的 op kernel 被重複使用。例如,通常一個params節點被一個會話中的多個圖所共享。

  • 如果 "gdef"被分配給多個裝置,可能會新增額外的節點(例如,傳送/接收節點)。額外節點的名字是通過呼叫 "new_name(old_name) "生成的。

  • 如果成功的話,"executors"將被分配,每個裝置填入一個執行器,呼叫者將擁有返回的 executors 的所有權。

// Creates executors given a graph definition "gdef" of a "session".
// If a node in "gdef" is shared by other graphs in "session", the
// same op kernel is reused. E.g., typically a params node is shared
// by multiple graphs in a session.
//
// If "gdef" is assigned to multiple devices, extra nodes (e.g.,
// send/recv nodes) maybe added. The extra nodes' name are generated
// by calling "new_name(old_name)".
//
// "executors" are filled with one executor per device if success and
// the caller takes the ownership of returned executors.
Status GraphMgr::InitItem(
    const string& handle, const GraphDef& gdef, WorkerSession* session,
    const GraphOptions& graph_options, const DebugOptions& debug_options,
    const ConfigProto& config_proto, int64_t collective_graph_key,
    DistributedFunctionLibraryRuntime* cluster_flr, Item* item) {
  item->session = handle;
  item->collective_graph_key = collective_graph_key;
  item->lib_def.reset(
      new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));

  TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));

  // We don't explicitly Validate the graph def because ConvertGraphDefToGraph
  // does that below.
  item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
      device_mgr_, worker_env_->env, /*config=*/&config_proto,
      gdef.versions().producer(), item->lib_def.get(),
      graph_options.optimizer_options(), worker_env_->compute_pool, cluster_flr,
      /*session_metadata=*/nullptr,
      Rendezvous::Factory{
          [this, session](const int64_t step_id, const DeviceMgr*,
                          Rendezvous** r) -> Status {
            auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id);
            TF_RETURN_IF_ERROR(remote_r->Initialize(session));
            *r = remote_r;
            return Status::OK();
          },
          [this](const int64_t step_id) {
            this->worker_env_->rendezvous_mgr->Cleanup(step_id);
            return Status::OK();
          }}));

  // Constructs the graph out of "gdef".
  Graph graph(OpRegistry::Global());
  GraphConstructorOptions opts;
  opts.allow_internal_ops = true;
  opts.expect_device_spec = true;
  opts.validate_nodes = true;
  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));

  // Splits "graph" into multiple subgraphs by device names.
  std::unordered_map<string, GraphDef> partitions;
  PartitionOptions popts;
  popts.node_to_loc = SplitByDevice; // 這裡呼叫了
  popts.new_name = [this](const string& prefix) {
    mutex_lock l(mu_);
    return strings::StrCat(prefix, "_G", next_id_++);
  };
  popts.get_incarnation = [this](const string& name) -> int64 {
    Device* device = nullptr;
    Status s = device_mgr_->LookupDevice(name, &device);
    if (s.ok()) {
      return device->attributes().incarnation();
    } else {
      return PartitionOptions::kIllegalIncarnation;
    }
  };
  popts.flib_def = item->lib_def.get();
  popts.control_flow_added = true;
  popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
  TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
  if (popts.scheduling_for_recvs) {
    TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
  }

  std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
  // 對每個分割槽進行圖轉換
  for (auto& partition : partitions) {
    std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
    GraphConstructorOptions device_opts;
    // There are internal operations (e.g., send/recv) that we now allow.
    device_opts.allow_internal_ops = true;
    device_opts.expect_device_spec = true;
    TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
        device_opts, std::move(partition.second), device_graph.get()));
    partition_graphs.emplace(partition.first, std::move(device_graph));
  }

  GraphOptimizationPassOptions optimization_options;
  optimization_options.flib_def = item->lib_def.get();
  optimization_options.partition_graphs = &partition_graphs;
  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
      OptimizationPassRegistry::POST_PARTITIONING, optimization_options));

  LocalExecutorParams params;

  item->units.reserve(partitions.size());
  item->graph_mgr = this;
  const auto& optimizer_opts = graph_options.optimizer_options();
  GraphOptimizer optimizer(optimizer_opts);
  for (auto& p : partition_graphs) {
    const string& device_name = p.first;
    std::unique_ptr<Graph>& subgraph = p.second;
    item->units.resize(item->units.size() + 1);
    ExecutionUnit* unit = &(item->units.back());

    // Find the device.
    Status s = device_mgr_->LookupDevice(device_name, &unit->device);
    if (!s.ok()) {
      // Remove the empty unit from the item as the item destructor wants all
      // units to have valid devices.
      item->units.pop_back();
      return s;
    }

    // 看看是否需要重寫圖
    // Give the device an opportunity to rewrite its subgraph.
    TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph));

    // Top-level nodes in the graph uses the op segment to cache
    // kernels. Therefore, as long as the executor is alive, we need
    // to ensure the kernels cached for the session are alive.
    auto opseg = unit->device->op_segment();
    opseg->AddHold(handle);

    // Function library runtime.
    FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name());

    // 建立 executor
    // Construct the root executor for the subgraph.
    params.device = unit->device;
    params.function_library = lib;
    params.create_kernel =
        [handle, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
                             OpKernel** kernel) {
          // NOTE(mrry): We must not share function kernels (implemented
          // using `CallOp`) between subgraphs, because `CallOp::handle_`
          // is tied to a particular subgraph. Even if the function itself
          // is stateful, the `CallOp` that invokes it is not.
          if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
            return lib->CreateKernel(props, kernel);
          }
          auto create_fn = [lib, &props](OpKernel** kernel) {
            return lib->CreateKernel(props, kernel);
          };
          // Kernels created for subgraph nodes need to be cached.  On
          // cache miss, create_fn() is invoked to create a kernel based
          // on the function library here + global op registry.
          return opseg->FindOrCreate(handle, props->node_def.name(), kernel,
                                     create_fn);
        };
    params.delete_kernel = [lib](OpKernel* kernel) {
      if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
        delete kernel;
      }
    };

    // 優化圖
    optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
                       GraphOptimizer::Options());

    TF_RETURN_IF_ERROR(
        EnsureMemoryTypes(DeviceType(unit->device->device_type()),
                          unit->device->name(), subgraph.get()));
    unit->graph = std::move(subgraph);
    unit->build_cost_model = graph_options.build_cost_model();
    if (unit->build_cost_model > 0) {
      skip_cost_models_ = false;
    }
    TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root));
  }
  return Status::OK();
}

上面需要注意的一點是使用了 SplitByDevice 進行圖的二次切分,這次是按照裝置來切分。

// NOTE: node->device_name() is not set by GraphConstructor.  We
// expects that NodeDef in GraphDef given to workers fully specifies
// device names.
static string SplitByDevice(const Node* node) {
  return node->assigned_device_name();
}

inline const std::string& Node::assigned_device_name() const {
  return graph_->get_assigned_device_name(*this);
}

註冊圖的結果大致如下,就是使用Master傳來的各種資訊來生成一個Item,註冊在GraphMgr之中,同時也為Item生成ExecutionUnit,其中graph_handle是根據handle生成的。

註冊完子圖之後,後續就可以執行子圖。

3. 執行子圖

Master 用 RunGraphRequest 來執行在 graph_handle下注冊的所有子圖。Master 會生成一個全域性唯一的 step_id 來區分圖計算的不同執行 step。子圖之間可以使用 step_id 進行彼此通訊(例如,傳送/轉發操作),以區分不同執行產生的張量。

RunGraphRequest 訊息的 send 表示子圖輸入的張量,recv_key 指明子圖輸出的張量。RunGraphResponse 會返回 recv_key 對應的 Tensor 列表。

3.1 Service

首先來到了 GrpcWorkerService,呼叫到的是 "/tensorflow.WorkerService/RunGraph",對應的程式碼是:

void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
  // 利用Schedule把計算任務放進執行緒池佇列中
  Schedule([this, call]() {
    CallOptions* call_opts = new CallOptions;
    ProtoRunGraphRequest* wrapped_request =
        new ProtoRunGraphRequest(&call->request);
    NonOwnedProtoRunGraphResponse* wrapped_response =
        new NonOwnedProtoRunGraphResponse(&call->response);
    call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
    worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
                           [call, call_opts, wrapped_request,
                            wrapped_response](const Status& s) {
                             call->ClearCancelCallback();
                             delete call_opts;
                             delete wrapped_request;
                             delete wrapped_response;
                             call->SendResponse(ToGrpcStatus(s));
                           });
  });
  ENQUEUE_REQUEST(RunGraph, true);
}

這裡是把計算任務放進執行緒池佇列中,具體業務邏輯在 Worker::RunGraphAsync 函式中。

void Schedule(std::function<void()> f) {
  worker_->env()->compute_pool->Schedule(std::move(f));
}

3.2 GrpcWorker

在 RunGraphAsync 之中,有兩種執行方式,我們選擇 DoRunGraph 來分析。

void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
                           MutableRunGraphResponseWrapper* response,
                           StatusCallback done) {
  if (request->store_errors_in_response_body()) {
    done = [response, done](const Status& status) {
      response->set_status(status);
      done(Status::OK());
    };
  }
  if (request->is_partial()) {
    DoPartialRunGraph(opts, request, response, std::move(done)); // 有興趣讀者可以深入研究
  } else {
    DoRunGraph(opts, request, response, std::move(done)); // 分析這裡
  }
}

DoRunGraph 主要是呼叫了 session->graph_mgr()->ExecuteAsync 來執行計算圖。

void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
                        MutableRunGraphResponseWrapper* response,
                        StatusCallback done) {
  const int64_t step_id = request->step_id();
  Status s = recent_request_ids_.TrackUnique(request->request_id(),
                                             "RunGraph (Worker)", request);
  if (!s.ok()) {
    done(s);
    return;
  }

  std::shared_ptr<WorkerSession> session;
  if (request->create_worker_session_called()) {
    s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
                                                   &session);
  } else {
    session = env_->session_mgr->LegacySession();
  }
  if (!s.ok()) {
    done(s);
    return;
  }
  GraphMgr::NamedTensors in;
  GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
  s = PrepareRunGraph(request, &in, out);
  if (!s.ok()) {
    delete out;
    done(s);
    return;
  }
  StepStatsCollector* collector = nullptr;
  if (request->exec_opts().report_tensor_allocations_upon_oom() ||
      request->exec_opts().record_timeline() ||
      request->exec_opts().record_costs()) {
    collector = new StepStatsCollector(response->mutable_step_stats());
  }
  DeviceProfilerSession* device_profiler_session = nullptr;
  if (collector && request->exec_opts().record_timeline()) {
    // If timeline was requested, assume we want hardware level tracing.
    device_profiler_session = DeviceProfilerSession::Create().release();
  }
  CancellationManager* cm = new CancellationManager;
  opts->SetCancelCallback([this, cm, step_id]() {
    cm->StartCancel();
    AbortStep(step_id);
  });
  CancellationToken token;
  token = cancellation_manager_.get_cancellation_token();
  bool already_cancelled = !cancellation_manager_.RegisterCallback(
      token, [cm]() { cm->StartCancel(); });
  if (already_cancelled) {
    opts->ClearCancelCallback();
    delete cm;
    delete collector;
    delete device_profiler_session;
    delete out;
    done(errors::Aborted("Call was aborted"));
    return;
  }
  session->graph_mgr()->ExecuteAsync(
      request->graph_handle(), step_id, session.get(), request->exec_opts(),
      collector, response, cm, in,
      [this, step_id, response, session, cm, out, token, collector,
       device_profiler_session, opts, done](const Status& status) {
        Status s = status;
        if (s.ok()) {
          // 接受張量
          s = session->graph_mgr()->RecvOutputs(step_id, out);
        }

        opts->ClearCancelCallback();
        cancellation_manager_.DeregisterCallback(token);
        delete cm;

        if (device_profiler_session) {
          device_profiler_session->CollectData(response->mutable_step_stats())
              .IgnoreError();
        }

        if (s.ok()) {
          for (const auto& p : *out) {
            const string& key = p.first;
            const Tensor& val = p.second;
            response->AddRecv(key, val);
          }
        }

        if (collector) collector->Finalize();
        delete collector;
        delete device_profiler_session;
        delete out;
        done(s);
      });
}

3.3 GraphMgr

ExecuteAsync 呼叫了 StartParallelExecutors 完成平行計算,具體邏輯大致為:

  • 找到一個子圖;
  • 計運算元圖 cost;
  • 生成一個 rendezvous,使用本 session 初始化 rendezvous,後續就是用這個 rendezvous 來通訊,rendezvous 利用 session 進行通訊;
  • 傳送張量到 Rendezvous;
  • 呼叫 StartParallelExecutors 執行子計算圖;
void GraphMgr::ExecuteAsync(const string& handle, const int64_t step_id,
                            WorkerSession* session, const ExecutorOpts& opts,
                            StepStatsCollector* collector,
                            MutableRunGraphResponseWrapper* response,
                            CancellationManager* cancellation_manager,
                            const NamedTensors& in, StatusCallback done) {
  const uint64 start_time_usecs = Env::Default()->NowMicros();
  profiler::TraceMeProducer activity(
      // To TraceMeConsumers in ExecutorState::Process/Finish or RunGraphDone.
      [step_id] {
        return profiler::TraceMeEncode(
            "RunGraph", {{"id", step_id}, {"_r", 1} /*root_event*/});
      },
      profiler::ContextType::kTfExecutor, step_id,
      profiler::TraceMeLevel::kInfo);
  
  // Lookup an item. Holds one ref while executing.
  // 找到一個子圖
  Item* item = nullptr;
  {
    mutex_lock l(mu_);
    auto iter = table_.find(handle);
    if (iter != table_.end()) {
      item = iter->second;
      item->Ref();
    }
  }
 
  // 計算cost
  CostGraphDef* cost_graph = nullptr;
  if (response != nullptr) {
    cost_graph = response->mutable_cost_graph();
    if (opts.record_partition_graphs()) {
      for (const ExecutionUnit& unit : item->units) {
        GraphDef graph_def;
        unit.graph->ToGraphDef(&graph_def);
        response->AddPartitionGraph(graph_def);
      }
    }
  }

  // 生成一個rendezvous
  RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
  // 使用本session初始化rendezvous,後續就是用這個rendezvous來通訊,rendezvous 利用session進行通訊
  Status s = rendezvous->Initialize(session); 
  CollectiveExecutor::Handle* ce_handle =
      item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
          ? new CollectiveExecutor::Handle(
                worker_env_->collective_executor_mgr->FindOrCreate(step_id),
                true)
          : nullptr;
  // Sends values specified by the caller.
  // 傳送張量到Rendezvous
  size_t input_size = 0;
  if (s.ok()) {
    std::vector<string> keys;
    std::vector<Tensor> tensors_to_send;
    keys.reserve(in.size());
    tensors_to_send.reserve(in.size());
    for (auto& p : in) {
      keys.push_back(p.first);
      tensors_to_send.push_back(p.second);
      input_size += p.second.AllocatedBytes();
    }
    // 傳送張量
    s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
  }

  if (!s.ok()) {
    done(s);
    delete ce_handle;
    item->Unref();
    rendezvous->Unref();
    return;
  }

  // 執行子計算圖  
  StartParallelExecutors(
      handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
      cancellation_manager, session, start_time_usecs,
      [item, rendezvous, ce_handle, done, start_time_usecs, input_size,
       step_id](const Status& s) {
        profiler::TraceMeConsumer activity(
            // From TraceMeProducer in GraphMgr::ExecuteAsync.
            [step_id] {
              return profiler::TraceMeEncode("RunGraphDone", {{"id", step_id}});
            },
            profiler::ContextType::kTfExecutor, step_id,
            profiler::TraceMeLevel::kInfo);
        done(s);
        metrics::RecordGraphInputTensors(input_size);
        metrics::UpdateGraphExecTime(Env::Default()->NowMicros() -
                                     start_time_usecs);
        rendezvous->Unref();
        item->Unref();
        delete ce_handle;
      });
}

具體大致如下,ExecuteAsync使用handle來查詢Item,進而找到計算圖。其中session用來通訊和執行,step_id與通訊相關,具體可以參見上面程式碼。

StartParallelExecutors 會啟動一個 ExecutorBarrier。當某一個計算裝置執行完所分配的 PartitionGraph 後,ExecutorBarrier 計數器將會增加 1,如果所有裝置都完成 PartitionGraph 列表的執行,barrier.wait() 阻塞操作將退出。

void GraphMgr::StartParallelExecutors(
    const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous,
    CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector,
    CostGraphDef* cost_graph, CancellationManager* cancellation_manager,
    WorkerSession* session, int64_t start_time_usecs, StatusCallback done) {
  const int num_units = item->units.size();
  ScopedStepContainer* step_container = new ScopedStepContainer(
      step_id,
      [this](const string& name) { device_mgr_->ClearContainers({name}); });

  ExecutorBarrier* barrier =
      new ExecutorBarrier(num_units, rendezvous,
                          [this, item, collector, cost_graph, step_container,
                           done](const Status& s) {
                            BuildCostModel(item, collector, cost_graph);
                            done(s);
                            delete step_container;
                          });
  Executor::Args args;
  args.step_id = step_id;
  args.rendezvous = rendezvous;
  args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
  args.cancellation_manager = cancellation_manager;
  args.stats_collector = collector;
  args.step_container = step_container;
  args.sync_on_finish = sync_on_finish_;
  args.start_time_usecs = start_time_usecs;
  if (LogMemory::IsEnabled()) {
    LogMemory::RecordStep(args.step_id, handle);
  }
  thread::ThreadPool* pool = worker_env_->compute_pool;
  using std::placeholders::_1;
  // Line below is equivalent to this code, but does one less indirect call:
  //  args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
  auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
  for (const auto& unit : item->units) {
    thread::ThreadPool* device_thread_pool =
        unit.device->tensorflow_device_thread_pool();
    if (!device_thread_pool) {
      args.runner = default_runner;
    } else {
      args.runner =
          std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1);
    }
    unit.root->RunAsync(args, barrier->Get());
  }
}

3.4 小結

對於註冊/執行子圖,我們用一幅圖來小結一下。

img

圖 1 註冊/執行子圖

4. 總結

我們用一幅圖來把整個分散式計算流程總結如下:

img

圖 2 分散式計算流程

0xFF 參考

相關文章