[原始碼解析] TensorFlow 分散式環境(8) --- 通訊機制

羅西的思考發表於2022-04-06

[原始碼解析] TensorFlow 分散式環境(8) --- 通訊機制

當計算圖在裝置之間劃分之後,跨裝置的 PartitionGraph 之間可能存在著資料依賴關係,因此 TF 在它們之間插入 Send/Recv 節點,這樣就完成資料互動。而在分散式模式之中,Send/Recv 通過 RpcRemoteRendezvous 完成資料交換,所以我們需要先看看 TF 之中的資料交換機制 Rendezvous。

迄今為止,在分散式機器學習之中,我們看到了太多的 Rendezvous,其大多出現在彈性和通訊相關部分,雖然具體意義各有細微不同,但是基本意義都差不多,就是來自其法語單詞的原意:會合,聚會,集會,約會等。TensorFlow的Rendezvous是訊息傳輸的通訊元件和交換機制。

本文依舊深度借鑑了兩位大神:

[TensorFlow Internals] (https://github.com/horance-liu/tensorflow-internals),雖然其分析的不是最新程式碼,但是建議對 TF 內部實現機制有興趣的朋友都去閱讀一下,絕對大有收穫。
https://home.cnblogs.com/u/deep-learning-stacks/ 西門宇少,不僅僅是 TensorFlow,其公共號還有更多其他領域,業界前沿。

本系列其他文章是:

[翻譯] TensorFlow 分散式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"

[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構

[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache

[原始碼解析] TensorFlow 分散式環境(5) --- Session

[原始碼解析] TensorFlow 分散式環境(7) --- Worker 動態邏輯

1. 機制

在分散式模式之中,對跨裝置的邊會進行分裂,在邊的傳送端和接收端會分別插入 Send 節點和 Recv 節點。

  • 程式內的 Send 和 Recv 節點通過 IntraProcessRendezvous 實現資料交換。
  • 程式間的 Send 和 Recv 節點通過 GrpcRemoteRendezvous 實現資料交換。

我們假設 Worker 0 有兩個 GPU,當插入Send 節點和 Recv 節點,效果如下,其中 Worker 1 傳送給 Worker 之間的代表程式間通過 GrpcRemoteRendezvous 實現資料交換,Worker 0 內部兩個 GPU 之間的虛線箭頭代表程式內部通過 IntraProcessRendezvous 實現資料交換,Worker 之間的實線箭頭表示使用 RPC 進行資料交換。

當執行某次 step,如果兩個 Worker 需要互動資料,則:

  • 生產者 Sender 會先生成張量,放入本地 Table。
  • 消費者 Receiver 向生產者傳送 RecvTensorRequest 訊息,訊息之中攜帶二元組 (step_id, rendezvous_key)
  • 生產者端 Worker 會從本地 Table 獲取相應的 Tensor 資料,並通過 RecvTensorResponse 返回。

其中send/recv 的資料傳輸是通過 WorkerInterface 的派生類作為介面完成的,WorkerInterface 則基於底層的 gRPC 通訊庫。

圖 1 傳送/接受

1.1 訊息識別符號

我們在學習 PyTorch 分散式時候,就知道每次分散式通訊都需要有一個全域性唯一的識別符號,比如:

  • 使用 autogradMessageId 來表示一對 send/recv autograd 函式。每 send-recv 對被分配一個全域性唯一的autograd_message_id 以唯一地標識該send-recv對。這對於在向後傳播期間查詢遠端節點上的相應函式很有用。
  • 此容器還負責維護全域性唯一的訊息 id,用來關聯傳送/接收自動微分函式對。格式是一個 64 位整數,前 16 位是工作者 id,後 48 位是 worker 內部自動遞增的整數。

類似的,TF 也需要為每一個Send/Recv Pair 確定一個唯一的識別符號,這樣在多組訊息並行傳送時候,才不會發生訊息錯位。這個識別符號就是 ParsedKey。

1.1.1 定義

其定義如下:

  • src_device:傳送裝置。
  • src:和 src_device 資訊相同,只不過是表示為結構體。
  • src_incarnation:用於 debug,某個 worker 重啟後,該值會發生變化,這樣就可以區分之前掛掉的worker。
  • dst_device:接收方裝置。
  • dst:和 dst_device 資訊相同,只不過表示為結構體。
  • edge_name:邊名字,可以是張量名字,也可以是某種特殊意義的字串。
// Parses the key constructed by CreateKey and parse src/dst device
// names into structures respectively.
struct ParsedKey {
  StringPiece src_device;
  DeviceNameUtils::ParsedName src;
  uint64 src_incarnation = 0;
  StringPiece dst_device;
  DeviceNameUtils::ParsedName dst;
  StringPiece edge_name;

  ParsedKey() {}
  ParsedKey(const ParsedKey& b) { *this = b; }

  ParsedKey& operator=(const ParsedKey& b);
  StringPiece FullKey() const { return buf_; }

 private:
  friend class Rendezvous;
  friend class SendOp;
  friend class RecvOp;
  std::string buf_;
};

1.1.2 建立

具體生成字串 key 結果如下:

src_device ; HexString(src_incarnation) ; dst_device ; name ; frame_iter.frame_id : frame_iter.iter_id

具體程式碼如下:

/*  static */
string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation,
                             const string& dst_device, const string& name,
                             const FrameAndIter& frame_iter) {
  // NOTE: ';' is not used in the device name's job name.
  //
  // We include both sender and receiver in the key to facilitate
  // debugging. For correctness, we only need to encode the receiver.
  //
  // "src_incarnation" is used to distinguish a worker when it
  // restarts.
  char buf[strings::kFastToBufferSize];
  return strings::StrCat(
      src_device, ";", strings::Uint64ToHexString(src_incarnation, buf), ";",
      dst_device, ";", name, ";", frame_iter.frame_id, ":", frame_iter.iter_id);
}

然後系統會使用 ParseKey 方法來解析key,生成 ParsedKey。ParseKey 對輸入 key 的前四個域做了對映,拋棄第五個域 frame_iter.frame_id : frame_iter.iter_id。其他都直接對應字面意思,只是 edge_name 對應了 name。

/* static */
Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) {
  if (key.data() == out->buf_.data()) {
    // Caller used our buf_ string directly, so we don't need to copy.  (The
    // SendOp and RecvOp implementations do this, for example).
    DCHECK_EQ(key.size(), out->buf_.size());
  } else {
    // Make a copy that our StringPieces can point at a copy that will persist
    // for the lifetime of the ParsedKey object.
    out->buf_.assign(key.data(), key.size());
  }
  StringPiece s(out->buf_);
  StringPiece parts[5];
  for (int i = 0; i < 5; i++) {
    parts[i] = ConsumeNextPart(&s, ';');
  }
  if (s.empty() &&          // Consumed the whole string
      !parts[4].empty() &&  // Exactly five parts
      DeviceNameUtils::ParseFullName(parts[0], &out->src) &&
      strings::HexStringToUint64(parts[1], &out->src_incarnation) &&
      DeviceNameUtils::ParseFullName(parts[2], &out->dst) &&
      !parts[3].empty()) {
    out->src_device = StringPiece(parts[0].data(), parts[0].size());
    out->dst_device = StringPiece(parts[2].data(), parts[2].size());
    out->edge_name = StringPiece(parts[3].data(), parts[3].size());
    return Status::OK();
  }
  return errors::InvalidArgument("Invalid  rendezvous key: ", key);
}

1.2 Rendezvous

Rendezvous 是一個抽象,用於從生產者向消費者傳遞張量。一個 rendezvous 是一個通道(channels)的表(table)。每個通道都由一個 rendezvous 鍵來標記。該鍵編碼為<生產者,消費者>對,其中生產者和消費者是 tensorflow 裝置。

生產者呼叫 Send() 方法在一個命名的通道上傳送一個張量。消費者呼叫 Recv() 方法從一個指定的通道接收一個張量。一個張量的序列可以從生產者傳遞給消費者。 消費者按照生產者傳送的順序接收它們。

消費者可以在張量產生之前或之後安全地請求張量。 消費者可以選擇進行阻塞式呼叫或提供回撥:無論哪種情況,消費者都會在張量可用時收到它。 生產者永遠不會阻塞。

1.2.1 介面類

RendezvousInterface 是介面類,定義了虛擬函式。ParsedKey 也是定義在這裡(我們省略了這部分程式碼)。

class RendezvousInterface {
 public:
  struct Args {
    DeviceContext* device_context = nullptr;
    AllocatorAttributes alloc_attrs;
    CancellationManager* cancellation_manager = nullptr;  // not owned.
  };

  // The caller is a tensor producer and it sends a message (a tensor
  // "val" and a bool "is_dead") under the given "key".
  //
  // {val, is_dead} is bundled as a message sent and received.
  // Typically, is_dead is set by some control flow nodes
  // (e.g., a not-taken branch).  args is passed by Send to the
  // Recv function to communicate any information that the Recv
  // function might need.  This is typically only necessary for
  // Send/Recv on the same worker.
  //
  // Send() never blocks.
  virtual Status Send(const ParsedKey& key, const Args& args, const Tensor& val,
                      const bool is_dead) = 0;

  // Callback provided by a tensor consumer waiting on the rendezvous.
  // It will be invoked when the tensor is available, or when a non-OK
  // status arises in the production of that tensor.  It also gets
  // two Rendezvous::Args, one provided by the sender, the other by the
  // receiver, which may be needed when a non-CPU device is in use
  // by either side.
  typedef std::function<void(const Status&, const Args&, const Args&,
                             const Tensor&, const bool)>
      DoneCallback;

  virtual void RecvAsync(const ParsedKey& key, const Args& args,
                         DoneCallback done) = 0;

  // Synchronous wrapper for RecvAsync.
  Status Recv(const ParsedKey& key, const Args& args, Tensor* val,
              bool* is_dead, int64_t timeout_ms);
  Status Recv(const ParsedKey& key, const Args& args, Tensor* val,
              bool* is_dead);

  // Aborts all pending and future Send/Recv with the given "status".
  // StartAbort() does not wait for ongoing calls to finish.
  // REQUIRES: !status.ok()
  virtual void StartAbort(const Status& status) = 0;

 protected:
  virtual ~RendezvousInterface();

  virtual bool is_cross_process() { return false; }
  friend class ProcessFunctionLibraryRuntime;
};

1.2.2 基礎實現 Rendezvous

Rendezvous 類提供了最基本的 Send、Recv 和 RecvAsync 的實現,也提供了 ParseKey 功能。

// A reference-counted implementation of RendezvousInterface.
//
// This class is used in cases where a rendezvous may be shared between multiple
// threads with no clear owner.
class Rendezvous : public RendezvousInterface, public core::RefCounted {
 public:
  class Factory {
   public:
    // Default to a factory that evaluates to false.
    Factory() : valid_(false) {}

    Factory(std::function<Status(const int64_t, const DeviceMgr*, Rendezvous**)>
                create_fn,
            std::function<Status(const int64_t)> cleanup_fn)
        : valid_(true),
          create_fn_(std::move(create_fn)),
          cleanup_fn_(std::move(cleanup_fn)) {}

    // If no clean up fn is provided, just put in a dummy.
    // For backwards compatibility.
    explicit Factory(
        std::function<Status(const int64_t, const DeviceMgr*, Rendezvous**)>
            create_fn)
        : valid_(true),
          create_fn_(std::move(create_fn)),
          cleanup_fn_([](const int64_t step_id) { return Status::OK(); }) {}

    explicit operator bool() const { return valid_; }

    Status operator()(const int64_t step_id, const DeviceMgr* device_mgr,
                      Rendezvous** rendez) const {
      return create_fn_(step_id, device_mgr, rendez);
    }

    Status CleanUp(const int64_t step_id) const { return cleanup_fn_(step_id); }

   private:
    bool valid_;
    std::function<Status(const int64_t, const DeviceMgr*, Rendezvous**)>
        create_fn_;
    std::function<Status(const int64_t)> cleanup_fn_;
  };

  // Constructs a rendezvous key for the tensor of "name" sent from
  // "src_device" to "dst_device". The tensor is generated in the frame
  // and iteration specified by "frame_iter".
  static std::string CreateKey(const std::string& src_device,
                               uint64 src_incarnation,
                               const std::string& dst_device,
                               const std::string& name,
                               const FrameAndIter& frame_iter);

  static Status ParseKey(StringPiece key, ParsedKey* out);
};

1.2.3 跨程式 RemoteRendezvous

RemoteRendezvous 繼承了 Rendezvous,其只增加了一個純虛擬函式 Initialize 方法。所有跨程式通訊的派生類都需要重寫此函式,因為需要藉助 Session 成初始化工作。

RemoteRendezvous 可以處理兩個遠端程式之中生產者或消費者的情況,增加了與遠端工作者協調的功能。RemoteRendezvous 遵循兩階段初始化策略:首先,物件被構建。最終,它們將被初始化。RendezvousMgrInterface 的客戶端必須保證最終對返回的 RemoteRendezvous 呼叫了 nitialize 方法。

// RemoteRendezvous follow a 2-part initialization. First the objects are
// constructed. Eventually, they will be initialized. Clients of the
// RendezvousMgrInterface must guarantee to call Initialize on the returned
// RemoteRendezvous eventually.
//
// Partially initialized RemoteRendezvous must respect the Rendezvous interface
// (i.e. Send() must never block), however implementations are not expected to
// actually perform the underlying operations until after the RemoteRendezvous
// has been Initialize'd.
class RemoteRendezvous : public Rendezvous {
 public:
  // Fully construct the RemoteRendezvous.
  virtual Status Initialize(WorkerSession* session) = 0;

 protected:
  bool is_cross_process() override { return true; }
};

1.2.4 BaseRemoteRendezvous

因為跨程式通訊存在不同協議,所以跨程式通訊的各種 Rendezvous 都需要依據自己不同的協議來實現。所以 TF 在 RemoteRendezvous 和真正特化的各種 Rendezvous 中間加入了一箇中間層 BaseRemoteRendezvous,這個類起到了承上啟下的作用,提供了公共的 Send 和 Recv 方法,可以做到儘可能程式碼複用。

BaseRemoteRendezvous 主要成員變數是 Rendezvous* local_,程式碼之中大量使用了 BaseRecvTensorCall 作為引數,BaseRecvTensorCall 是通訊的實體抽象。

// RemoteRendezvous is a Rendezvous which can handle either
// the producer or consumer being in a remote process.
//
// Buffering of Tensor values is delegated to a "local" Rendezvous
// obtained from NewLocalRendezvous().  This class just adds
// functionality to coordinate with remote workers.
class BaseRemoteRendezvous : public RemoteRendezvous {
 public:
  BaseRemoteRendezvous(const WorkerEnv* env, int64_t step_id);

  // Upgrades the BaseRemoteRendezvous to full initialization.
  Status Initialize(WorkerSession* session) override;

  // Forwards to local_, where the Tensor "val" will be buffered and
  // any waiting callback stored.
  Status Send(const ParsedKey& key, const Rendezvous::Args& args,
              const Tensor& val, const bool is_dead) override;

  // This method is called only by the RecvOp.  It tests to see
  // whether the value will be produced by a local or remote device
  // and handles accordingly.  In the local case it forwards to
  // local_, in the remote case it initiates an RPC request.
  void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
                 DoneCallback done) override;

  void StartAbort(const Status& status) override;

  // This method is called only by the local Worker, forwarded through
  // the same method on RendezvousMgr.  This occurs when the Worker
  // has received a RecvTensor request, either locally or over the
  // network.  In either case it needs to retrieve a locally buffered
  // value from local_, and give it to its caller.
  //
  // Runs "done" as soon as the tensor for "parsed" is available or an error
  // is detected.
  //
  // REQUIRES: "parsed" is one that will be Saved into the local rendezvous.
  void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done);

 protected:
  virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
                                   const Rendezvous::Args& args,
                                   DoneCallback done) = 0;

  // Returns true if "src" and "dst" are located in the same worker,
  // and hence may use a local rendezvous.
  virtual bool IsSameWorker(DeviceNameUtils::ParsedName src,
                            DeviceNameUtils::ParsedName dst);

  // If aborted, aborts "call". Otherwise, adds "call" into active_.
  void RegisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args);

  // Removes "call" from active_ if "call" is in active_.
  void DeregisterCall(BaseRecvTensorCall* call);

  WorkerSession* session();

  bool is_initialized();

  ~BaseRemoteRendezvous() override;

  const WorkerEnv* const env_;  // Not owned.
  const int64_t step_id_;

 private:
  Rendezvous* local_;  // Owns a Ref on this object.

  mutable mutex mu_;

  // Status given by StartAbort() if any.
  Status status_ TF_GUARDED_BY(mu_);

  WorkerSession* session_ TF_GUARDED_BY(mu_);  // Not owned.

  // Data structures to handle calls when partially initialized.
  struct DeferredCall {
    const ParsedKey parsed;
    DoneCallback done;

    DeferredCall(const ParsedKey& parsed, DoneCallback done);
  };
  std::vector<DeferredCall> deferred_calls_ TF_GUARDED_BY(mu_);

  typedef std::function<void()> InactiveCallback;

  std::unordered_map<BaseRecvTensorCall*, InactiveCallback> active_
      TF_GUARDED_BY(mu_);

  bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(mu_) {
    return session_ != nullptr;
  }

  // If "is_src" is true, checks that the rendezvous key "parsed"'s
  // source is in this process. If "is_src" is false, checks that the
  // rendezvous key "parsed"'s destination is in this process.
  Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src);

  // Callback handling the case when a rendezvous has been
  // accomplished in local_ and the consumer is local to this process.
  // Tensor "in" will be copied into "out". The key "parsed" encodes
  // the src and dst devices.
  void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
                          const Rendezvous::Args& in_args,
                          const Rendezvous::Args& out_args, const Tensor& in,
                          Tensor* out, StatusCallback done);

  // Must be called only if fully initialized.
  void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);

  TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
};

class BaseRecvTensorCall {
 public:
  BaseRecvTensorCall() {}
  virtual ~BaseRecvTensorCall() {}
  virtual void Start(std::function<void()> recv_done) = 0;
  virtual void StartAbort(const Status& s) = 0;
  virtual Status status() const = 0;
 private:
  TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall);
};

在建立時候構建了一個 local Rendezvous,這個 local Rendezvous用來完成基本業務。

BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env,
                                           int64_t step_id)
    : env_(env),
      step_id_(step_id),
      local_(NewLocalRendezvous()),
      session_(nullptr) {}

Rendezvous* NewLocalRendezvous() { return new LocalRendezvousWrapper; }

LocalRendezvousWrapper 定義如下:

class LocalRendezvousWrapper : public Rendezvous {
 public:
  LocalRendezvousWrapper() : impl_(this) {}

  Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
              const bool is_dead) override {
    return impl_.Send(key, send_args, val, is_dead);
  }

  void RecvAsync(const ParsedKey& key, const Args& recv_args,
                 DoneCallback done) override {
    impl_.RecvAsync(key, recv_args, std::move(done));
  }

  void StartAbort(const Status& status) override { impl_.StartAbort(status); }

 private:
  LocalRendezvous impl_;

  TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousWrapper);
};

我們接下來看看 BaseRemoteRendezvous 初始化方法,其中做了基礎配置,比如設定session。

Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
  std::vector<DeferredCall> deferred_calls;
  {
    mutex_lock l(mu_);
    if (session_ != nullptr) {
      if (session_->worker_name() == session->worker_name()) {
        return Status::OK();
      }
      Status s = errors::Internal(
          "Double init! Worker names would have changed from: ",
          session_->worker_name(), " -> ", session->worker_name());
      return s;
    }
    session_ = session;
    std::swap(deferred_calls, deferred_calls_);
  }
  for (auto& call : deferred_calls) {
    RecvLocalAsyncInternal(call.parsed, std::move(call.done));
  }
  return Status::OK();
}

1.2.5 RpcRemoteRendezvous

RpcRemoteRendezvous 是 RemoteRendezvous 的 gRPC 協議實現。

class RpcRemoteRendezvous : public BaseRemoteRendezvous {
 public:
  RpcRemoteRendezvous(const WorkerEnv* env, int64_t step_id)
      : BaseRemoteRendezvous(env, step_id) {}

 protected:
  void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
                           const Rendezvous::Args& args,
                           DoneCallback done) override;

 private:
  ~RpcRemoteRendezvous() override {}

  TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
};

BaseRecvTensorCall 對應的派生類是 RpcRecvTensorCall。

// Used only to retrieve tensors from remote processes.
class RpcRecvTensorCall : public BaseRecvTensorCall {
 public:
  RpcRecvTensorCall() : wi_(nullptr), dst_device_(nullptr) {}

  void Init(WorkerInterface* wi, int64_t step_id, StringPiece key,
            AllocatorAttributes alloc_attrs, Device* dst_device,
            const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
    wi_ = wi;
    alloc_attrs_ = alloc_attrs;
    dst_device_ = dst_device;
    recv_args_ = recv_args;
    done_ = std::move(done);
    req_.set_step_id(step_id);
    req_.set_rendezvous_key(key.data(), key.size());
    req_.set_request_id(GetUniqueRequestId());
  }

  void Reset() {
    // The RpcRemoteRendezvous using this object is responsible for calling
    // ReleaseWorker() before Reset().

    alloc_attrs_ = AllocatorAttributes();
    dst_device_ = nullptr;
    // We don't clear opts_ and assume that Init will set up the state for
    // opts_ appropriately.
    req_.Clear();
    resp_.Clear();
    {
      mutex_lock l(mu_);
      status_ = Status::OK();
    }
    done_ = nullptr;
  }

  ~RpcRecvTensorCall() override {
    // Since only the RpcRecvTensorFreeList will delete an
    // RpcRecvTensorCall, we require that ReleaseWorker() has been called before
    // the user releases a Call object to the free list.
    CHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
        << "Leaking WorkerInterface in RpcRecvTensorCall destructor.";
  }

  void Start(std::function<void()> recv_done) override {
    StartRTCall(std::move(recv_done));
  }

  void StartAbort(const Status& s) override {
    {
      mutex_lock l(mu_);
      status_.Update(s);
    }
    opts_.StartCancel();
  }

  Status status() const override {
    mutex_lock l(mu_);
    return status_;
  }

  void ReleaseWorker(WorkerCacheInterface* worker_cache) {
    DCHECK_NE(static_cast<WorkerInterface*>(nullptr), wi_)
        << "RpcRecvTensorCall::ReleaseWorker() called twice.";
    worker_cache->ReleaseWorker(src_worker_, wi_);
    wi_ = nullptr;
  }

  const Tensor& tensor() const { return resp_.tensor(); }

  bool is_dead() const { return resp_.metadata().is_dead(); }

  Device* dst_device() const { return dst_device_; }
  const Rendezvous::Args& recv_args() const { return recv_args_; }
  const Rendezvous::DoneCallback& done() const { return done_; }

 private:
  friend class RpcRemoteRendezvous;

  // Start the main RecvTensor call, checking for an async abort.
  void StartRTCall(std::function<void()> recv_done) {
    resp_.InitAlloc(dst_device_, alloc_attrs_);
    auto abort_checked = std::make_shared<Notification>();
    auto cb = [this, abort_checked,
               recv_done = std::move(recv_done)](const Status& s) {
      // Make sure the Rendezvous abort checking is finished before running the
      // callback, which might destroy the current call object.
      abort_checked->WaitForNotification();
      if (!s.ok()) {
        mutex_lock l(mu_);
        status_.Update(s);
      }
      recv_done();
    };
    wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));

    // NOTE: Check if the rendezvous was aborted after sending out the RPC. The
    // ordering is important because StartAbort could be called right before
    // the RecvTensorAsync request registers its RPC cancellation to opts_.
    // In that case, the previous StartAbort would not trigger the
    // cancellation of this call.
    Status s;
    {
      mutex_lock l(mu_);
      s = status_;
    }
    if (!s.ok()) {
      opts_.StartCancel();
    }
    // Notify that the abort check has finished.
    abort_checked->Notify();
  }

  string src_worker_;
  string src_rel_device_;
  WorkerInterface* wi_;  // Not owned.
  AllocatorAttributes alloc_attrs_;
  Device* dst_device_;
  CallOptions opts_;
  RecvTensorRequest req_;
  TensorResponse resp_;
  Rendezvous::Args recv_args_;
  Rendezvous::DoneCallback done_;

  mutable mutex mu_;
  Status status_ TF_GUARDED_BY(mu_);

  TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
};

目前的邏輯關係具體如下:

圖 2 Rendezvous 邏輯關係

1.3 管理類

RendezvousMgr 主要負責建立和銷燬 RemoteRendezvous,其會跟蹤一組本地的 rendezvous 例項,本工作者傳送的所有張量都在 RendezvousMgr 中緩衝,直到張量被接收。 每個全域性唯一的 "step_id" 對應於一個由 RendezvousMgr 管理的本地 rendezvous例項。

1.3.1 介面

RendezvousMgrInterface 是介面類。

// RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received.  Each global unique "step_id"
// corresponds to one local rendezvous instance managed by a
// RendezvousMgr.
//
// E.g.,
//   Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
//   fork execution of an graph executor using "rendez"  on thread 1;
//   fork execution of another graph executor using "rendez" on thread 2;
//   ...
//   join threads 1 and 2;
//
// In the example above, execution in thread 1 and 2 communicates with
// each other by send/recv operations through the "rend".
//
// Tensors sent and recved through rendezvous managed by this
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
class RendezvousMgrInterface {
 public:
  RendezvousMgrInterface() {}
  virtual ~RendezvousMgrInterface() {}

  // Returns Rendezvous supporting send and recv among workers in the
  // "step_id".  The caller takes ownership of one reference on the
  // returned Rendezvous instance.
  //
  // Note: the caller must guarantee to eventually call Initialize on the
  // returned RemoteRendezvous
  virtual RemoteRendezvous* Find(int64_t step_id) = 0;

  // Finds the local rendezvous instance for the "step_id".  Runs
  // "done" when the tensor for "key" is produced or an error occurs.
  //
  // This method is used by the rpc handler of RecvTensor.
  virtual void RecvLocalAsync(int64_t step_id,
                              const Rendezvous::ParsedKey& parsed,
                              Rendezvous::DoneCallback done) = 0;

  // Synchronous wrapper for RecvLocalAsync.
  virtual Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed,
                           Tensor* val, bool* is_dead) = 0;

  // Removes rendezvous for "step_id".
  //
  // TODO(zhifengc): Have a background thread in worker that
  // periodically calls CleanupAll().
  virtual void Cleanup(int64_t step_id) = 0;
};

1.3.2 BaseRendezvousMgr

BaseRendezvousMgr 實現了基本功能,比如依據step_id查詢Rendezvous。

class BaseRendezvousMgr : public RendezvousMgrInterface {
 public:
  explicit BaseRendezvousMgr(const WorkerEnv* worker_env);

  ~BaseRendezvousMgr() override;

  // Returns Rendezvous supporting send and recv among workers in the
  // "step_id".  The caller takes ownership of one reference on the
  // returned Rendezvous instance.
  //
  // Note: the caller must guarantee to eventually call Initialize on the
  // returned RemoteRendezvous
  RemoteRendezvous* Find(int64_t step_id) override;

  // Finds the local rendezvous instance for the "step_id".  Runs
  // "done" when the tensor for "key" is produced or an error occurs.
  //
  // This method is used by the rpc handler of RecvTensor.
  void RecvLocalAsync(int64_t step_id, const Rendezvous::ParsedKey& parsed,
                      Rendezvous::DoneCallback done) override;

  // Synchronous wrapper for RecvLocalAsync.
  Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed,
                   Tensor* val, bool* is_dead) override;

  // Removes rendezvous for "step_id".
  void Cleanup(int64_t step_id) override;

 protected:
  virtual BaseRemoteRendezvous* Create(int64_t step_id,
                                       const WorkerEnv* worker_env) = 0;

 private:
  // Maps step_id to rendezvous.
  typedef absl::flat_hash_map<int64_t, BaseRemoteRendezvous*> Table;

  // Not owned.
  const WorkerEnv* const worker_env_;

  mutex mu_;
  Table table_ TF_GUARDED_BY(mu_);

  BaseRemoteRendezvous* FindOrCreate(int64_t step_id);

  TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr);
};

2. 使用

在前面執行計算時候,我們看到了一些關於 Rendezvous 的使用,接下來我們就找幾個情景來分析一下。

2.1 Worker 接受

我們首先看看接受方的 worker。

2.1.1 DoRunGraph

Worker 在 DoRunGraph 方法之中會接受張量。

void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
                        MutableRunGraphResponseWrapper* response,
                        StatusCallback done) {

  session->graph_mgr()->ExecuteAsync(
      request->graph_handle(), step_id, session.get(), request->exec_opts(),
      collector, response, cm, in,
      [this, step_id, response, session, cm, out, token, collector,
       device_profiler_session, opts, done](const Status& status) {
        Status s = status;
        if (s.ok()) {
          // 接受張量
          s = session->graph_mgr()->RecvOutputs(step_id, out);
        }
      });
}

RecvOutputs 方法如下,就是依據step_id獲取一個Rendezvous,然後接受訊息。

Status GraphMgr::RecvOutputs(const int64_t step_id, NamedTensors* out) {
  Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
  Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
  rendezvous->Unref();
  size_t output_size = 0;
  for (auto& p : *out) {
    output_size += p.second.AllocatedBytes();
  }
  return s;
}

具體如下圖所示,流程順序如圖上數字,其中第3步返回了一個Rendezvous,RecvOutputsFromRendezvous 是一個全域性方法。

2.1.2 DoPartialRunGraph

DoPartialRunGraph 會呼叫 RecvOutputsAsync 完成接受任務。

void Worker::DoPartialRunGraph(CallOptions* opts,
                               RunGraphRequestWrapper* request,
                               MutableRunGraphResponseWrapper* response,
                               StatusCallback done) {
  const int64_t step_id = request->step_id();
  const string& graph_handle = request->graph_handle();

  Status s = recent_request_ids_.TrackUnique(
      request->request_id(), "PartialRunGraph (Worker)", request);

  std::shared_ptr<WorkerSession> session;
  if (request->create_worker_session_called()) {
    s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
                                                   &session);
  } else {
    session = env_->session_mgr->LegacySession();
  }

  GraphMgr::NamedTensors in;
  GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
  s = PrepareRunGraph(request, &in, out);
  auto finish = [done, out, opts](const Status& s) {
    opts->ClearCancelCallback();
    delete out;
    done(s);
  };

  CancellationManager* cm = nullptr;
  bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);

  // Before we start doing anything, we set the RPC cancellation.
  opts->SetCancelCallback([this, cm, step_id]() {
    cm->StartCancel();
    AbortStep(step_id);
  });

  // If this is a new partial run request, the request will need to start the
  // executors.
  if (is_new_partial_run) {
    CancellationToken token;
    token = cancellation_manager_.get_cancellation_token();
    cancellation_manager_.RegisterCallback(token,
                                           [cm]() { cm->StartCancel(); });
    session->graph_mgr()->ExecuteAsync(
        graph_handle, step_id, session.get(), request->exec_opts(),
        nullptr /* collector */, nullptr /* response */, cm, in,
        [this, token, step_id, session](Status s) {
          cancellation_manager_.DeregisterCallback(token);
          partial_run_mgr_.ExecutorDone(step_id, s);
        });
  } else {
    // Send the partial run's new inputs.
    s = session->graph_mgr()->SendInputs(step_id, in);
  }

  // 這裡會呼叫到 RecvOutputsAsync 來接受張量
  session->graph_mgr()->RecvOutputsAsync(
      step_id, out, [this, out, request, response, step_id, finish](Status s) {
        if (s.ok()) {
          // Construct and return the resp.
          for (const auto& p : *out) {
            const string& key = p.first;
            const Tensor& val = p.second;
            response->AddRecv(key, val);
          }
        }
        if (request->is_last_partial_run()) {
          partial_run_mgr_.PartialRunDone(step_id, finish, s);
        } else {
          finish(s);
        }
      });
}

RecvOutputsAsync 這裡呼叫了 RecvOutputsFromRendezvousAsync。

void GraphMgr::RecvOutputsAsync(const int64_t step_id, NamedTensors* out,
                                StatusCallback done) {
  Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
  std::vector<string> keys;
  std::vector<Tensor>* received_keys = new std::vector<Tensor>;
  keys.reserve(out->size());
  received_keys->reserve(out->size());
  for (const auto& p : *out) {
    keys.push_back(p.first);
    received_keys->push_back(p.second);
  }
  RecvOutputsFromRendezvousAsync(
      rendezvous, nullptr, {}, keys, received_keys,
      [done, rendezvous, received_keys, out, keys](const Status s) {
        rendezvous->Unref();
        size_t output_size = 0;
        for (int i = 0, end = keys.size(); i < end; ++i) {
          (*out)[keys[i]] = (*received_keys)[i];
          output_size += (*out)[keys[i]].AllocatedBytes();
        }
        metrics::RecordGraphOutputTensors(output_size);
        delete received_keys;
        done(s);
      });
}

具體如下圖,流程順序如圖上數字,其中第3步返回了一個Rendezvous,RecvOutputsFromRendezvousAsync是一個全域性方法。

2.2 GraphMgr 傳送

在 ExecuteAsync 之中會傳送張量。

void GraphMgr::ExecuteAsync(const string& handle, const int64_t step_id,
                            WorkerSession* session, const ExecutorOpts& opts,
                            StepStatsCollector* collector,
                            MutableRunGraphResponseWrapper* response,
                            CancellationManager* cancellation_manager,
                            const NamedTensors& in, StatusCallback done) {

  if (s.ok()) {
    // 傳送張量
    s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
  }

  // 執行子計算圖  
  StartParallelExecutors(
      handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
      cancellation_manager, session, start_time_usecs,
      [item, rendezvous, ce_handle, done, start_time_usecs, input_size,
       step_id](const Status& s) {
      });
}

SendTensorsToRendezvous 如下:

Status SendTensorsToRendezvous(
    RendezvousInterface* rendezvous, DeviceContext* device_context,
    const std::vector<AllocatorAttributes>& alloc_attrs,
    const std::vector<string>& keys, gtl::ArraySlice<Tensor> tensors_to_send) {

  Rendezvous::ParsedKey parsed;
  for (int i = 0; i < keys.size(); ++i) {
    Rendezvous::Args rendez_args;
    rendez_args.device_context = device_context;
    if (!alloc_attrs.empty()) {
      rendez_args.alloc_attrs = alloc_attrs[i];
    }
    TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed));
    TF_RETURN_IF_ERROR(
        rendezvous->Send(parsed, rendez_args, tensors_to_send[i], false));
  }
  return Status::OK();
}

我們接下來就仔細分析一下如何接受和傳送。

3. 傳送

我們首先看看傳送流程。Send 過程並不涉及跨程式傳輸,所以和本地場景下的 Send 傳輸過程相同,這裡只是把張量放到 Worker 的本地 Table 之中,完全不涉及跨網路傳輸,是非阻塞的。

3.1 BaseRemoteRendezvous

Send 方法呼叫了 local_->Send 完成功能。

Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
                                  const Rendezvous::Args& args,
                                  const Tensor& val, const bool is_dead) {

  WorkerSession* sess = nullptr;
  {
    tf_shared_lock l(mu_);
    if (!status_.ok()) return status_;
    sess = session_;
  }

  if (!IsLocalDevice(sess->worker_name(), parsed.src_device)) {
    return errors::InvalidArgument(
        "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
        sess->worker_name());
  }

  // Buffers "val" and "device_context" in local_.
  return local_->Send(parsed, args, val, is_dead);
}

3.2 LocalRendezvous

LocalRendezvous::Send 會把張量插入到本地表。

Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key,
                             const Rendezvous::Args& send_args,
                             const Tensor& val, const bool is_dead) {
  uint64 key_hash = KeyHash(key.FullKey());

  if (is_dead) {
    static auto* rendezvous_dead_values_sent = monitoring::Counter<2>::New(
        "/tensorflow/core/rendezvous_dead_values_sent",
        "The number of dead values sent between a pair of devices.",
        "send_device", "recv_device");
    rendezvous_dead_values_sent
        ->GetCell(string(key.src_device), string(key.dst_device))
        ->IncrementBy(1);
  }

  mu_.lock();
  if (!status_.ok()) {
    // Rendezvous has been aborted.
    Status s = status_;
    mu_.unlock();
    return s;
  }

  ItemQueue* queue = &table_[key_hash];
  if (queue->head == nullptr || queue->head->type == Item::kSend) {
    // There is no waiter for this message. Append the message
    // into the queue. The waiter will pick it up when arrives.
    // Only send-related fields need to be filled.
    queue->push_back(new Item(send_args, val, is_dead));
    mu_.unlock();
    return Status::OK();
  }

  // There is an earliest waiter to consume this message.
  Item* item = queue->head;

  // Delete the queue when the last element has been consumed.
  if (item->next == nullptr) {
    table_.erase(key_hash);
  } else {
    queue->head = item->next;
  }
  mu_.unlock();

  // Notify the waiter by invoking its done closure, outside the
  // lock.
  DCHECK_EQ(item->type, Item::kRecv);
  (*item->recv_state.waiter)(Status::OK(), send_args, item->args, val, is_dead);
  delete item;
  return Status::OK();
}

此時邏輯如下,這裡 Worker 0 指代的是一個工作者角色,並非是 Worker 類。

圖 3 傳送邏輯

4. 接受

傳送端現在已經把準備好的張量放入本地 table。接收端需要從傳送端的 table 取出張量,這裡就涉及了跨程式傳輸。接受的處理過程是:

  • Recv方 是 Client,Recv 方將所需要的 Tensor 對應的 ParsedKey 拼接出來,然後向 Send 方發出 Request,ParsedKey 攜帶於 Request 之中。
  • Send方 是 Server,接收到 Request 後,Send 方立即在本地 Table 中查詢 Client 所需要的Tensor,找到後將 Tensor 封裝成 Response 傳送回 Recv 方。

這裡重點是:資料傳輸由 recv 部分發起,向 Send 方主動發出請求來觸發通訊過程。這與我們常見的模式不同。我們知道,Worker 之中既有同步呼叫,也有非同步呼叫,我們選擇非同步呼叫來看看。先提前給出一個傳送接受流程讓大家有個整體認識。下圖之中虛線表示返回張量。

圖 4 傳送接受整體邏輯

4.1 Client

客戶端邏輯如下:

4.1.1 RecvOutputsFromRendezvousAsync

全域性函式 RecvOutputsFromRendezvousAsync 呼叫到了 rendezvous->RecvAsync。

void RecvOutputsFromRendezvousAsync(
    RendezvousInterface* rendezvous, DeviceContext* device_context,
    const std::vector<AllocatorAttributes>& alloc_attrs,
    const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
    StatusCallback done) {
  if (keys.empty()) {
    done(Status::OK());
    return;
  }

  received_tensors->reserve(keys.size());
  std::vector<
      std::tuple<string, Tensor*, Rendezvous::ParsedKey, AllocatorAttributes>>
      arguments;
  for (int i = 0; i < keys.size(); ++i) {
    Rendezvous::ParsedKey parsed;
    Status s = Rendezvous::ParseKey(keys[i], &parsed);
    received_tensors->push_back(Tensor());
    if (!s.ok()) {
      done(s);
      return;
    }
    AllocatorAttributes alloc_attr;
    if (!alloc_attrs.empty()) {
      alloc_attr = alloc_attrs[i];
    }
    arguments.emplace_back(keys[i], &((*received_tensors)[i]), parsed,
                           alloc_attr);
  }

  auto status_cb = new ReffedStatusCallback(std::move(done));
  for (auto& p : arguments) {
    const string& key = std::get<0>(p);
    Tensor* val = std::get<1>(p);
    Rendezvous::ParsedKey parsed = std::get<2>(p);
    Rendezvous::Args rendez_args;
    rendez_args.device_context = device_context;
    rendez_args.alloc_attrs = std::get<3>(p);
    status_cb->Ref();
    rendezvous->RecvAsync(
        parsed, rendez_args,
        [val, key, status_cb](const Status& s,
                              const Rendezvous::Args& send_args,
                              const Rendezvous::Args& recv_args,
                              const Tensor& v, const bool is_dead) {
          Status status = s;
          if (status.ok()) {
            *val = v;
            if (is_dead) {
              status = errors::InvalidArgument("The tensor returned for ", key,
                                               " was not valid.");
            }
          }
          status_cb->UpdateStatus(status);
          status_cb->Unref();
        });
  }
  status_cb->Unref();
}

4.1.2 BaseRemoteRendezvous

因為不在一個程式之內,所以呼叫到了 RecvFromRemoteAsync。

void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
                                     const Rendezvous::Args& recv_args,
                                     DoneCallback done) {
  Status s = ValidateDevices(parsed, false /*!is_src*/);

  profiler::ScopedMemoryDebugAnnotation op_annotation("RecvAsync", step_id_);
  // Are src and dst in the same worker?
  if (IsSameWorker(parsed.src, parsed.dst)) { // 在同一個worker裡面
    // Recv the tensor from local_.
    local_->RecvAsync(
        parsed, recv_args,
        [this, parsed, done](
            const Status& status, const Rendezvous::Args& send_args,
            const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {

          Tensor* out = new Tensor;
          StatusCallback final_callback = [done, send_args, recv_args, out,
                                           is_dead](const Status& s) {
            done(s, send_args, recv_args, *out, is_dead);
            delete out;
          };

          if (status.ok()) {
            SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
                               std::move(final_callback));
          } else {
            final_callback(status);
          }
        });
    return;
  } else { // 不在同一個worker裡面
    RecvFromRemoteAsync(parsed, recv_args, std::move(done));
  }
}

4.1.3 RpcRemoteRendezvous

RpcRemoteRendezvous 檢查各項引數,準備 RpcRecvTensorCall,之後啟動 call->Start(),Start() 裡面調的是 StartRTCall()。RpcRecvTensorCall 繼承了 BaseRecvTensorCall 這個抽象基類,是一次 gRPC 呼叫的抽象,其封裝了複雜的後續呼叫鏈。這裡關鍵點是如下兩句,就是如何使用對應的 Worker 設定 RpcRecvTensorCall:

WorkerInterface* rwi = worker_cache->GetOrCreateWorker(call->src_worker_);

call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
             recv_args, std::move(done));

完整程式碼如下:

void RpcRemoteRendezvous::RecvFromRemoteAsync(
    const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
    DoneCallback done) {
  CHECK(is_initialized());
  Status s;

  // Prepare a RecvTensor call that can handle being aborted.
  // 生成一個 Call
  RpcRecvTensorCall* call = get_call_freelist()->New();

  // key.src_device identifies a remote device.
  if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
                                        &call->src_rel_device_)) {
    s = errors::Internal(parsed.src_device,
                         " is invalid remote source device.");
  }
  WorkerSession* sess = session();
  std::shared_ptr<WorkerCacheInterface> worker_cache =
      sess->GetSharedWorkerCache();
  // The worker will be released in a subsequent call to
  // sess->worker_cache()->ReleaseWorker() (if the call has not yet been
  // initialized) or call->ReleaseWorker() (if it has been initialized).
  
  // 拿到對應的 Worker
  WorkerInterface* rwi = worker_cache->GetOrCreateWorker(call->src_worker_);

  Device* dst_device;
  if (s.ok()) {
    s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
  }
  if (!s.ok()) {
    if (rwi != nullptr) {
      sess->worker_cache()->ReleaseWorker(call->src_worker_, rwi);
    }
    get_call_freelist()->Release(call);
    done(s, Args(), recv_args, Tensor{}, false);
    return;
  }

  // 用 Worker 來初始化
  call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
             recv_args, std::move(done));

  // Record "call" in active_ so that it can be aborted cleanly.
  RegisterCall(call, recv_args);

  // Start "call".
  Ref();
  call->Start([this, call, worker_cache]() {
    // Removes "call" from active_. Prevent StartAbort().
    DeregisterCall(call);
    // If StartAbort was called prior to DeregisterCall, then the
    // current status should be bad.
    Status s = call->status();
    // NOTE: *session() can potentially be deleted before we return from
    // call->done()(...), so we must release the worker before calling the
    // callback.
    call->ReleaseWorker(session()->worker_cache());
    call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
    get_call_freelist()->Release(call);
    Unref();
  });
}

4.1.4 RpcRecvTensorCall

RpcRecvTensorCall 的 Start 方法如下,結果又來到了 StartRTCall。

void RpcRecvTensorCall::Start(std::function<void()> recv_done) override {
  StartRTCall(std::move(recv_done));
}

RpcRecvTensorCall::StartRTCall 之中,會呼叫 Worker 的 RecvTensorAsync 來完成傳輸,其實就是 GrpcRemoteWorker 的 RecvTensorAsync。

// Start the main RecvTensor call, checking for an async abort.
void RpcRecvTensorCall::StartRTCall(std::function<void()> recv_done) {
  resp_.InitAlloc(dst_device_, alloc_attrs_);
  auto abort_checked = std::make_shared<Notification>();
  auto cb = [this, abort_checked,
             recv_done = std::move(recv_done)](const Status& s) {
    // Make sure the Rendezvous abort checking is finished before running the
    // callback, which might destroy the current call object.
    abort_checked->WaitForNotification();
    if (!s.ok()) {
      mutex_lock l(mu_);
      status_.Update(s);
    }
    recv_done();
  };
  wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));

  // NOTE: Check if the rendezvous was aborted after sending out the RPC. The
  // ordering is important because StartAbort could be called right before
  // the RecvTensorAsync request registers its RPC cancellation to opts_.
  // In that case, the previous StartAbort would not trigger the
  // cancellation of this call.
  Status s;
  {
    mutex_lock l(mu_);
    s = status_;
  }
  if (!s.ok()) {
    opts_.StartCancel();
  }
  // Notify that the abort check has finished.
  abort_checked->Notify();
}

4.1.5 GrpcRemoteWorker

RecvTensorAsync 方法的縮減版本如下,於是我們回到了熟悉的 Worker 流程。

void GrpcRemoteWorker::RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request, TensorResponse* response, StatusCallback done) override {
  IssueRequest(request, response, recvtensor_, callback, call_opts);
}

目前我們完成了下圖的右半部分,如圖上圓圈所示。

4.2 Server

現在我們來到了 Server 端,其實就是張量傳送方。接收到 RecvTensorRequest 之後的邏輯如下:

4.2.1 GrpcWorkerService

GrpcWorkerServiceThread::HandleRPCsLoop 之中有一個 for 迴圈,插入了 1000 個處理機制,設定了 GrpcWorkerMethod::kRecvTensor 由 EnqueueRecvTensorRequestRaw() 處理。這是事先快取,為了加速處理,而且 EnqueueRecvTensorRequestRaw 之中在處理一個訊息之後,會呼叫 EnqueueRequestForMethod 再次插入一個處理機制。

void GrpcWorkerServiceThread::HandleRPCsLoop() {
  // TODO(ncteisen): This may require performance engineering. We can
  // change the number of threads, the number of handlers per thread,
  // or even decide to specialize certain threads to certain methods.
  SETUP_FOR_REQUEST(GetStatus, 1, false);
  SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
  SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
  SETUP_FOR_REQUEST(CleanupAll, 1, false);
  SETUP_FOR_REQUEST(RegisterGraph, 1, false);
  SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
  SETUP_FOR_REQUEST(Logging, 1, false);
  SETUP_FOR_REQUEST(Tracing, 1, false);
  SETUP_FOR_REQUEST(CompleteGroup, 10, true);
  SETUP_FOR_REQUEST(CompleteInstance, 10, true);
  SETUP_FOR_REQUEST(GetStepSequence, 10, true);
  SETUP_FOR_REQUEST(RecvBuf, 500, true);
  SETUP_FOR_REQUEST(RunGraph, 100, true);
  SETUP_FOR_REQUEST(CleanupGraph, 100, false);
  SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);

  // TODO(ncteisen): Determine a better policy for enqueuing the
  // appropriate number of each request type.
  for (int i = 0;
       i < gtl::FindWithDefault(
               queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
               1000);
       ++i) {
    EnqueueRecvTensorRequestRaw(); // 設定
  }

  void* tag;
  bool ok;

  while (cq_->Next(&tag, &ok)) {
    UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
        static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
    CHECK(callback_tag);
    callback_tag->OnCompleted(this, ok);
  }
}

這裡會再次插入,會設定由 GrpcWorkerServiceThread::RecvTensorHandlerRaw 繼續處理 GrpcWorkerMethod::kRecvTensor。

void EnqueueRecvTensorRequestRaw() {
  mutex_lock l(shutdown_mu_);
  if (!is_shutdown_) {
    Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
         RecvTensorRequest, ::grpc::ByteBuffer>::
        EnqueueRequestForMethod(
            worker_service_, cq_.get(),
            static_cast<int>(GrpcWorkerMethod::kRecvTensor),
            &GrpcWorkerServiceThread::RecvTensorHandlerRaw,
            true /* supports cancel*/);
  }
}

4.2.2 GrpcWorkerServiceThread

GrpcWorkerServiceThread 是服務端處理請求的執行緒類。這裡就是呼叫 GrpcWorker 來繼續處理。這裡使用了 WorkerCall 來作為引數。WorkerCall 是服務端處理一次 gRPC 請求和響應的類,是個別名。

using WorkerCall =
    Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
         RequestMessage, ResponseMessage>;

程式碼具體如下:

void GrpcWorkerServiceThread::RecvTensorHandlerRaw(
    WorkerCall<RecvTensorRequest, ::grpc::ByteBuffer>* call) {
  Schedule([this, call]() {
    CallOptions* call_opts = new CallOptions;
    call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });

    worker_->GrpcRecvTensorAsync(
        call_opts, &call->request, &call->response,
        [call, call_opts](const Status& s) {
          call->ClearCancelCallback();
          delete call_opts;
          if (!s.ok()) {
            VLOG(3) << "Bad response from RecvTensor:" << s;
          }
          call->SendResponse(ToGrpcStatus(s));
        });
  });
  EnqueueRecvTensorRequestRaw();
}

4.2.3 GrpcWorker

GrpcWorker 是真正負責處理請求邏輯的 Worker,是 GrpcRemoteWorker 的服務端版本。GrpcWorker::GrpcRecvTensorAsync 邏輯是:

  • 會獲取 rendezvous。使用 rendezvous_mgr->RecvLocalAsync 將客戶端所需要的 Tensor 從本地 Table 查詢出來。
  • 呼叫 grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response) 把張量編碼。
  • 然後在 callback 之中呼叫 CopyDeviceToHost 把張量從 GPU 拷貝到 CPU。
  • 最後利用 gRPC 傳送回客戶端。
// GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
// buffers for a response object, to avoid extra protocol buffer serialization
// overhead we generate our response directly into a ::grpc::ByteBuffer object
void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
                                     const RecvTensorRequest* request,
                                     ::grpc::ByteBuffer* response,
                                     StatusCallback done) {

  const int64_t request_id = request->request_id();
  const int64_t step_id = request->step_id();

  bool cache_enabled = (response_cache_ != nullptr && request_id != 0);

  auto do_response = [response, done, cache_enabled](const Tensor& tensor,
                                                     bool is_dead,
                                                     const Status& status) {
    if (status.ok()) {
      grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response);
    }
    done(status);
  };

  // If response cache is enabled and the response cache already contains the
  // request, we delegate this retry request to the response cache. Otherwise,
  // we add the request to the response cache and start the computation to
  // retrieve the requested data.
  if (cache_enabled &&
      response_cache_->QueueRequest(request_id, step_id, do_response)) {
    return;
  }

  auto rendezvous_done = [this, request_id, do_response, cache_enabled](
                             const Tensor& tensor, bool is_dead,
                             const Status& status) {
    if (cache_enabled) {
      // Data is ready. Process all pending requests in the response cache.
      response_cache_->OnRequestFinished(request_id, tensor, is_dead, status);
    } else {
      do_response(tensor, is_dead, status);
    }
  };

  auto fail = [&rendezvous_done](const Status& status) {
    rendezvous_done(Tensor(), false, status);
  };

  Status s = recent_request_ids_.TrackUnique(
      request_id, "RecvTensor (GrpcWorker)", *request);

  const string& key = request->rendezvous_key();
  Rendezvous::ParsedKey parsed;
  s = Rendezvous::ParseKey(key, &parsed);
  Device* src_dev = nullptr;
  if (s.ok()) {
    s = PrepareRecvTensor(parsed, &src_dev);
  }

  // Request the tensor associated with the rendezvous key.
  // Note that we log the cancellation here but do not abort the current step.
  // gRPC can generate cancellations in response to transient network failures,
  // and aborting the step eliminates the opportunity for client side retries.
  // Repeated client failures will eventually cause the step to be aborted by
  // the client.
  opts->SetCancelCallback(
      [step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; });
  env_->rendezvous_mgr->RecvLocalAsync(
      step_id, parsed,
      [opts, rendezvous_done, src_dev, request](
          const Status& status, const Rendezvous::Args& send_args,
          const Rendezvous::Args& recv_args, const Tensor& val,
          const bool is_dead) {
        opts->ClearCancelCallback();
        if (status.ok()) {
          // DMA can only be used for Tensors that do not fall into
          // the following three odd edge cases: 1) a zero-size
          // buffer, 2) a dead tensor which has an uninit value, and
          // 3) the tensor has the on_host allocation attribute,
          // i.e. it's in CPU RAM *independent of its assigned
          // device type*.
          const bool on_host = send_args.alloc_attrs.on_host();
          {
            // Non-DMA cases.
            if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
              DeviceContext* send_dev_context = send_args.device_context;
              AllocatorAttributes alloc_attrs;
              alloc_attrs.set_gpu_compatible(true);
              alloc_attrs.set_on_host(true);
              Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
              Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
              // "val" is on an accelerator device. Uses the device_context to
              // fill the copy on host.
              StatusCallback copy_ready = [rendezvous_done, copy,
                                           is_dead](const Status& s) {
                // The value is now ready to be returned on the wire.
                rendezvous_done(*copy, is_dead, s);
                delete copy;
              };

              CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(),
                               src_dev, copy, send_dev_context, copy_ready);
              return;
            }
          }
        }

        rendezvous_done(val, is_dead, status);
      });
}

4.2.4 BaseRendezvousMgr

BaseRendezvousMgr::RecvLocalAsync 會從本地 Table 查詢張量。

void BaseRendezvousMgr::RecvLocalAsync(int64_t step_id,
                                       const Rendezvous::ParsedKey& parsed,
                                       Rendezvous::DoneCallback done) {
  auto rendez = FindOrCreate(step_id);
  auto done_cb = [rendez, done = std::move(done)](
                     const Status& s, const Rendezvous::Args& send_args,
                     const Rendezvous::Args& recv_args, const Tensor& v,
                     bool dead) {
    rendez->Unref();
    done(s, send_args, recv_args, v, dead);
  };
  rendez->RecvLocalAsync(parsed, std::move(done_cb));
}

4.2.5 BaseRemoteRendezvous

其實,最終呼叫到了 RecvLocalAsyncInternal,其關鍵程式碼是 local_->RecvAsync。

void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
                                          DoneCallback done) {
  // Test whether the rendezvous is initialized using a shared lock, to avoid
  // the need for exclusive access in the common case.
  if (TF_PREDICT_FALSE(!is_initialized())) {
    mutex_lock l(mu_);
    if (!is_initialized_locked()) {
      // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
      // remote worker) before the RunStep (or PartialRunStep) RPC from the
      // master arrives. RecvLocalAsync thus buffers the arguments until after
      // the RemoteRendezvous is Initialize()'d, when it completes the
      // rendezvous logic. At some point after Initialize() is called, a Tensor
      // is produced locally that will then be sent in response to the incoming
      // RPC.
      DeferredCall call(parsed, std::move(done));
      deferred_calls_.push_back(call);
      return;
    }
  }
  RecvLocalAsyncInternal(parsed, std::move(done));
}

void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
                                                  DoneCallback done) {
  Status s = ValidateDevices(parsed, true /* is_src */);
  if (!s.ok()) {
    done(s, Args(), Args(), Tensor(), false);
    return;
  }
  local_->RecvAsync(parsed, Args(), std::move(done));
}

4.2.6 LocalRendezvous

LocalRendezvous::RecvAsync 完成了從本地 table 讀取張量的操作。

void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
                                const Rendezvous::Args& recv_args,
                                Rendezvous::DoneCallback done) {
  uint64 key_hash = KeyHash(key.FullKey());

  mu_.lock();
  if (!status_.ok()) {
    // Rendezvous has been aborted.
    Status s = status_;
    mu_.unlock();
    done(s, Rendezvous::Args(), recv_args, Tensor(), false);
    return;
  }

  ItemQueue* queue = &table_[key_hash];
  if (queue->head == nullptr || queue->head->type == Item::kRecv) {
    // There is no message to pick up.
    // Only recv-related fields need to be filled.
    CancellationManager* cm = recv_args.cancellation_manager;
    CancellationToken token = CancellationManager::kInvalidToken;
    bool already_cancelled = false;
    if (cm != nullptr) {
      // Increment the refcount when cancellation manager is present, to make
      // sure the rendezvous outlives the recv and its cancel callbacks.
      // This refcount is dropped in exactly one of the following cases:
      // (1) Recv registers cancellation callback to cm, and then cm is
      //     cancelled, unref in the cancellation callback;
      // (2) Recv registers cancellation callback to cm, but cm is already
      //     cancelled, unref in the already_cancelled check;
      // (3) Recv is successful, and item done callback finishes deregistering
      //     the cancellation callback, unref in the item done callback;
      // (4) Recv is successful, but the item done callback fails to deregister
      //     the cancellation callback because cm already StartCancel, in this
      //     case the cancellation callback will be invoked by the cm anyway,
      //     unref in the cancellation callback.
      if (rc_owner_) rc_owner_->Ref();
      token = cm->get_cancellation_token();
      already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
        Item* item = nullptr;
        {
          mutex_lock l(mu_);
          ItemQueue* queue = &table_[key_hash];
          // Find an item in the queue with a cancellation token that matches
          // token, and remove it.
          if (queue->head != nullptr && queue->head->type == Item::kRecv) {
            for (Item *prev = nullptr, *curr = queue->head; curr != nullptr;
                 prev = curr, curr = curr->next) {
              if (curr->recv_state.cancellation_token == token) {
                item = curr;
                if (queue->head->next == nullptr) {
                  // We have a single-element queue, so we can erase it from
                  // the table.
                  table_.erase(key_hash);
                } else {
                  // Remove the current item from the queue.
                  if (curr == queue->head) {
                    DCHECK_EQ(prev, nullptr);
                    queue->head = curr->next;
                  } else {
                    DCHECK_NE(prev, nullptr);
                    prev->next = curr->next;
                  }
                  if (queue->tail == curr) {
                    queue->tail = prev;
                  }
                }
                break;
              }
            }
          }
        }

        if (item != nullptr) {
          (*item->recv_state.waiter)(
              StatusGroup::MakeDerived(
                  errors::Cancelled("RecvAsync is cancelled.")),
              Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
          delete item;
        }
        // Unref case (1) and (4)
        if (rc_owner_) rc_owner_->Unref();
      });
    }
    if (already_cancelled) {
      mu_.unlock();
      // Unref case (2)
      if (rc_owner_) rc_owner_->Unref();
      done(StatusGroup::MakeDerived(
               errors::Cancelled("RecvAsync is cancelled.")),
           Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
      return;
    }

    // TODO(b/143786186): Investigate moving the allocation of Item outside
    // the lock.
    if (cm != nullptr) {
      // NOTE(mrry): We must wrap done with code that deregisters the
      // cancellation callback before calling the done callback, because the
      // cancellation manager may no longer be live after done is called.
      queue->push_back(new Item(
          recv_args,
          [this, cm, token, done = std::move(done)](
              const Status& s, const Rendezvous::Args& send_args,
              const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
            // TryDeregisterCallback returns true when the cancellation callback
            // is successfully deregistered. If it fails because the CM already
            // StartAbort, Unref will happen inside the cancellation callback
            // when called by the CM.
            if (cm->TryDeregisterCallback(token)) {
              // Unref case (3)
              if (this->rc_owner_) this->rc_owner_->Unref();
            }
            done(s, send_args, recv_args, v, dead);
          },
          token));
    } else {
      queue->push_back(new Item(recv_args, std::move(done), token));
    }

    mu_.unlock();
    return;
  }

  // A message has already arrived and is queued in the table under
  // this key.  Consumes the message and invokes the done closure.
  Item* item = queue->head;

  // Delete the queue when the last element has been consumed.
  if (item->next == nullptr) {
    table_.erase(key_hash);
  } else {
    queue->head = item->next;
  }
  mu_.unlock();

  // Invoke done() without holding the table lock.
  DCHECK_EQ(item->type, Item::kSend);
  done(Status::OK(), item->args, recv_args, *item->send_state.value,
       item->send_state.is_dead);
  delete item;
}

最終補齊了之前圖的所有邏輯。或者我們也可以從另一種角度來看,如下圖所示:

0xFF 參考

TensorFlow架構與設計:概述

TensorFlow核心剖析

TensorFlow架構與設計:OP本質論

[譯] TensorFlow 白皮書

2017TensorFlow開發者峰會

https://jcf94.com/2018/02/28/2018-02-28-tfunpacking3/

TensorFlow 拆包(五):Distributed

TensorFlow Architecture

『深度長文』Tensorflow程式碼解析(五)

什麼是in-graph replication和between-graph replication?

[騰訊機智] TensorFlow原始碼解析(1): 建立會話

05tensorflow分散式會話

第八節,配置分散式TensorFlow

TensorFlow 分散式(Distributed TensorFlow)

tensorflow原始碼解析之distributed_runtime

Distributed TensorFlow: A Gentle Introduction

一文說清楚Tensorflow分散式訓練必備知識

TensorFlow中的Placement啟發式演算法模組——Placer

TensorFlow的圖切割模組——Graph Partitioner

TensorFlow中的通訊機制——Rendezvous(一)本地傳輸

TensorFlow分散式採坑記

TensorFlow技術內幕(九):模型優化之分散式執行

Tensorflow架構流程]

相關文章