[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

羅西的思考發表於2022-03-21

[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

在具體介紹 TensorFlow 分散式的各種 Strategy 之前,我們首先需要看看分散式的基礎:分散式環境。只有把基礎打紮實了,才能在以後的分析工作之中最大程度的掃清障礙,事半功倍。本篇介紹 Worker(一系列相關概念) 的靜態架構。

本系列其他文章是:

[翻譯] TensorFlow 分散式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"

[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構

[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯

1. 繼承關係

1.1 角色概念

TensorFlow Worker 類是執行計算的實體,其主要功能是:

  • 接收 Master的請求。
  • 管理 WorkerSession。
  • 處理註冊的子圖,比如按照自己節點上的裝置情況來對子圖進行二次分裂。
  • 在每個裝置上執行註冊的子圖。
  • 支援 worker-to-worker 的張量傳輸等等。具體如何處理依據 worker 和 worker 的位置關係來決定,比如 CPU 和 GPU 之間使用 cudaMemcpyAsync,本地 GPU 之間通過 DMA,遠端 worker 通過 gRPC 或者 RDMA。
  • 執行完畢之後,從計算圖的終止節點 sink 中取出結果。

可以參見 protobuf/worker_service.proto 以瞭解關於每個方法的更多細節。

1.2 介面

對於 WorkerService 的訪問是通過 WorkerInterface 完成的。WorkerInterface 是 worker 的介面類,其是與 TensorFlow Worker service 互動的介面,主要是:

  • 定義了一些非同步虛擬函式,比如 CreateWorkerSessionAsync,派生類將實現它們,這些虛擬函式和 GrpcWorkerService 支援的 GrpcWorkerMethod 一一對應,也和 Protobuf 的配置一一對應。
  • 定義了一些同步函式,比如 CreateWorkerSession,其會通過類似 CallAndWait(&ME::CreateWorkerSessionAsync, request, response) 來呼叫到具體非同步虛擬函式。

1.3 WorkerInterface 派生類

如下圖所示,WorkerInterface 有三種實現。

  • Worker : 這個類可以被子類化,以便為不同的傳輸機制提供特定方法的專門實現。例如,GrpcWorker 專門實現了 RecvTensorAsync() 方法,以支援更有效的 gRPC 資料結構來處理大型二進位制資料。
  • GrpcWorker : 從 Worker 再次派生,是本地模式下的 Worker 角色。如果 Master/Worker 都是在本地,則可以直接呼叫,不需要 RPC 的網路傳輸。
  • GrpcRemoteWorker :分散式模式下,Worker 位於遠端,本地需要使用 GrpcRemoteWorker 來訪問遠端 Worker。
    • GrpcRemoteWorker 是 gRPC 客戶端,其通過 stub 來訪問遠端 Worker 之上的 GrpcWorkerService 服務。
    • GrpcWorkerService 實現了 WorkerService 定義的所有介面,但是實際業務是轉發給本地 GrpcWorker 完成。

具體示例如下:

圖 1 Worker 邏輯關係

2. GrpcRemoteWorker

GrpcRemoteWorker 相當於是遠端 Worker 的一個本地代理。

  • 本地 Master 將計算圖進行分割槽,然後依據分割槽是不在本地還是遠端,分別呼叫本地 Worker 或者 GrpcRemoteWorker 來執行分割槽的子計算圖。
  • 本地 GrpcRemoteWorker 生成是在 tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc 的GetOrCreateWorker 之中。
  • GrpcRemoteWorker 會通過 IssueRequest 向遠端傳送 grpc 請求。
  • 遠端 GrpcWorkerService 守護程式收到請求後,呼叫本地 Worker 處理請求,完成後返回結果。

2.1 定義

具體 GrpcRemoteWorker 程式碼如下,我們省略了部分程式碼,比如 DeleteWorkerSessionAsync 方法的實現等。

class GrpcRemoteWorker : public WorkerInterface {
 public:
  explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
                            ::grpc::CompletionQueue* completion_queue,
                            thread::ThreadPool* callback_threadpool,
                            WorkerCacheLogger* logger, const string& target)
      : channel_(std::move(channel)),
        stub_(channel_),
        cq_(completion_queue),
        callback_threadpool_(callback_threadpool),
        getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
        createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
        deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
        registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
        deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
        rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
        cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
        cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
        recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
        recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
        logging_(Method(GrpcWorkerMethod::kLogging)),
        tracing_(Method(GrpcWorkerMethod::kTracing)),
        completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
        instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
        getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
        markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
        logger_(logger),
        target_(target) {}

  ~GrpcRemoteWorker() override {}

  void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                                CreateWorkerSessionResponse* response,
                                StatusCallback done) override {
    IssueRequest(request, response, createworkersession_, std::move(done));
  }

  void RegisterGraphAsync(const RegisterGraphRequest* request,
                          RegisterGraphResponse* response,
                          StatusCallback done) override {
    IssueRequest(request, response, registergraph_, std::move(done));
  }

  void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
                     RunGraphResponse* response, StatusCallback done) override {
    IssueRequest(request, response, rungraph_, std::move(done), call_opts);
  }
  void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
                     MutableRunGraphResponseWrapper* response,
                     StatusCallback done) override {
    IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
                 rungraph_, std::move(done), call_opts);
  }

 private:
  // Utility method for issuing a generic asynchronous request. The
  // given callback, done, will be called when the RPC completes.
  void IssueRequest(const protobuf::Message* request,
                    protobuf::Message* response, const ::grpc::string& method,
                    StatusCallback done, CallOptions* call_opts = nullptr,
                    bool fail_fast = true) {
    new RPCState<protobuf::Message>(
        &stub_, cq_, method, *request, response, std::move(done), call_opts,
        callback_threadpool_, MaxRetries(), fail_fast, &target_);
  }

  void IssueRequest(const protobuf::Message* request, TensorResponse* response,
                    const ::grpc::string& method, StatusCallback done,
                    CallOptions* call_opts = nullptr) {
    new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
                                 std::move(done), call_opts,
                                 callback_threadpool_, MaxRetries(),
                                 /*fail_fast=*/true, &target_);
  }

  // Helper function for initializing the RpcMethod objects below.
  const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }

  // Helper function for configuring max GRPC retries. Defaults to 0 (no
  // retries).
  const int64_t MaxRetries() {
    int64_t max_retries = -1;
    TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES", 0, &max_retries));
    return max_retries;
  }

  SharedGrpcChannelPtr channel_;
  ::grpc::GenericStub stub_;
  ::grpc::CompletionQueue* cq_;
  thread::ThreadPool* callback_threadpool_;

  const ::grpc::string getstatus_;
  const ::grpc::string createworkersession_;
  const ::grpc::string deleteworkersession_;
  const ::grpc::string registergraph_;
  const ::grpc::string deregistergraph_;
  const ::grpc::string rungraph_;
  const ::grpc::string cleanupgraph_;
  const ::grpc::string cleanupall_;
  const ::grpc::string recvtensor_;
  const ::grpc::string recvbuf_;
  const ::grpc::string logging_;
  const ::grpc::string tracing_;
  const ::grpc::string completegroup_;
  const ::grpc::string instancesource_;
  const ::grpc::string getstepsequence_;
  const ::grpc::string markrecvfinished_;

  // Support for logging.
  WorkerCacheLogger* logger_;
  const string target_;

  TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
};

2.2 生成

生成程式碼如下:

WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
                                     ::grpc::CompletionQueue* completion_queue,
                                     thread::ThreadPool* callback_threadpool,
                                     WorkerCacheLogger* logger,
                                     const string& target) {
  return new GrpcRemoteWorker(std::move(channel), completion_queue,
                              callback_threadpool, logger, target);
}

具體呼叫是在快取之中,程式碼位於:tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc,其會依據引數決定生成何種 Worker。

WorkerInterface* GetOrCreateWorker(const string& target) override {
  if (target == local_target_) {
    return local_worker_;
  } else {
    SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
    if (!channel) {
      return nullptr;
    }
    size_t index = AssignWorkerToThread(target);
    return NewGrpcRemoteWorker(
        channel, worker_env_->GetCompletionQueue(index),
        worker_env_->GetThreadPool(), &logger_, target);
  }
}

2.3 傳送請求

我們接下看看如何傳送請求。CreateWorkerSessionAsync 實際傳送的就是 createworkersession_ 這個字串對應的請求。

  void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                                CreateWorkerSessionResponse* response,
                                StatusCallback done) override {
    IssueRequest(request, response, createworkersession_, std::move(done));
  }

IssueRequest 在上面定義之中有, 重新列出如下,可以看到呼叫的是 method 這個遠端方法,對於我們這裡就是 createworkersession_。

void IssueRequest(const protobuf::Message* request,
                  protobuf::Message* response, const ::grpc::string& method,
                  StatusCallback done, CallOptions* call_opts = nullptr,
                  bool fail_fast = true) {
  new RPCState<protobuf::Message>(
      &stub_, cq_, method, *request, response, std::move(done), call_opts,
      callback_threadpool_, MaxRetries(), fail_fast, &target_);
}

createworkersession_ 是在構建函式之中配置。

explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
                          ::grpc::CompletionQueue* completion_queue,
                          thread::ThreadPool* callback_threadpool,
                          WorkerCacheLogger* logger, const string& target)
    : channel_(std::move(channel)),
      createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)), // 配置

GrpcWorkerMethodName 定義在 tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc 之中,這裡是具體的字串,也就是遠端 GrpcWorker 的方法名字,可以看到,CreateWorkerSessionAsync 實際上呼叫的是 "/tensorflow.WorkerService/CreateWorkerSession"

// Names of worker methods.
enum class GrpcWorkerMethod {
  kGetStatus,
  kCreateWorkerSession,
  kDeleteWorkerSession,
  kRegisterGraph,
  kDeregisterGraph,
  kRunGraph,
  kCleanupGraph,
  kCleanupAll,
  kRecvTensor,
  kRecvBuf,
  kLogging,
  kTracing,
  kCompleteGroup,
  kCompleteInstance,
  kGetStepSequence,
  kMarkRecvFinished,
};

const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
  switch (id) {
    case GrpcWorkerMethod::kGetStatus:
      return "/tensorflow.WorkerService/GetStatus";
    case GrpcWorkerMethod::kCreateWorkerSession:
      return "/tensorflow.WorkerService/CreateWorkerSession";
    case GrpcWorkerMethod::kDeleteWorkerSession:
      return "/tensorflow.WorkerService/DeleteWorkerSession";
    case GrpcWorkerMethod::kRegisterGraph:
      return "/tensorflow.WorkerService/RegisterGraph";
    case GrpcWorkerMethod::kDeregisterGraph:
      return "/tensorflow.WorkerService/DeregisterGraph";
    case GrpcWorkerMethod::kRunGraph:
      return "/tensorflow.WorkerService/RunGraph";
    case GrpcWorkerMethod::kCleanupGraph:
      return "/tensorflow.WorkerService/CleanupGraph";
    case GrpcWorkerMethod::kCleanupAll:
      return "/tensorflow.WorkerService/CleanupAll";
    case GrpcWorkerMethod::kRecvTensor:
      return "/tensorflow.WorkerService/RecvTensor";
    case GrpcWorkerMethod::kRecvBuf:
      return "/tensorflow.WorkerService/RecvBuf";
    case GrpcWorkerMethod::kLogging:
      return "/tensorflow.WorkerService/Logging";
    case GrpcWorkerMethod::kTracing:
      return "/tensorflow.WorkerService/Tracing";
    case GrpcWorkerMethod::kCompleteGroup:
      return "/tensorflow.WorkerService/CompleteGroup";
    case GrpcWorkerMethod::kCompleteInstance:
      return "/tensorflow.WorkerService/CompleteInstance";
    case GrpcWorkerMethod::kGetStepSequence:
      return "/tensorflow.WorkerService/GetStepSequence";
    case GrpcWorkerMethod::kMarkRecvFinished:
      return "/tensorflow.WorkerService/MarkRecvFinished";
  }
  // Shouldn't be reached.
  LOG(FATAL) << "Invalid id: this line shouldn't be reached.";
  return "invalid id";
}

3. Worker Service

WorkerService是一個 gRPC 服務,其定義了一個 TensorFlow 服務。WorkerService 代表MasterService在一組本地裝置上執行資料流圖。 一個 WorkerService 會跟蹤多個 "註冊的計算圖"。每個註冊圖是客戶計算圖的一個子圖,只對應那些應該在這個工作者上執行的節點(以及使用 RecvTensor 方法進行程式間通訊之中所需的任何額外節點)。

Master 會依據 ClusterSpec 內容在叢集之中尋找其他的 Server 例項,找到之後把這些 Server 例項作為 Worker 角色。Master 接著把子圖分發給這些 Worker 節點,然後安排這些 Worker 完成具體子圖的計算過程。Worker 之間如果存在資料依賴,則通過程式間通訊進行互動。無論是 Master 呼叫 Worker,還是 Worker 之間互相訪問,都要遵循 WorkerService 定義的介面規範。WorkerService 的所有介面定義在 worker_service.proto 檔案中。

service WorkerService {
  // See worker.proto for details.
  rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);

  // See worker.proto for details.
  rpc CreateWorkerSession(CreateWorkerSessionRequest)
      returns (CreateWorkerSessionResponse);

  // See worker.proto for details.
  rpc DeleteWorkerSession(DeleteWorkerSessionRequest)
      returns (DeleteWorkerSessionResponse);

  // See worker.proto for details.
  rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse);

  // See worker.proto for details.
  rpc DeregisterGraph(DeregisterGraphRequest) returns (DeregisterGraphResponse);

  // See worker.proto for details.
  rpc RunGraph(RunGraphRequest) returns (RunGraphResponse);

  // See worker.proto for details.
  rpc CleanupGraph(CleanupGraphRequest) returns (CleanupGraphResponse);

  // See worker.proto for details.
  rpc CleanupAll(CleanupAllRequest) returns (CleanupAllResponse);

  // See worker.proto for details.
  rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) {
    // RecvTensor Method
  }

  // See worker.proto for details.
  rpc Logging(LoggingRequest) returns (LoggingResponse);

  // See worker.proto for details.
  rpc Tracing(TracingRequest) returns (TracingResponse);

  // See worker.proto for details.
  rpc RecvBuf(RecvBufRequest) returns (RecvBufResponse) {}

  // See worker.proto for details.
  rpc GetStepSequence(GetStepSequenceRequest) returns (GetStepSequenceResponse);

  // See worker.proto for details.
  rpc CompleteGroup(CompleteGroupRequest) returns (CompleteGroupResponse);

  // See worker.proto for details.
  rpc CompleteInstance(CompleteInstanceRequest)
      returns (CompleteInstanceResponse);
}

3.3.1 WorkerInterface

與 MasterService 類似,對於 WorkerService 的訪問是通過 WorkerInterface 完成的。WorkerInterface 是 worker 的介面類,其是與 TensorFlow Worker service 互動的介面,主要是:

  • 定義了一些非同步虛擬函式,比如 CreateWorkerSessionAsync,派生類將實現它們,這些虛擬函式和 GrpcWorkerService 支援的 GrpcWorkerMethod 一一對應,也和 Protobuf 的配置一一對應。
  • 定義了一些同步函式,比如 CreateWorkerSession,其會通過類似 CallAndWait(&ME::CreateWorkerSessionAsync, request, response) 的方法來呼叫到具體非同步虛擬函式。

我們首先列出其非同步介面如下。

// Interface for talking with the TensorFlow Worker service.
class WorkerInterface {
 public:
  virtual void GetStatusAsync(CallOptions* opts,
                              const GetStatusRequest* request,
                              GetStatusResponse* response, bool fail_fast,
                              StatusCallback done) = 0;

  virtual void CreateWorkerSessionAsync(
      const CreateWorkerSessionRequest* request,
      CreateWorkerSessionResponse* response, StatusCallback done) = 0;

  virtual void DeleteWorkerSessionAsync(
      CallOptions* opts, const DeleteWorkerSessionRequest* request,
      DeleteWorkerSessionResponse* response, StatusCallback done) = 0;

  virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
                                  RegisterGraphResponse* response,
                                  StatusCallback done) = 0;

  virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
                                    DeregisterGraphResponse* response,
                                    StatusCallback done) = 0;

  virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
                             MutableRunGraphResponseWrapper* response,
                             StatusCallback done) = 0;

  virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
                             RunGraphResponse* response, StatusCallback done) {
    RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
    MutableRunGraphResponseWrapper* wrapped_response =
        new NonOwnedProtoRunGraphResponse(response);
    RunGraphAsync(opts, wrapped_request, wrapped_response,
                  [wrapped_request, wrapped_response,
                   done = std::move(done)](const Status& s) {
                    done(s);
                    delete wrapped_request;
                    delete wrapped_response;
                  });
  }

  virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
                                 CleanupGraphResponse* response,
                                 StatusCallback done) = 0;

  virtual void CleanupAllAsync(const CleanupAllRequest* request,
                               CleanupAllResponse* response,
                               StatusCallback done) = 0;

  virtual void RecvTensorAsync(CallOptions* opts,
                               const RecvTensorRequest* request,
                               TensorResponse* response,
                               StatusCallback done) = 0;

  virtual void LoggingAsync(const LoggingRequest* request,
                            LoggingResponse* response, StatusCallback done) = 0;

  virtual void TracingAsync(const TracingRequest* request,
                            TracingResponse* response, StatusCallback done) = 0;

  virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                            RecvBufResponse* response, StatusCallback done) = 0;

  virtual void CompleteGroupAsync(CallOptions* opts,
                                  const CompleteGroupRequest* request,
                                  CompleteGroupResponse* response,
                                  StatusCallback done) = 0;

  virtual void CompleteInstanceAsync(CallOptions* ops,
                                     const CompleteInstanceRequest* request,
                                     CompleteInstanceResponse* response,
                                     StatusCallback done) = 0;

  virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
                                    GetStepSequenceResponse* response,
                                    StatusCallback done) = 0;
}

WorkerInterface 也提供給了同步介面,這樣 Master 或者 Worker 就可以像呼叫本地函式一樣呼叫遠端 WorkerService 的方法。同步介面是在非同步介面之上實現的,通過使用 CallAndWait 介面卡來完成對非同步的封裝。 另外,為了避免外部程式碼非法刪除 WorkerInterface 例項,也做了一些限制,比如其解構函式是 protected,讓 WorkerCacheInterface 成為友元,並且由 WorkerCacheInterface::ReleaseWorker 負責刪除 WorkerInterface 例項。下面是同步介面和一些基礎函式,成員變數。

// Interface for talking with the TensorFlow Worker service.
class WorkerInterface {
 public:

  virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() {
    return new MutableProtoRunGraphRequest;
  }

  virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() {
    return new OwnedProtoRunGraphResponse;
  }

  Status GetStatus(const GetStatusRequest* request,
                   GetStatusResponse* response) {
    Status ret;
    Notification n;
    GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
                   [&ret, &n](const Status& s) {
                     ret = s;
                     n.Notify();
                   });
    n.WaitForNotification();
    return ret;
  }

  Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
                             CreateWorkerSessionResponse* response) {
    return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
  }

  Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request,
                             DeleteWorkerSessionResponse* response) {
    return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request,
                                  response);
  }

  Status RegisterGraph(const RegisterGraphRequest* request,
                       RegisterGraphResponse* response) {
    return CallAndWait(&ME::RegisterGraphAsync, request, response);
  }

  Status DeregisterGraph(const DeregisterGraphRequest* request,
                         DeregisterGraphResponse* response) {
    return CallAndWait(&ME::DeregisterGraphAsync, request, response);
  }

  Status CleanupGraph(const CleanupGraphRequest* request,
                      CleanupGraphResponse* response) {
    return CallAndWait(&ME::CleanupGraphAsync, request, response);
  }

  Status CleanupAll(const CleanupAllRequest* request,
                    CleanupAllResponse* response) {
    return CallAndWait(&ME::CleanupAllAsync, request, response);
  }

  Status Logging(const LoggingRequest* request, LoggingResponse* response) {
    return CallAndWait(&ME::LoggingAsync, request, response);
  }

  Status Tracing(const TracingRequest* request, TracingResponse* response) {
    return CallAndWait(&ME::TracingAsync, request, response);
  }

  Status GetStepSequence(const GetStepSequenceRequest* request,
                         GetStepSequenceResponse* response) {
    return CallAndWait(&ME::GetStepSequenceAsync, request, response);
  }

 protected:
  // Instances of WorkerInterface must be deleted by a call to
  // WorkerCacheInterface::ReleaseWorker().
  virtual ~WorkerInterface() {}
  friend class WorkerCacheInterface;

  // NOTE: This should only be called by implementations of this
  // interface whose CreateRunGraphResponse() method returns a
  // proto-based wrappers for the RunGraphResponse message.
  RunGraphResponse* get_proto_from_wrapper(
      MutableRunGraphResponseWrapper* wrapper) {
    return wrapper->get_proto();
  }

 private:
  typedef WorkerInterface ME;

  template <typename Method, typename Req, typename Resp>
  Status CallAndWait(Method func, const Req* req, Resp* resp) {
    Status ret;
    Notification n;
    (this->*func)(req, resp, [&ret, &n](const Status& s) {
      ret = s;
      n.Notify();
    });
    n.WaitForNotification();
    return ret;
  }

  template <typename Method, typename Req, typename Resp>
  Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) {
    CallOptions call_opts;
    Status ret;
    Notification n;
    (this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) {
      ret = s;
      n.Notify();
    });
    n.WaitForNotification();
    return ret;
  }
};

3.3.2 概念梳理

WorkerService 介面之中牽扯到眾多概念,我們需要仔細梳理一下。

前面提到了,Client 和 Master 之間是通過 session_handle / MasterSession 對 來進行合作,Master 和 Worker 之間就是通過 MasterSession 和 WorkerSession 來完成合作的,MasterSession 會統一管理多個隸屬的 WorkerSession。這裡需要理清楚幾個概念之間的關係:

  • session_handle :目的是為了讓 MasterSession 統一管理其下面的多個 WorkerSession。與 MasterSession 一一對應,在建立 MasterSession 時候生成。通過 CreateSessionResponse 返回給 Client,通過 CreateWorkerSessionRequest 傳送給 Worker,這樣從 Client 到 Master,再到 Worker 這一條鏈路就是由 session_handle 唯一標示。
  • graph_handle :註冊子圖時候,由 GraphMgr::Register 生成,通過 RegisterGraphResponse 返回給 Master。子圖就被該 graph_handle 所標識。在叢集內部則是 (session_handle, graph_handle) 二元組來唯一標識某一個子圖。
  • step_id :因為 Master 會讓多個 Worker 併發執行計算,所以會廣播通知大家執行 RunGraph,為了區別不同的 Step,Master 為每次 RunStep 生成全域性唯一的標識 step_id,通過 RunGraphRequest 訊息把 step_id 攜帶給 Worker。

我們梳理一下 graph_handle。GraphMgr::Register 之中會生成 graph_handle。

Status GraphMgr::Register(
    const string& handle, const GraphDef& gdef, WorkerSession* session,
    const GraphOptions& graph_options, const DebugOptions& debug_options,
    const ConfigProto& config_proto, int64_t collective_graph_key,
    DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
  Item* item = new Item;
  Status s = InitItem(handle, gdef, session, graph_options, debug_options,
                      config_proto, collective_graph_key, cluster_flr, item);
  // Inserts one item into table_.
  {
    mutex_lock l(mu_);
    *graph_handle =
        strings::Printf("%016llx", static_cast<long long>(++next_id_));
    item->handle = *graph_handle;
    CHECK(table_.insert({*graph_handle, item}).second);
  }
  return Status::OK();
}

RegisterGraphResponse 之中會返回 graph_handle 給 Master。

message RegisterGraphResponse {
  // If the registration succeeds, returns an opaque graph_handle to
  // the master. The master calls RunGraph with graph_handle to
  // compute different steps.
  string graph_handle = 1;
}

分割的子圖裡有 graph_handle。

// Graph partitioned into per-location subgraphs.
struct Part {
  // Worker name.
  string name;

  // Maps feed names to rendezvous keys. Empty most of the time.
  std::unordered_map<string, string> feed_key;

  // Maps rendezvous keys to fetch names. Empty most of the time.
  std::unordered_map<string, string> key_fetch;

  // The interface to the worker. Owned.
  WorkerInterface* worker = nullptr;

  // After registration with the worker, graph_handle identifies
  // this partition on the worker.
  string graph_handle;

  Part() : feed_key(3), key_fetch(3) {}
};

註冊返回時候會給子圖設定 graph_handle。

Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
    const PartitionOptions& popts,
    std::unordered_map<string, GraphDef> graph_partitions) {
  partitions_.reserve(graph_partitions.size());
  Status s;
  for (auto& name_def : graph_partitions) {
    partitions_.emplace_back();
    Part* part = &partitions_.back();
    part->name = name_def.first;
    TrackFeedsAndFetches(part, name_def.second, popts);
    part->worker = worker_cache_->GetOrCreateWorker(part->name);
    if (part->worker == nullptr) {
      s = errors::NotFound("worker ", part->name);
      break;
    }
  }
  if (!s.ok()) {
    for (Part& part : partitions_) {
      worker_cache_->ReleaseWorker(part.name, part.worker);
      part.worker = nullptr;
    }
    return s;
  }
  struct Call {
    RegisterGraphRequest req;
    RegisterGraphResponse resp;
    Status status;
  };
  const int num = partitions_.size();
  gtl::InlinedVector<Call, 4> calls(num);
  BlockingCounter done(num);
  for (int i = 0; i < num; ++i) {
    const Part& part = partitions_[i];
    Call* c = &calls[i];
    c->req.set_session_handle(session_handle_);
    c->req.set_create_worker_session_called(!should_deregister_);
    c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
    StripDefaultAttributes(*OpRegistry::Global(),
                           c->req.mutable_graph_def()->mutable_node());
    *c->req.mutable_config_proto() = session_opts_.config;
    *c->req.mutable_graph_options() = session_opts_.config.graph_options();
    *c->req.mutable_debug_options() =
        callable_opts_.run_options().debug_options();
    c->req.set_collective_graph_key(collective_graph_key_);

    auto cb = [c, &done](const Status& s) {
      c->status = s;
      done.DecrementCount();
    };
    part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
  }
  done.Wait();
  for (int i = 0; i < num; ++i) {
    Call* c = &calls[i];
    s.Update(c->status);
    partitions_[i].graph_handle = c->resp.graph_handle();
  }
  return s;
}

使用時候會用 graph_handle 來唯一確定一個子圖。

// Asynchronously deregisters subgraphs on the workers, without waiting for the
// result.
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
  struct Call {
    DeregisterGraphRequest req;
    DeregisterGraphResponse resp;
  };
  for (Part& part : partitions_) {
    // The graph handle may be empty if we failed during partition registration.
    if (!part.graph_handle.empty()) {
      Call* c = new Call;
      c->req.set_session_handle(session_handle_);
      c->req.set_create_worker_session_called(!should_deregister_);
      c->req.set_graph_handle(part.graph_handle);
      // NOTE(mrry): We must capture worker_cache_ since this
      // could be deleted before the callback is called.
      WorkerCacheInterface* worker_cache = worker_cache_;
      const string name = part.name;
      WorkerInterface* w = part.worker;
      CHECK_NOTNULL(w);
      auto cb = [worker_cache, c, name, w](const Status& s) {
         delete c;
        worker_cache->ReleaseWorker(name, w);
      };
      w->DeregisterGraphAsync(&c->req, &c->resp, cb);
    }
  }
}

3.3.4 WorkerInterface 派生類

如下圖所示,WorkerInterface 有兩種實現。

  • GrpcWorker : 本地模式下的Worker 角色,如果 Master/Worker都是在本地,則可以直接呼叫,不需要 RPC 的網路傳輸。
  • GrpcRemoteWorker :分散式模式下,Worker 位於遠端,本地需要使用 GrpcRemoteWorker 來訪問遠端 Worker。
    • GrpcRemoteWorker 是 gRPC 客戶端,其通過 stub 來訪問遠端 Worker 之上的 GrpcWorkerService 服務。
    • GrpcWorkerService 實現了 WorkerService 定義的所有介面,但是實際業務是轉發給本地 GrpcWorker 完成。

具體示例如下:

圖 1 WorkerInterface 派生類

3.3.5 使用

Server 初始化時候,用如下程式碼建立Worker Service。

  // 建立 GrpcWorker 以及對應的 GrpcWorkerService
  worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
                                  : NewGrpcWorker(&worker_env_, config);
  worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
                                         opts.worker_service_options)


具體就是返回 GrpcWorkerService。

// Returns an implementation of WorkerService rpc service.
std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
    GrpcWorker* worker, ::grpc::ServerBuilder* builder,
    GrpcWorkerServiceOptions options) {
  return std::unique_ptr<AsyncServiceInterface>(
      new GrpcWorkerService(worker, builder, options));
}


GrpcServer 之中,使用 worker_thread_ 執行緒來執行 GrpcWorkerService 的 HandleRPCsLoop 方法。

worker_thread_.reset(
    env_->StartThread(ThreadOptions(), "TF_worker_service",
                      [this] { worker_service_->HandleRPCsLoop(); }));


3.3.6 定義

GrpcWorkerService 定義如下,因為其需要作為守護程式處理傳入的 gRPC 請求,所以在建構函式之中會建立若干執行緒,用來響應請求,然後在 HandleRPCsLoop 之中會啟動這些執行緒,然後做 Join。

class GrpcWorkerService : public AsyncServiceInterface {
 public:
  GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
                    GrpcWorkerServiceOptions options)
      : is_shutdown_(false) {
    builder->RegisterService(&worker_service_);

    for (int i = 0; i < options.num_serving_threads; i++) {
      threads_.emplace_back(
          new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
                                      cache_.get(), &worker_service_));
    }
  }

  // This method blocks forever handling requests from the completion queue.
  void HandleRPCsLoop() override {
    for (auto& worker_thread : threads_) {
      worker_thread->Start();
    }
    for (auto& worker_thread : threads_) {
      worker_thread->Join();
    }
  }

 private:
  grpc::WorkerService::AsyncService worker_service_;
  std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;

  std::unique_ptr<GrpcResponseCache> cache_;
  mutex service_shutdown_mu_;
  bool is_shutdown_ TF_GUARDED_BY(service_shutdown_mu_);

  TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
};

3.3.7 執行緒

具體迴圈和響應請求其實是線上程之中完成的,cq_ 則是 grpc 的完成佇列。

// GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
// requests.  Each thread operates on an independent completion queue.
class GrpcWorkerServiceThread {
 public:
  explicit GrpcWorkerServiceThread(
      GrpcWorker* worker, ::grpc::ServerBuilder* builder,
      std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
      grpc::WorkerService::AsyncService* worker_service)
      : worker_(worker),
        queue_depth_(queue_depth),
        cache_(cache),
        worker_service_(worker_service),
        is_shutdown_(false) {
    cq_ = builder->AddCompletionQueue();
  }

  void Start() {
    thread_.reset(
        worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
                                         [this]() { HandleRPCsLoop(); }));
  }
}

主迴圈

GrpcWorkerServiceThread::HandleRPCsLoop 是執行緒主迴圈,和 master service 類似。這裡先準備好一些 gRPC 呼叫的等待佇列,這些呼叫請求與後面的 GrpcWorkerMethod 一一對應,每個方法對應的處理過程的程式碼會在後面提到。

// Add one or more completion queue entries for each worker method, then
// begin servicing requests from the completion queue.
void GrpcWorkerServiceThread::HandleRPCsLoop() {
  // TODO(ncteisen): This may require performance engineering. We can
  // change the number of threads, the number of handlers per thread,
  // or even decide to specialize certain threads to certain methods.
  SETUP_FOR_REQUEST(GetStatus, 1, false);
  SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
  SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
  SETUP_FOR_REQUEST(CleanupAll, 1, false);
  SETUP_FOR_REQUEST(RegisterGraph, 1, false);
  SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
  SETUP_FOR_REQUEST(Logging, 1, false);
  SETUP_FOR_REQUEST(Tracing, 1, false);
  SETUP_FOR_REQUEST(CompleteGroup, 10, true);
  SETUP_FOR_REQUEST(CompleteInstance, 10, true);
  SETUP_FOR_REQUEST(GetStepSequence, 10, true);
  SETUP_FOR_REQUEST(RecvBuf, 500, true);
  SETUP_FOR_REQUEST(RunGraph, 100, true);
  SETUP_FOR_REQUEST(CleanupGraph, 100, false);
  SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);

  // TODO(ncteisen): Determine a better policy for enqueuing the
  // appropriate number of each request type.
  for (int i = 0;
       i < gtl::FindWithDefault(
               queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
               1000);
       ++i) {
    EnqueueRecvTensorRequestRaw();
  }

  void* tag;
  bool ok;

  while (cq_->Next(&tag, &ok)) {
    UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
        static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
    CHECK(callback_tag);
    callback_tag->OnCompleted(this, ok);
  }
}
grpc request

對於 request 的處理與 master 類似。每個 request 會呼叫到一個業務 handler,如下面巨集定義的 GrpcWorkerServiceThread::method##Handler。

#define ENQUEUE_REQUEST(method, supports_cancel)                             \
  do {                                                                       \
    mutex_lock l(shutdown_mu_);                                              \
    if (!is_shutdown_) {                                                     \
      Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,       \
           method##Request, method##Response>::                              \
          EnqueueRequestForMethod(                                           \
              worker_service_, cq_.get(),                                    \
              static_cast<int>(GrpcWorkerMethod::k##method),                 \
              &GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
    }                                                                        \
  } while (0)

#define SETUP_FOR_REQUEST(method, default_depth, supports_cancel)              \
  for (int i = 0;                                                              \
       i < gtl::FindWithDefault(queue_depth_,                                  \
                                static_cast<int>(GrpcWorkerMethod::k##method), \
                                default_depth);                                \
       ++i) {                                                                  \
    ENQUEUE_REQUEST(method, supports_cancel);                                  \
  }

這裡需要把每個 RPC 服務註冊為非同步服務,這使用 gRPC 自帶的 AddMethod 介面和 MarkMethodAsync 介面來完成。

WorkerService::AsyncService::AsyncService() {
  for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
    AddMethod(new ::grpc::internal::RpcServiceMethod(
        GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
        ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
    ::grpc::Service::MarkMethodAsync(i);
  }
}
Handler & 執行緒池

具體 Handler 是通過巨集來配置的,具體如下,這裡呼叫了 Call,其會依據配置來決定是否使用執行緒池 compute_pool->Schedule 來進行計算。這裡就用到了 worker env 裡面整合的模組。

  // Handle all non-cancellable simple methods with a standard wrapper.
  // The boolean may_block_on_compute_pool indicates whether or not the
  // operation may block on activities (such as op execution) that run on the
  // compute pool.
#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if (!s.ok()) {                                                          \
        VLOG(3) << "Bad response from " << #method << ": " << s;              \
      }                                                                       \
      call->SendResponse(ToGrpcStatus(s));                                    \
    };                                                                        \
    if ((may_block_on_compute_pool)) {                                        \
      worker_->env()->env->SchedClosure(std::move(closure));                  \
    } else {                                                                  \
      worker_->env()->compute_pool->Schedule(std::move(closure));             \
    }                                                                         \
    ENQUEUE_REQUEST(method, false);                                           \
  }

  HANDLE_CALL(GetStatus, false);
  HANDLE_CALL(CreateWorkerSession, false);
  HANDLE_CALL(DeleteWorkerSession, true);
  HANDLE_CALL(CleanupAll, false);
  HANDLE_CALL(RegisterGraph, false);
  HANDLE_CALL(DeregisterGraph, false);
  HANDLE_CALL(CleanupGraph, false);
  HANDLE_CALL(Logging, false);
  HANDLE_CALL(Tracing, false);

#undef HANDLE_CALL
訊息&方法

GrpcWorkerMethod 定義了 worker 具體有哪些方法。

// Names of worker methods.
enum class GrpcWorkerMethod {
  kGetStatus,
  kCreateWorkerSession,
  kDeleteWorkerSession,
  kRegisterGraph,
  kDeregisterGraph,
  kRunGraph,
  kCleanupGraph,
  kCleanupAll,
  kRecvTensor,
  kRecvBuf,
  kLogging,
  kTracing,
  kCompleteGroup,
  kCompleteInstance,
  kGetStepSequence,
  kMarkRecvFinished,
};

具體這些訊息名字對應哪些方法,就是由 GrpcWorkerMethodName 完成。

const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
  switch (id) {
    case GrpcWorkerMethod::kGetStatus:
      return "/tensorflow.WorkerService/GetStatus";
    case GrpcWorkerMethod::kCreateWorkerSession:
      return "/tensorflow.WorkerService/CreateWorkerSession";
    case GrpcWorkerMethod::kDeleteWorkerSession:
      return "/tensorflow.WorkerService/DeleteWorkerSession";
    case GrpcWorkerMethod::kRegisterGraph:
      return "/tensorflow.WorkerService/RegisterGraph";
    case GrpcWorkerMethod::kDeregisterGraph:
      return "/tensorflow.WorkerService/DeregisterGraph";
    case GrpcWorkerMethod::kRunGraph:
      return "/tensorflow.WorkerService/RunGraph";
    case GrpcWorkerMethod::kCleanupGraph:
      return "/tensorflow.WorkerService/CleanupGraph";
    case GrpcWorkerMethod::kCleanupAll:
      return "/tensorflow.WorkerService/CleanupAll";
    case GrpcWorkerMethod::kRecvTensor:
      return "/tensorflow.WorkerService/RecvTensor";
    case GrpcWorkerMethod::kRecvBuf:
      return "/tensorflow.WorkerService/RecvBuf";
    case GrpcWorkerMethod::kLogging:
      return "/tensorflow.WorkerService/Logging";
    case GrpcWorkerMethod::kTracing:
      return "/tensorflow.WorkerService/Tracing";
    case GrpcWorkerMethod::kCompleteGroup:
      return "/tensorflow.WorkerService/CompleteGroup";
    case GrpcWorkerMethod::kCompleteInstance:
      return "/tensorflow.WorkerService/CompleteInstance";
    case GrpcWorkerMethod::kGetStepSequence:
      return "/tensorflow.WorkerService/GetStepSequence";
    case GrpcWorkerMethod::kMarkRecvFinished:
      return "/tensorflow.WorkerService/MarkRecvFinished";
  }
  // Shouldn't be reached.
  return "invalid id";
}

在 AsyncService 之中會呼叫 GrpcWorkerMethodName 完成給 grpc 註冊。

WorkerService::AsyncService::AsyncService() {
  for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
    AddMethod(new ::grpc::internal::RpcServiceMethod(
        GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
        ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
    ::grpc::Service::MarkMethodAsync(i);
  }
}

業務處理

具體業務處理則是呼叫了 Worker 完成的。

void GetStepSequenceHandler(
    WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
  Schedule([this, call]() {
    worker_->GetStepSequenceAsync(
        &call->request, &call->response, [call](const Status& s) {
          call->SendResponse(ToGrpcStatus(s));
        });
  });
  ENQUEUE_REQUEST(GetStepSequence, true);
}

目前從執行緒角度看,邏輯如下,這裡假定有三個執行緒。Server 的 worker_thread_ 啟動了 GrpcWorkerService::HandleRPCsLoop(),其作用就是啟動兩個 GrpcWorkerServiceThread,每個 GrpcWorkerServiceThread 在 GrpcWorkerServiceThread::HandleRPCsLoop 之中會響應 gRPC 請求,進行業務處理。這裡需要注意,GrpcWorkerService 和 GrpcWorkerServiceThread 都有 HandleRPCsLoop 這個方法。

圖 2 執行緒角度

3.3.8 業務邏輯

CreateWorkerSession

CreateWorkerSessionRequest 訊息之中會傳遞 MasterSession對應的 session_handle,Worker 接收訊息之後,生成一個 WorkerSession。在一個叢集之內,當 MasterSession 建立 WorkerSession 時候,都會把自己對應的 session_handle 傳過去,這樣,WorkerSession 就可以通過 session_handle 知道自己屬於哪個 MasterSession。MasterSession 例項也可以統一管理隸屬於它的所有 WorkerSession。

GrpcWorker 通過 SessionMgr 來具體完成對 WorkerSession 的管理,既可以通過 master task name 來確定 WorkerSession,也可以通過 session_handle 來確定。

class SessionMgr {

  WorkerEnv* const worker_env_;  // Not owned.
  std::unique_ptr<WorkerCacheInterface> default_worker_cache_;
  std::shared_ptr<WorkerSession> legacy_session_;
  const WorkerCacheFactory worker_cache_factory_;

  // A map from session identifier to internal session structure.
  std::map<string, std::shared_ptr<WorkerSession>> sessions_ TF_GUARDED_BY(mu_);

  // Incarnation and WorkerSession handle associated with a master task.
  struct MasterAssociatedSession {
    const int64_t master_incarnation;
    const string session_handle;
  };
  // A map from master task name to its associated worker sessions.
  std::unordered_multimap<string, MasterAssociatedSession>
      master_to_associated_sessions_ TF_GUARDED_BY(mu_);
};

具體訊息如下,注意,CreateWorkerSessionResponse 沒有返回任何東西:

message CreateWorkerSessionRequest {
  // Sessions are identified by a given handle.
  string session_handle = 1;

  // Defines the configuration of a TensorFlow worker.
  ServerDef server_def = 2;

  // If true, any resources such as Variables used in the session will not be
  // shared with other sessions.
  bool isolate_session_state = 3;

  // The device attributes of all the devices in the cluster.
  repeated DeviceAttributes cluster_device_attributes = 4;

  // The master task name from which the request is sent.
  string master_task = 5;

  // The incarnation ID of the master task local CPU device.
  // If the target worker already has a WorkerSession created previously with
  // the same master task name but a different incarnation, it usually indicates
  // that the previous master failed before deleting the WorkerSession on the
  // worker. To prevent memory leaks, the worker should garbage collect the old
  // WorkerSessions.
  int64 master_incarnation = 6;
}

message CreateWorkerSessionResponse {}

圖 3 CreateWorkerSession

如前所述,GrpcWorker 這些訊息都是用巨集來生成的。

#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if (!s.ok()) {                                                          \
        VLOG(3) << "Bad response from " << #method << ": " << s;              \
      }                                                                       \
      call->SendResponse(ToGrpcStatus(s));                                    \
    };                                                                        \
    if ((may_block_on_compute_pool)) {                                        \
      worker_->env()->env->SchedClosure(std::move(closure));                  \
    } else {                                                                  \
      worker_->env()->compute_pool->Schedule(std::move(closure));             \
    }                                                                         \
    ENQUEUE_REQUEST(method, false);                                           \
  }

  HANDLE_CALL(GetStatus, false);
  HANDLE_CALL(CreateWorkerSession, false);
  HANDLE_CALL(DeleteWorkerSession, true);
  HANDLE_CALL(CleanupAll, false);
  HANDLE_CALL(RegisterGraph, false);
  HANDLE_CALL(DeregisterGraph, false);
  HANDLE_CALL(CleanupGraph, false);
  HANDLE_CALL(Logging, false);
  HANDLE_CALL(Tracing, false);

RegisterGraph

RegisterGraphRequest 訊息會傳送 MasterSession 對應的 session_handle,子圖 graph_def。當 Worker 接收訊息,完成子圖註冊/初始化後,會返回該子圖的 graph_handle 給 Master。

對於每個會話,在 master 將每個節點放在一個裝置上之後,它將整個圖分割成許多子圖。一個子圖中的所有節點都在同一個 worker 中,但可能在該 worker 擁有的許多裝置上(例如cpu0,加上gpu0、gpu1、...、gpu7)。在執行任何step之前,master 為 worker 註冊了子圖。成功的註冊會返回一個圖的控制程式碼,以便在以後的 RunGraph請求中使用。

////////////////////////////////////////////////////////////////////////////////
//
// RegisterGraph method request/response messages
//
// For each session, after the master placed every node on a device,
// it partitions the whole graph into many subgraphs. All the nodes in
// a subgraph were in the same worker, but potentially on many devices
// owned by that worker (e.g. cpu0, plus gpu0, gpu1, ..., gpu7). The
// master registers subgraphs for a worker before running any steps. A
// successful registration returns a graph handle to be used in latter
// RunGraph requests.
//
////////////////////////////////////////////////////////////////////////////////

message RegisterGraphRequest {
  // Subgraphs are scoped within one session.
  string session_handle = 1;

  // Set to true if CreateWorkerSession was called for session_handle.
  bool create_worker_session_called = 6;

  // "graph_def" has the subgraph of nodes for this worker, with each node
  // having its device_name filled in.
  GraphDef graph_def = 2;

  // True iff the graph (before partitioning) contains control flow nodes.
  //
  // As of 01/11/2015, this is no longer set by clients.
  bool has_control_flow = 3 [deprecated = true];

  // Configuration options for the session in which this graph was created.
  GraphOptions graph_options = 4;

  // Field(s) used by TensorFlow Debugger (tfdbg).
  DebugOptions debug_options = 5;

  // If graph_def contains any collective ops this must be a positive
  // integer used to coordinate execution with other graphs.  All
  // graphs in a distributed execution with the same
  // collective_graph_key will coordinate to use the same step_id
  // concurrently so that BufRendezvous entries will make the correct
  // values accessible.
  int64 collective_graph_key = 7;

  // ConfigProto from the session in which this graph was created.
  // Contains additional parameters beyond graph_options, including
  // the name of the requested executor.
  ConfigProto config_proto = 8;
}

message RegisterGraphResponse {
  // If the registration succeeds, returns an opaque graph_handle to
  // the master. The master calls RunGraph with graph_handle to
  // compute different steps.
  string graph_handle = 1;
}

圖 4 RegisterGraph

DeregisterGraph

當不再需要計算圖時(例如,整個計算圖圖被重新排程,圖內節點被重新編排),Master 會利用該圖對應的 graph_handle來取消註冊。在 Master 重啟情況下,Worker 根據以 TTL 為基礎的策略自動取消對應 graph_handle 的註冊。

////////////////////////////////////////////////////////////////////////////////
//
// DeregisterGraph method request/response messages
//
// The master deregisters the given graph_handle when the graph is no
// longer needed (e.g., the overall graph is re-scheduled and nodes
// are re-placed).
//
// The worker deregisters a graph_handle automatically according to on
// a TTL-base policy in case of master restarts.
//
////////////////////////////////////////////////////////////////////////////////

message DeregisterGraphRequest {
  // The session_handle used when registering the graph. If session_handle is
  // empty, a single global namespace is used.
  string session_handle = 2;

  // Set to true if CreateWorkerSession was called for session_handle.
  bool create_worker_session_called = 3;

  // REQUIRED: graph_handle must be returned by a RegisterGraph call
  // to the same WorkerService.
  string graph_handle = 1;
}

message DeregisterGraphResponse {
  // TODO(mrry): Optionally add summary stats for the graph.
}

圖 5 DeregisterGraph

RunGraph

Master 用 RunGraphRequest 來執行在 graph_handle下注冊的所有子圖。

Master 會生成一個全域性唯一的 step_id 來區分圖計算的不同執行 step。子圖之間可以使用 step_id 進行彼此通訊(例如,傳送/轉發操作),以區分不同執行產生的張量。

RunGraphRequest 訊息的 send 表示子圖輸入的張量,recv_key 指明子圖輸出的張量。RunGraphResponse 會返回 recv_key 對應的 Tensor 列表。

圖 6 RunGraph

////////////////////////////////////////////////////////////////////////////////
//
// RunGraph request / response messages
//
// The worker executes all subgraphs registered under graph_handle.
// RunGraph returns after the execution finishes or an error is
// encountered.
// A sequence of RunGraphRequests with is_partial may be sent to RunGraph for
// partial graph execution.
//
////////////////////////////////////////////////////////////////////////////////

// Options specific to the execution of a single step.
message ExecutorOpts {
  bool record_costs = 1;
  bool record_timeline = 3;
  bool record_partition_graphs = 4;
  bool report_tensor_allocations_upon_oom = 5;
}

message RunGraphRequest {
  // session_handle is the master-generated unique id for this session.
  // If session_handle is non-empty, it must be the same as used when
  // registering the graph. If it is empty, a single global namespace is used to
  // search for the graph_handle.
  string session_handle = 8;

  // Set to true if CreateWorkerSession was called for session_handle.
  bool create_worker_session_called = 10;

  // REQUIRED: graph_handle must be returned by a RegisterGraph call
  // to the same WorkerService.
  string graph_handle = 1;

  // A unique ID to distinguish different runs of the same graph.
  //
  // The master generates a global unique step_id to distinguish
  // different runs of the graph computation. Subgraphs communicate
  // (e.g., send/recv ops) with each other using step_id to
  // distinguish tensors generated by different runs.
  int64 step_id = 2;

  // Options for this step.
  ExecutorOpts exec_opts = 5;

  // Runs the graph.
  //
  // Sends the tensors in "send" into the graph before the run and
  // fetches the keys into RunGraphResponse.recv after the run.
  repeated NamedTensorProto send = 3;
  repeated string recv_key = 4;

  // True if the RunGraphRequest is a partial run request.
  bool is_partial = 6;
  // True if this is the last partial run request in a sequence of requests.
  bool is_last_partial_run = 7;

  // If true then some errors, e.g., execution errors that have long
  // error messages, may return an OK RunGraphResponse with the actual
  // error saved in the status_code/status_error_message fields of the
  // response body. This is a workaround since the RPC subsystem may
  // truncate long metadata messages.
  bool store_errors_in_response_body = 9;

  // Unique identifier for this request. Every RunGraphRequest must have a
  // unique request_id, and retried RunGraphRequests must have the same
  // request_id. If request_id is zero, retry detection is disabled.
  //
  // Retried RunGraphRequests are problematic because they may issue a
  // RecvTensor that will have no corresponding sender and will wait forever.
  // Workers use request_ids to reject retried RunGraph requests instead of
  // waiting forever.
  int64 request_id = 11;

  // Next: 12
}

message RunGraphResponse {
  // A list of tensors corresponding to those requested by
  // RunGraphRequest.recv_key.
  repeated NamedTensorProto recv = 1;

  // If the request asked for execution stats, the cost graph, or the partition
  // graphs, these are returned here.
  // TODO(suharshs): Package these in a RunMetadata instead.
  StepStats step_stats = 2;
  CostGraphDef cost_graph = 3;
  repeated GraphDef partition_graph = 4;

  // If store_errors_in_response_body is true in the request, then
  // optionally the server may return an OK status for the RPC and
  // fill the true status into the fields below, to allow for messages
  // that are too long to fit in metadata.
  error.Code status_code = 5;
  string status_error_message = 6;
}

RecvTensor

在具體執行之中,兩個 Worker 之間可能會交換資料,此時生產者只是把準備好的張量放入 rendezvous,消費者會主動發起 RecvTensorRequest 請求,RecvTensorRequest 裡面 step_id 標識是哪次 step,rendezvous_key 標識要接收張量的通道(channel)。

一個 RecvTensor 請求從通道中獲取一個張量,也可以通過多個 RecvTensor 請求在同一個通道中傳送和接收多個張量。最終生產者的張量會通過 RecvTensorResponse 返回給消費者。

圖 7 RecvTensor

////////////////////////////////////////////////////////////////////////////////
//
// RecvTensor method request/response messages
//
////////////////////////////////////////////////////////////////////////////////

message RecvTensorRequest {
  // The step in which the tensor will be produced.
  //
  // REQUIRED: This must eventually correspond to the step_id passed
  // into a RunGraph call on the same WorkerService.
  int64 step_id = 1;

  // A key identifying the channel to receive tensors from. A RecvTensor request
  // retrieves one tensor from the channel, but multiple tensors can be sent and
  // received over the same channel with multiple RecvTensor requests. See
  // rendezvous.h for details.
  string rendezvous_key = 2;

  // If true, use an out-of-band DMA mechanism to transfer the
  // received tensor.
  bool dma_ok = 3;

  // Optional information on client-side device locality.
  DeviceLocality client_locality = 4;

  // Optional information on server-side device locality.
  DeviceLocality server_locality = 5;

  // Optional information needed by the RPC subsystem.
  google.protobuf.Any transport_options = 6;

  // Unique identifier for this request. Every RecvTensorRequest must have a
  // unique request_id, and retried RecvTensorRequests must have the same
  // request_id. If request_id is zero, retry detection and response cache
  // are disabled.
  //
  // Retried RecvTensorRequests are problematic because a RecvTensor with no
  // corresponding sender will wait forever, and the tensor may have been
  // delivered to a previous retry. Workers use request_ids to reject retried
  // RecvTensor requests instead of waiting forever.
  int64 request_id = 7;
}

message RecvTensorResponse {
  // The tensor as a proto.
  TensorProto tensor = 1;

  // If true, this tensor was the output of a dead node, and the
  // content is invalid.
  bool is_dead = 2;

  // The time at which tensor was available and started to be returned.
  int64 send_start_micros = 3;

  // Optional additional information about how to receive the tensor,
  // e.g. in the event that RecvTensorRequest.dma_ok was true.
  google.protobuf.Any transport_options = 4;

  // Whether the receiver should send a MarkRecvFinishedRequest to the sender
  // to ack the message.
  bool require_ack = 5;
}

4. Worker

Worker 類主要是提供了 WorkerEnv 和 PartialRunMgr,其可以被子類化,以便為不同的傳輸機制提供特定方法的專門實現。例如,GrpcWorker 專門實現了 RecvTensorAsync 方法,以支援更有效的 gRPC 資料結構來處理大型二進位制資料。

class Worker : public WorkerInterface {
 protected:
  WorkerEnv* const env_;  // Not owned.
  RecentRequestIds recent_request_ids_;

 private:
  PartialRunMgr partial_run_mgr_;

  CancellationManager cancellation_manager_;

  TF_DISALLOW_COPY_AND_ASSIGN(Worker);
};

我們舉出一個方法看看,具體其他方法我們後面遇到了會說。

void Worker::CleanupAllAsync(const CleanupAllRequest* request,
                             CleanupAllResponse* response,
                             StatusCallback done) {
  std::vector<string> containers;
  for (const auto& c : request->container()) containers.push_back(c);
  env_->device_mgr->ClearContainers(containers);
  done(Status::OK());
}

5. GrpcWorker

GrpcWorker 是 GrpcRemoteWorker 對應的遠端 Worker。也是 GrpcWorkerService 呼叫的物件,其實現了業務邏輯。其定義如下,我們可以看到其實現了幾個方法。

class GrpcWorker : public Worker {
 public:
  GrpcWorker(WorkerEnv* env, const ConfigProto& config);

  // Specialized version of RecvTensor for gRPC, which avoids a copy.
  virtual void GrpcRecvTensorAsync(CallOptions* opts,
                                   const RecvTensorRequest* request,
                                   ::grpc::ByteBuffer* response,
                                   StatusCallback done);

  void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
                    StatusCallback done) override;

  void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                    RecvBufResponse* response, StatusCallback done) override;

  void CleanupGraphAsync(const CleanupGraphRequest* request,
                         CleanupGraphResponse* response,
                         StatusCallback done) override;

  WorkerEnv* env();

  void EnableResponseCache();

  void RemoveCacheEntryForId(int64 request_id);

 private:
  std::unique_ptr<GrpcResponseCache> response_cache_;
  const int32 recv_buf_max_chunk_;
};

至此,Worker 的靜態結構我們已經介紹完畢,具體 Worker 功能我們將在後文 Session 部分進行具體介紹。

0xFF 參考

TensorFlow Internals

TensorFlow架構與設計:概述

TensorFlow核心剖析

TensorFlow架構與設計:OP本質論

[譯] TensorFlow 白皮書

2017TensorFlow開發者峰會

https://jcf94.com/2018/02/28/2018-02-28-tfunpacking3/

TensorFlow 拆包(五):Distributed

TensorFlow Architecture

『深度長文』Tensorflow程式碼解析(五)

什麼是in-graph replication和between-graph replication?

[騰訊機智] TensorFlow原始碼解析(1): 建立會話

05tensorflow分散式會話

第八節,配置分散式TensorFlow

TensorFlow 分散式(Distributed TensorFlow)

tensorflow原始碼解析之distributed_runtime

Distributed TensorFlow: A Gentle Introduction

一文說清楚Tensorflow分散式訓練必備知識

TensorFlow中的Placement啟發式演算法模組——Placer

TensorFlow的圖切割模組——Graph Partitioner

TensorFlow中的通訊機制——Rendezvous(一)本地傳輸

TensorFlow分散式採坑記

TensorFlow技術內幕(九):模型優化之分散式執行

Tensorflow架構流程]

相關文章