[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache

羅西的思考發表於2022-03-23

[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache

我們接下來介紹快取機制。為什麼要快取?因為叢集內部有眾多 worker。在 Master 與 Worker 之間,Worker 和 Worker 之間都需要互動,所以有必要把 Worker 和其 Grpc 通道都快取起來。可以說,在 TensorFlow 分散式環境下處處可見快取的使用。

本系列其他文章是:

[翻譯] TensorFlow 分散式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"

[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構

[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

1. WorkerCache

WorkerCache 的作用就是獲取 WorkerInterface 例項,WorkerInterface 例項可以訪問遠端 WorkerSerivice 服務。WorkerInterface 例項的典型就是 GrpcRemoteWorker。

1.1 如何使用

前面初始化 MasterEnv 時,WorkerCacheFactory 被配置到 master_env_.worker_cache_factory 之中。

master_env_.worker_cache_factory =
    [this](const WorkerCacheFactoryOptions& options,
           WorkerCacheInterface** worker_cache) {
      return WorkerCacheFactory(options, worker_cache);
    };

後續在 Master::CreateSession 之中,有如下刪減版程式碼,從中可以知道如何從工廠類之中獲取 worker_cache(WorkerCacheInterface例項),以及後續如何使用 worker_cache 進行操作。

void Master::CreateSession(const CreateSessionRequest* req,
                           CreateSessionResponse* resp, MyClosure done) {
  SchedClosure([this, req, resp, done]() {
      // 配置option
      WorkerCacheFactoryOptions worker_cache_factory_options;
      worker_cache_factory_options.protocol = &grpc_protocol;
      worker_cache_factory_options.rpc_options = &req->config().rpc_options();
    
      // 建立 worker_cache
      // Create the worker cache from the computed server_def.
      status = env_->worker_cache_factory(worker_cache_factory_options,
                                          &worker_cache);

      // 使用 worker_cache 來完成後續操作
      status =
          DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
                                         worker_cache, remote_devices.get());

  });
}

1.2 配置

WorkerCacheFactoryOptions 等價於 ServerDef,它包含 ClusterDef,job_name,task_index 等資訊。

// Options passed to the worker_cache_factory function.
struct WorkerCacheFactoryOptions {
  const ClusterDef* cluster_def = nullptr;
  const string* job_name = nullptr;
  int task_index;
  const string* protocol = nullptr;
  const RPCOptions* rpc_options = nullptr;

  WorkerCacheFactoryOptions() {}

  // Construct from a ServerDef proto.
  //
  // Note: server_def must outlive WorkerCacheFactoryOptions!
  WorkerCacheFactoryOptions(const ServerDef& server_def) {
    if (server_def.has_cluster() && !server_def.job_name().empty()) {
      cluster_def = &server_def.cluster();
      job_name = &server_def.job_name();
      task_index = server_def.task_index();
      protocol = &server_def.protocol();
      rpc_options = &server_def.default_session_config().rpc_options();
    }
  }
};

1.3 工廠類

WorkerCacheFactory 是一個函式,其作用如下:

  • 使用 ParseChannelSpec 來得到 GrpcChannelSpec 例項,GrpcChannelSpec 等價於 ClusterSpec,其包含叢集基本配置資訊。
  • 使用 NewGrpcChannelCache 拿到一個GrpcChannelCache channel_cache。這裡使用到了 GetChannelCreationFunction。
  • 使用 NewGrpcWorkerCacheWithLocalWorker(channel_cache) 得到 worker_cache。
Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
                                      WorkerCacheInterface** worker_cache) {

  // 得到 GrpcChannelSpec
  GrpcChannelSpec channel_spec;
  TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));

  // 得到 GrpcChannelCache
  std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
      channel_spec, GetChannelCreationFunction(), *options.rpc_options));

  string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
                                       "/task:", options.task_index);

  const string host_port = channel_cache->TranslateTask(name_prefix);
  int requested_port;

  auto colon_index = host_port.find_last_of(':');
  if (!strings::safe_strto32(host_port.substr(colon_index + 1),
                             &requested_port)) {
    return errors::Internal("Could not parse port for local server from \"",
                            host_port, "\".");
  }
  if (requested_port != bound_port_) {
    return errors::InvalidArgument("Requested port ", requested_port,
                                   " differs from expected port ", bound_port_);
  }
  // 得到 Worker Cache
  *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
      channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
  return Status::OK();
}

1.3.1 ParseChannelSpec

ParseChannelSpec 被用來得到 GrpcChannelSpec 例項,GrpcChannelSpec 等價於 ClusterSpec,其包含叢集基本配置資訊。

Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
                                    GrpcChannelSpec* channel_spec) {
  for (const auto& job : options.cluster_def->job()) {
    std::map<int, string> host_ports;
    for (const auto& task : job.tasks()) {
      string& host_port = host_ports[task.first];
      if (!host_port.empty()) {
        return errors::InvalidArgument("JobDef for job \"", job.name(),
                                       "\" specified two addresses for task \"",
                                       task.first, "\": ", host_port, " and ",
                                       task.second);
      }
      if (job.name() == *options.job_name && task.first == options.task_index) {
        host_port = strings::StrCat(host_name_, ":", bound_port_);
      } else {
        host_port = task.second;
      }
    }
    TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
  }
  return Status::OK();
}

1.3.2 NewGrpcChannelCache

NewGrpcChannelCache 用於建立 GrpcChannelCache 例項,可以看到,每個 Job 對應了一個 SparseGrpcChannelCache。如果只有一個 SparseGrpcChannelCache,則直接返回,否則把這些 SparseGrpcChannelCache 組合在一起構建一個 MultiGrpcChannelCache 返回。其中傳入的channel_func 是 GetChannelCreationFunction。我們後續會介紹。

GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
                                      ChannelCreationFunction channel_func,
                                      const RPCOptions& options) {
  const int num_jobs = spec.host_ports_jobs().size();
  if (!num_jobs) {
    return nullptr;
  }
  std::vector<GrpcChannelCache*> caches;
  caches.reserve(num_jobs);
  for (auto& job : spec.host_ports_jobs()) {
    caches.push_back(
        new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func,
                                   options.num_channels_per_target()));
  }
  return caches.size() == 1 ? caches[0]
                            : new MultiGrpcChannelCache(
                                  caches, options.num_channels_per_target());
}

1.3.3 NewGrpcWorkerCacheWithLocalWorker

NewGrpcWorkerCacheWithLocalWorker 方法建立 GrpcWorkerCache 例項。

WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
    std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
    WorkerInterface* local_worker, const string& local_target) {
  return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
}

local_worker 引數是通過 worker_impl() 得到並且傳入的,其生成是在 GrpcServer::Init 之中,就是本地的 GrpcWorker。

GrpcWorker* worker_impl() const { return worker_impl_.get(); }

std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
                                          const ConfigProto& config) {
  return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
}

Status GrpcServer::Init(const GrpcServerOptions& opts) {
  
    // 省略
  
    worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
                                  : NewGrpcWorker(&worker_env_, config);
  
  	// 省略
}  

我們梳理一下工廠類目前流程,可以看到,最開始輸入是 WorkerCacheFactoryOptions,然後一步一步的通過各個函式的處理,最後生成了 GrpcWorkerCache。

圖 1 工廠類流程

1.4 WorkerCacheInterface

1.4.1 介面

WorkerCacheInterface 是介面類,上面圖之中 GrpcWorkerCache 就是這個介面的派生類。

class WorkerCacheInterface {
 public:
  virtual ~WorkerCacheInterface() {}

  // Updates *workers with strings naming the remote worker tasks to
  // which open channels have been established.
  virtual void ListWorkers(std::vector<string>* workers) const = 0;
  virtual void ListWorkersInJob(const string& job_name,
                                std::vector<string>* workers) const = 0;

  // If "target" names a remote task for which an RPC channel exists
  // or can be constructed, returns a pointer to a WorkerInterface object
  // wrapping that channel. The returned value must be destroyed by
  // calling `this->ReleaseWorker(target, ret)`
  virtual WorkerInterface* GetOrCreateWorker(const string& target) = 0;

  // Release a worker previously returned by this->GetOrCreateWorker(target).
  //
  // TODO(jeff,sanjay): Consider moving target into WorkerInterface.
  // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a
  //                    per-rpc-subsystem WorkerInterface creator.
  virtual void ReleaseWorker(const string& target, WorkerInterface* worker) {
    // Subclasses may override to reuse worker objects.
    delete worker;
  }

  // Set *locality with the DeviceLocality of the specified remote device
  // within its local environment.  Returns true if *locality
  // was set, using only locally cached data.  Returns false
  // if status data for that device was not available.  Never blocks.
  virtual bool GetDeviceLocalityNonBlocking(const string& device,
                                            DeviceLocality* locality) = 0;

  // Set *locality with the DeviceLocality of the specified remote device
  // within its local environment.  Callback gets Status::OK if *locality
  // was set.
  virtual void GetDeviceLocalityAsync(const string& device,
                                      DeviceLocality* locality,
                                      StatusCallback done) = 0;

  // TODO(b/189159585): Define a general client cache maker function to
  // construct client cache of different types sharing the same underling RPC
  // channels, to replace the eager and coordination cache function.
  // Build and return a EagerClientCache object wrapping that channel.
  virtual Status GetEagerClientCache(
      std::unique_ptr<eager::EagerClientCache>* eager_client_cache) = 0;

  // Build and return a CoordinationClientCache object wrapping that channel.
  virtual Status GetCoordinationClientCache(
      std::unique_ptr<CoordinationClientCache>* coordination_client_cache) = 0;

  // Start/stop logging activity.
  virtual void SetLogging(bool active) {}

  // Discard any saved log data.
  virtual void ClearLogs() {}

  // Return logs for the identified step in *ss.  Any returned data will no
  // longer be stored.
  virtual bool RetrieveLogs(int64_t step_id, StepStats* ss) { return false; }
};

WorkerCachePartial 又繼承了 WorkerCacheInterface。

// Implements the part of the interface that caches and returns remote
// device status attributes.
class WorkerCachePartial : public WorkerCacheInterface {
 public:
  bool GetDeviceLocalityNonBlocking(const string& device,
                                    DeviceLocality* locality) override;

  void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
                              StatusCallback) override;

  ~WorkerCachePartial() override {}

  // Clear all entries from the DeviceStatus cache.
  void FlushStatusCache();

 private:
  mutex mu_;

  // Initiate a GetStatusAsync to the remote task named by "task", and
  // update the cache with all the DeviceAttributes reported.
  Status RefreshDeviceStatus(const string& device_name);

  typedef std::unordered_map<string, DeviceAttributes> StatusMap;
  StatusMap device_status_cache_ TF_GUARDED_BY(mu_);
};

1.4.2 GrpcWorkerCache

GrpcWorkerCache 則繼承了 WorkerCachePartial。

class GrpcWorkerCache : public WorkerCachePartial {
 public:
  explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
                           WorkerInterface* local_worker,
                           const string& local_target,
                           GrpcWorkerEnv* worker_env)
      : local_target_(local_target),
        local_worker_(local_worker),
        channel_cache_(channel_cache),
        worker_env_(worker_env),
        next_round_robin_assignment_(0) {}

  const string local_target_;
  WorkerInterface* const local_worker_;  // Not owned.
  std::shared_ptr<GrpcChannelCache> channel_cache_;
  WorkerCacheLogger logger_;
  GrpcWorkerEnv* worker_env_;  // Not owned

  mutex assignment_mu_;
  std::unordered_map<std::string, size_t> target_assignments_
      TF_GUARDED_BY(assignment_mu_);
  size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
};

其主要功能是使用 ListWorkers 羅列出叢集內所有 worker 的名字。

void ListWorkers(std::vector<string>* workers) const override {
  channel_cache_->ListWorkers(workers);
}

void ListWorkersInJob(const string& job_name,
                        std::vector<string>* workers) const override {
	channel_cache_->ListWorkersInJob(job_name, workers);
}

GetOrCreateWorker 會根據 Worker 的 RPC 通道建立 worker,如果是本地,則直接返回 local_worker_,就是我們前面設定的本地 GrpcWorker。

WorkerInterface* GetOrCreateWorker(const string& target) override {
  if (target == local_target_) {
    return local_worker_;
  } else {
    SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
    if (!channel) {
      return nullptr;
    }
    size_t index = AssignWorkerToThread(target);
    return NewGrpcRemoteWorker(
        channel, worker_env_->GetCompletionQueue(index),
        worker_env_->GetThreadPool(), &logger_, target);
  }
}

2. RPC 通道

Worker 執行在 RPC 通道之上,所以我們接下來看看如何建立這個 RPC 通道。因為 Worker 有快取,同樣的,RPC 通道也有快取。GrpcChannelCache 就是這個快取,其被用來獲取/建立叢集之中遠端 Worker 的 RPC 通道。

2.1 GrpcChannelCache 介面

GrpcChannelCache 是介面類,定義了一系列介面,比如:

  • ListWorkers 可以返回叢集之中的 Worker 名稱。
  • TranslateTask :把 Worker 名字 轉換為地址資訊,格式是 host:port。
  • FindWorkerChannel :從快取中查詢 grpc::Channel 例項,如果快取之中沒有,就依據地址資訊動態生成一個例項,再將其放入快取。
class GrpcChannelCache {
 public:
  virtual ~GrpcChannelCache() {}

  // Populates *workers with names of all workers which this object
  // was created to handle.  Worker names are in the format
  //  /job:<job identifier>/task:<task id>
  // e.g. /job:mnist/task:2
  virtual void ListWorkers(std::vector<string>* workers) = 0;
  virtual void ListWorkersInJob(const string& job_name,
                                std::vector<string>* workers) = 0;

  // If found, returns a gRPC channel that is connected to the remote
  // worker named by 'target'. 'target' is of the following
  // format: /job:<job identifier>/task:<task id>
  // E.g., /job:mnist/task:2
  virtual SharedGrpcChannelPtr FindWorkerChannel(const string& target) = 0;

  // Translates a string in the form `/job:X/task:Z` into a host_port.
  virtual string TranslateTask(const string& task) = 0;
};

2.2 快取機制

CachingGrpcChannelCache 是快取類,可以避免每次建立 grpc::Channel 的開銷。其定義如下,具體就是派生了 GrpcChannelCache 的 GenericCachingChannelCache。

// GrpcChannelCache that caches results to FindWorkerChannel() calls.
using CachingGrpcChannelCache = GenericCachingChannelCache<GrpcChannelCache>;

GenericCachingChannelCache,用於快取FindWorkerChannel()呼叫的結果,首先從快取中查詢 grpc::Channel 例項,如果快取之中沒有,就依據地址資訊呼叫 FindChannelOnce 動態生成一個例項,再將其放入快取。

GenericCachingChannelCache 允許使用多個通道與同一目標通訊以提高吞吐量。當同一目標存在多個通道時,每次呼叫FindWorkerChannel時,都會以 round robin 迴圈方式選擇這些通道。

注意,因為有如下定義,所以 absl::flat_hash_map<string, ChannelState> channels_ 就是 ::grpc::Channel 快取 集合。

typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;

具體程式碼是:

template <typename ChannelCacheT>
class GenericCachingChannelCache : public ChannelCacheT {
 public:
  explicit GenericCachingChannelCache(int num_channels_per_target)
      : num_channels_per_target_(
            num_channels_per_target > 0 ? num_channels_per_target : 1) {}

  ~GenericCachingChannelCache() override {}

  SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
    {
      mutex_lock l(mu_);
      auto iter = channels_.find(target);
      if (iter != channels_.end()) {
        return GetNextChannelPtrAndUpdateState(iter->second);
      }
    }
    ChannelState new_chan_state;
    for (int indx = 0; indx < num_channels_per_target_; indx++) {
      auto ch = FindChannelOnce(target);
      if (!ch) return nullptr;
      new_chan_state.channels.push_back(ch);
    }
    new_chan_state.last_used = num_channels_per_target_ - 1;

    {
      mutex_lock l(mu_);
      typename absl::flat_hash_map<string, ChannelState>::iterator iter;
      bool was_inserted;
      std::tie(iter, was_inserted) = channels_.insert({target, new_chan_state});
      return GetNextChannelPtrAndUpdateState(iter->second);
    }
  }

 protected:
  // Find the ClientChannel for "target".  Only called when no channel was
  // found in the channels_ cache for "target".  A non nullptr result will be
  // cached in channels_.
  virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;

 private:
  struct ChannelState {
    std::vector<SharedGrpcChannelPtr> channels; 
    int last_used;
  };

  // Should be called with mu_ held.
  SharedGrpcChannelPtr GetNextChannelPtrAndUpdateState(
      ChannelState& chan_state) {
    // Following statement is marked as Crash OK as this is an invariant of
    // code flow in this class.
    CHECK_EQ(chan_state.channels.size(), num_channels_per_target_);  // Crash OK
    chan_state.last_used =
        (chan_state.last_used + 1) % num_channels_per_target_;
    return chan_state.channels[chan_state.last_used];
  }

  const int num_channels_per_target_;
  // TODO(zhifengc): Eviction when the map becomes too big.
  mutex mu_;
  absl::flat_hash_map<string, ChannelState> channels_ TF_GUARDED_BY(mu_);
};

2.3 業務派生類

從 CachingGrpcChannelCache 又派生出了兩個類,具體如下:

2.3.1 葉子節點

SparseGrpcChannelCache 是葉子結點,叢集之中每個 Job 對應了一個 SparseGrpcChannelCache,SparseGrpcChannelCache 內部的 grpc::Channel 集合就是 Job 的 Task 對應的 grpc::Channel 集合,每個 Task 對應一個 grpc::Channel 。

SparseGrpcChannelCache 主要變數如下:

  • const string job_id_ :本類對應了哪一個 Job。
  • const std::map<int, string> host_ports_ :本 Job 對應 Task 的 host:port 列表。
  • const ChannelCreationFunction channel_func_ :生成 grpc:Channel 的方法。

SparseGrpcChannelCache 主要功能如下:

  • ListWorkers :該方法返回本 Job 對應的 Task 名稱列表。
  • TranslateTask:依據某個 Task 名字來得到其地址資訊(格式為host:port ),例如, /job:ps/replica:1/task:1 的地址可能就是 ps1:1111;
  • FindChannelOnce :依據某個 Task 名字來建立對應的 grpc::Channel。具體是先通過 TranslateTask 獲取到 worker 對應的 task id,然後得到地址資訊,最後用地址資訊來構建 grpc::Channel。
class SparseGrpcChannelCache : public CachingGrpcChannelCache {
 public:
  SparseGrpcChannelCache(const string& job_id,
                         const std::map<int, string>& host_ports,
                         ChannelCreationFunction channel_func,
                         int num_channels_per_target)
      : CachingGrpcChannelCache(num_channels_per_target),
        job_id_(job_id),
        host_ports_(host_ports),
        channel_func_(std::move(channel_func)) {
  }
  ~SparseGrpcChannelCache() override {}

  void ListWorkers(std::vector<string>* workers) override {
    workers->reserve(workers->size() + host_ports_.size());
    for (const auto& id_host_port : host_ports_) {
      workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
    }
  }

  void ListWorkersInJob(const string& job_name,
                        std::vector<string>* workers) override {
    if (job_name == job_id_) {
      ListWorkers(workers);
    }
  }

  string TranslateTask(const string& target) override {
    DeviceNameUtils::ParsedName parsed;
    if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
      return "";
    }

    if (!parsed.has_job || parsed.job != job_id_) {
      return "";
    }
    if (!parsed.has_replica || parsed.replica != 0) {
      return "";
    }
    int32_t task = parsed.has_task ? parsed.task : -1;
    auto iter = host_ports_.find(task);
    if (iter == host_ports_.end()) {
      return "";
    }
    return iter->second;
  }

 protected:
  SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
    const string host_port = TranslateTask(target);
    if (host_port.empty()) {
    if (host_port.empty()) {
      return nullptr;
    }
    auto chan_ptr = channel_func_(host_port);
    return chan_ptr;
  }

 private:

  const string job_id_;
  const std::map<int, string> host_ports_;
  const ChannelCreationFunction channel_func_;
  TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
};

2.3.2 非葉子結點

為了提高 SparseGrpcChannelCache 查詢過程以及對叢集所有 Worker 節點 的組合管理,TF 把 叢集內的 SparseGrpcChannelCache 組合起來,構建了 MultiGrpcChannelCache。MultiGrpcChannelCache 會把訪問過的 SparseGrpcChannelCache 快取起來。

// A ChannelCache that is the union of multiple ChannelCaches.
// Takes ownership of the caches passed to the constructor.
class MultiGrpcChannelCache : public CachingGrpcChannelCache {
 public:
  explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches,
                                 int num_channels_per_target)
      : CachingGrpcChannelCache(num_channels_per_target), caches_(caches) {}

  ~MultiGrpcChannelCache() override {
    for (GrpcChannelCache* cache : caches_) {
      delete cache;
    }
  }

  void ListWorkers(std::vector<string>* workers) override {
    for (GrpcChannelCache* cache : caches_) {
      cache->ListWorkers(workers);
    }
  }

  void ListWorkersInJob(const string& job_name,
                        std::vector<string>* workers) override {
    for (GrpcChannelCache* cache : caches_) {
      cache->ListWorkersInJob(job_name, workers);
    }
  }

  string TranslateTask(const string& target) override {
    mutex_lock l(mu_);  // could use reader lock
    GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
    if (cache == nullptr) {
      for (GrpcChannelCache* c : caches_) {
        string r = c->TranslateTask(target);
        if (!r.empty()) {
          target_caches_.insert({target, c});
          cache = c;
          break;
        }
      }
    }
    return cache->TranslateTask(target);
  }

 protected:
  SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
    for (GrpcChannelCache* cache : caches_) {
      SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
      if (ch) {
        mutex_lock l(mu_);
        target_caches_.insert({target, cache});
        return ch;
      }
    }
    return nullptr;
  }

 private:
  // List of channels used by this MultiGrpcChannelCache.
  const std::vector<GrpcChannelCache*> caches_;

  mutex mu_;
  // Cache of channels keyed by the target they are handling.
  // The same GrpcChannelCache can appear multiple times in the cache.
  std::unordered_map<string, GrpcChannelCache*> target_caches_
      TF_GUARDED_BY(mu_);
};

目前結構如下:

圖 2 快取邏輯關係

2.4 生成 GrpcChannelCache

前面在生成 GrpcChannelCache 時候,傳入了 GetChannelCreationFunction,當時沒有介紹,我們現在梳理一下。

  // 得到 GrpcChannelCache
  std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
      channel_spec, GetChannelCreationFunction(), *options.rpc_options));

2.4.1 目標&使用

我們首先看看如何使用或者說目標,就是通過 target(host:port型別的字串)來生成一個 SharedGrpcChannelPtr,我們知道,SharedGrpcChannelPtr 就是 grpc::Channel。

SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
  const string host_port = TranslateTask(target);
  if (host_port.empty()) {
  if (host_port.empty()) {
    return nullptr;
  }
  auto chan_ptr = channel_func_(host_port);
  VLOG(5) << "Channel created for: job: " << job_id_
          << " host_port: " << host_port << " target : " << target
          << " Ptr: " << chan_ptr.get();
  return chan_ptr;
}

2.4.2 NewHostPortGrpcChannel

首先要介紹 NewHostPortGrpcChannel,NewHostPortGrpcChannel 是 TF 現存的 API。其主要作用是呼叫 ::grpc::CreateCustomChannel(gRPC API)得到一個 grpc::Channel,配置到 SharedGrpcChannelPtr* channel_pointer 之上,然後返回 channel_pointer(也就是 grpc::Channel)。這個方法的返回結果是我們滿意的,但是呼叫方法不對,需要封裝或轉換一下。

Status NewHostPortGrpcChannel(const string& target,
                              const RPCOptions* rpc_options,
                              SharedGrpcChannelPtr* channel_pointer) {
  // Minimally ensure that the target is valid
  TF_RETURN_IF_ERROR(ValidateHostPortPair(target));

  ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
  *channel_pointer = ::grpc::CreateCustomChannel(
      "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
  return Status::OK();
}

2.4.3 ConvertToChannelCreationFunction

ConvertToChannelCreationFunction 方法是用來把傳入的 new_channel_func_ptr 方法轉換一下,把 new_channel_func_ptr 變成一個只需要傳入 const string& target 就可以生成 SharedGrpcChannelPtr 的方法。

ChannelCreationFunction ConvertToChannelCreationFunction(
    const std::function<Status(string, const RPCOptions*,
                               SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
  return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
    SharedGrpcChannelPtr channel_ptr;
    if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
            .ok()) {
      return channel_ptr;
    } else {
      return nullptr;
    }
  };
}

2.4.4 GetChannelCreationFunction

GetChannelCreationFunction 就是使用 NewHostPortGrpcChannel 作為傳入引數,得到一個 ConvertToChannelCreationFunction 的方法,因為這個方法才是可以被 WorkerCache工廠類利用的方法。

ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
  // We can do this because SparseGrpcChannelCache is robust to nullptr being
  // returned by the channel creation function
  return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
}

2.4.5 使用分析

回到我們的呼叫。channel_func_ 就是 GetChannelCreationFunction,於是直接呼叫就可以得到 grpc::Channel。

SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
  const string host_port = TranslateTask(target);
  auto chan_ptr = channel_func_(host_port);
}

至此,我們擴充之前的邏輯如下,中間增加了一個步驟,通過傳入 target 就可以得到 grpc::Channel:

圖 3 如何轉換

3. Cache 在系統中的位置

我們雖然總結了 Cache 如何初始化,如何使用,但是我們迷失了 Cache 在系統之中的位置,現在我們看看究竟在系統之中,Cache 處於什麼位置。GrpcWorkerCache 內部的 GrpcChannelCache 指向了系統內部的 gRPC Channel Cache,用來獲取快取的 gRPC 通道。local_worker 儲存了本地 Worker。

圖 4 Cache 的位置

當呼叫 GrpcWorkerCache 的 GetOrCreateWorker 時候,如果 target 是本地,就直接返回 local_worker(就是我們前面設定的本地 GrpcWorker),否則根據 Worker 的 RPC 通道來生成一個遠端 GrpcRemoteWorker。

圖 5 生成 worker

在 Master,Worker,MasterSesision,WorkerSession 之中,處處可見 WorkerCacheInterface(也就是GrpcWorkerCache)的身影,很多類都有一個指向 WorkerCacheInterface 的成員變數,使用相當廣泛。

4. 查詢裝置集

為了建立 WorkerSession,MasterSession 需要知道遠端所有 Worker 之上的裝置集合,所以 Master 會在建立 MasterSession 之前遍歷所有 Worker,獲取其上的裝置資訊,因為其利用了 GrpcWorkerCache 的功能,所以我們在這裡一起講解。基本邏輯如下:

  • 根據 GrpcWorkerCache::ListWorkers 獲取叢集中所有 Worker 的名字。
  • 依據 worker_name 呼叫 GetOrCreateWorker 在 worker_cache 內部查詢 WorkerInterface 物件,如果有就獲取,沒有就構建。
  • 然後構建 GetStatusRequest,傳送給找到的 Worker,具體通過 GetStatusAsync 完成。
  • Worker 返回 GetStatusResponse 之後,將呼叫回撥函式 cb (WhenFound方法)之中的函式物件來獲取 Worke 的裝置資訊。這裡需要對獲取到的裝置資訊進行處理,新增 worker_name。

圖 6 獲取裝置

4.1 DeviceFinder

4.1.1 定義

DeviceFinder 是一個函式物件,實現了查詢遠端worker裝置的演算法,我們先給出成員變數如下:

class DeviceFinder {
  ~DeviceFinder() {
    for (Device* dev : found_) delete dev;
  }

  typedef DeviceFinder ME;
  const MasterEnv* env_;
  WorkerCacheInterface* worker_cache_;
  std::vector<DeviceNameUtils::ParsedName> filters_;

  mutex mu_;
  int num_pending_ TF_GUARDED_BY(mu_);
  condition_variable pending_zero_;
  std::vector<Device*> found_ TF_GUARDED_BY(mu_);
  // List of targets to be contacted by this DeviceFinder. The
  // respective `bool` in `seen_targets_` indicates whether we have
  // heard from this target or not.
  std::vector<string> targets_;
  std::vector<bool> seen_targets_ TF_GUARDED_BY(mu_);
  Status status_;

  TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
};

4.1.2 初始化

主要邏輯是:根據 GrpcWorkerCache::ListWorkers 獲取叢集中所有的 Worker 的名字列表。

explicit DeviceFinder(
    const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
    WorkerCacheInterface* worker_cache)
    : env_(env), worker_cache_(worker_cache) {
  CHECK(worker_cache) << "Worker cache was null!";
  auto process_filter = [this](const string& filter) {
    DeviceNameUtils::ParsedName parsed;
    if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
      filters_.push_back(parsed);
    } else {
      LOG(FATAL) << "Skipping invalid filter: " << filter;
    }
  };
  for (const string& filter : device_filters) {
    process_filter(filter);
  }
  // Enumerates all known workers' target. A target name is a
  // prefix of a device name. E.g., /job:mnist/replica:0/task:10.
  if (filters_.empty()) {
    // If no filters were specified, we list all known workers in
    // `worker_cache`.
    std::vector<string> workers;
    worker_cache->ListWorkers(&workers);
    std::swap(workers, targets_);
  } else {
    // When applying filters, we must include the local worker, even if it
    // does not match any of the filters.
    CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
    const string& local_device_name = env_->local_devices[0]->name();
    DeviceNameUtils::ParsedName local_parsed_name;
    CHECK(DeviceNameUtils::ParseFullName(local_device_name,
                                         &local_parsed_name));
    bool all_filters_have_job = true;
    std::unordered_set<string> filter_job_names({local_parsed_name.job});
    for (const DeviceNameUtils::ParsedName& filter : filters_) {
      all_filters_have_job = all_filters_have_job && filter.has_job;
      if (filter.has_job) {
        filter_job_names.insert(filter.job);
      }
    }

    std::vector<string> workers;
    if (all_filters_have_job) {
      // If all of the device filters have a job specified, then we only need
      // to list the workers in the jobs named in the filter, because a worker
      // in any other job would not match any filter.
      for (const string& job_name : filter_job_names) {
        VLOG(2) << "Selectively listing workers in job: " << job_name;
        std::vector<string> workers_in_job;
        worker_cache->ListWorkersInJob(job_name, &workers_in_job);
        workers.insert(workers.end(), workers_in_job.begin(),
                       workers_in_job.end());
      }
    } else {
      // If any of the device filters does not have a job specified, then we
      // must list the workers from all jobs.
      VLOG(2) << "Listing workers in all jobs because some device "
              << "filter has no job specified. Filters were:";
      if (device_filters.empty()) {
        VLOG(2) << "- <NO FILTERS>";
      } else {
        for (const string& filter : device_filters) {
          VLOG(2) << "- " << filter;
        }
      }
      worker_cache->ListWorkers(&workers);
    }
    for (const string& name : workers) {
      if (MatchFilters(name) ||
          DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
        targets_.push_back(name);
      }
    }
  }
  seen_targets_.assign(targets_.size(), false);
}

4.1.3 GetRemoteDevices

GetRemoteDevices 方法會獲取遠端裝置,邏輯如下:

  • 利用 finder.Start() 來給叢集內部所有 Worker 廣播 GetStatusRequest。
  • 利用 finder.Wait() 收集所有 Worker 返回的 GetStatusResponse 訊息。
  • 利用 finder.GetRemoteDevices 獲取查詢結果,並且返回給客戶。
static Status GetRemoteDevices(
    const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
    WorkerCacheInterface* worker_cache,
    std::vector<std::unique_ptr<Device>>* out_remote) {
  DeviceFinder finder(device_filters, env, worker_cache);
  finder.Start();
  TF_RETURN_IF_ERROR(finder.Wait());
  finder.GetRemoteDevices(env->local_devices, out_remote);
  return Status::OK();
}
4.1.3.1 Start

Start 方法會把計數器 num_pending_ 初始化為 Worker 數目,然後遍歷 Worker,逐一呼叫 NewRemoteDevices 進行處理。

void Start() {
  {
    mutex_lock l(mu_);
    num_pending_ = targets_.size();
    if (num_pending_ == 0) {
      pending_zero_.notify_all();
    }
  }
  // Talk to all workers to get the list of available devices.
  using std::placeholders::_1;
  using std::placeholders::_2;
  for (size_t i = 0; i < targets_.size(); ++i) {
    // TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may
    // never be called.
    NewRemoteDevices(env_->env, worker_cache_, targets_[i],
                     std::bind(&ME::WhenFound, this, i, _1, _2));
  }
}

NewRemoteDevices 邏輯如下:

  • 依據 worker_name 呼叫 GetOrCreateWorker 在 worker_cache 內部查詢 WorkerInterface 物件,如果有就獲取,沒有就構建。
  • 然後構建 GetStatusRequest,傳送給找到的 Worker,具體通過 GetStatusAsync 完成。
  • Worker 返回 GetStatusResponse 之後,將呼叫回撥函式 cb (WhenFound方法)之中的函式物件來獲取 Worke 的裝置資訊。這裡需要對獲取到的裝置資訊進行處理,新增 worker_name。
void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
                      const string& worker_name, NewRemoteDevicesDone done) {
  WorkerInterface* wi = worker_cache->GetOrCreateWorker(worker_name);
  if (wi == nullptr) {
    std::vector<Device*> empty;
    done(errors::NotFound("Device ", worker_name, " is not found."), &empty);
    return;
  }
  struct Call {
    GetStatusRequest req; // 傳送訊息
    GetStatusResponse resp; // 相應訊息
  };
  Call* call = new Call;
  // 回撥函式
  auto cb = [env, worker_cache, worker_name, done, wi,
             call](const Status& status) {
    Status s = status;
    std::vector<Device*> remote_devices;
    auto cleanup = gtl::MakeCleanup(
        [&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] {
          worker_cache->ReleaseWorker(worker_name, wi);
          done(s, &remote_devices);
          delete call;
        });
    if (s.ok()) {
      DeviceNameUtils::ParsedName worker_name_parsed;
      if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) ||
          !worker_name_parsed.has_job || !worker_name_parsed.has_replica ||
          !worker_name_parsed.has_task) {
        s = errors::InvalidArgument("Could not parse worker name: ",
                                    worker_name);
        return;
      }
      remote_devices.reserve(call->resp.device_attributes_size());
      for (const DeviceAttributes& da : call->resp.device_attributes()) {
        DeviceNameUtils::ParsedName device_name_parsed;
        CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed))
            << "Device attribute name '" << da.name() << "' could not be "
            << "parsed. Device Attribute: " << da.DebugString();
        // Preserve the exact name, if possible.
        if (device_name_parsed.job == worker_name_parsed.job &&
            device_name_parsed.replica == worker_name_parsed.replica &&
            device_name_parsed.task == worker_name_parsed.task) {
          auto d = new RemoteDevice(env, da);
          remote_devices.push_back(d);
        } else {
          DeviceAttributes da_rewritten = da;
          da_rewritten.set_name(DeviceNameUtils::FullName(
              worker_name_parsed.job, worker_name_parsed.replica,
              worker_name_parsed.task, device_name_parsed.type,
              device_name_parsed.id));
          auto d = new RemoteDevice(env, da_rewritten);

          // Experimental: Skipping over adding any TPU-type devices that aren't
          // on the job called "worker" (but still adds the CPUs of other jobs).
          if (getenv("TPU_NO_POPULATE_DEVICE_LIST_FROM_CLUSTER_SPEC") !=
              nullptr) {
            if (worker_name_parsed.job == "worker" ||
                device_name_parsed.type.find("TPU") == std::string::npos) {
              remote_devices.push_back(d);
            }
          } else {
            remote_devices.push_back(d);
          }
        }
      }
    }
  };
  wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
                     /*fail_fast=*/false, cb);
}
4.1.3.2 Wait

Wait 方法之中,如果計數器不為 0,則一直呼叫 pending_zero_.wait_for 等待,期間主執行緒會週期性睡眠 10 秒鐘。

Status Wait() {
  mutex_lock l(mu_);
  // TODO(mrry): Propagate a timeout here, since `num_pending_` may
  // never become zero.
  while (num_pending_ != 0) {
    pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs));
    if (num_pending_ != 0) {
      for (size_t i = 0; i < targets_.size(); ++i) {
        if (!seen_targets_[i]) {
          LOG(INFO)
              << "CreateSession still waiting for response from worker: "
              << targets_[i];
        }
      }
    }
  }
  return status_;
}
4.1.3.3 回撥函式

Start 的回撥函式如下,如果收到了某個 Worker 的GetStatusResponse 訊息,則 Start 會呼叫到此。WhenDone將計數器減 1,如果計數器為 0,則呼叫 pending_zero_.notify_all(),這樣 wait 之中的 pending_zero_.wait_for 語句 會被喚醒,GetRemoteDevices 方法就會利用 finder.GetRemoteDevices 獲取查詢結果,並且返回給客戶。

void WhenFound(int target_index, const Status& s,
               std::vector<Device*>* devices) {
  mutex_lock l(mu_);
  seen_targets_[target_index] = true;
  if (!s.ok()) {
    LOG(ERROR) << "CreateSession failed because worker "
               << targets_[target_index] << " returned error: " << s;
    status_.Update(s);
  } else {
    found_.insert(found_.end(), devices->begin(), devices->end());
    devices->clear();
  }
  --num_pending_;
  if (num_pending_ == 0) {
    pending_zero_.notify_all();
  }
}

4.2 Worker 互動

NewRemoteDevices 之中會通過 GetStatusAsync 來構建 GetStatusRequest,傳送給找到的 Worker。

WorkerInterface* wi = worker_cache->GetOrCreateWorker(worker_name);
wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
                     /*fail_fast=*/false, cb);

4.2.1 GrpcRemoteWorker

wi 就是找到的 WorkerInterface,實際就是 GrpcRemoteWorker,這是 gRPC 的客戶端,通過 stub 呼叫遠端 WorkerService 相應的服務介面。

void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
                    GetStatusResponse* response, bool fail_fast,
                    StatusCallback done) override {
  IssueRequest(request, response, getstatus_, std::move(done), call_opts,
               fail_fast);
}

4.2.2 GrpcWorkerService

遠端 Worker 之中,接收到訊息是在 GrpcWorkerService 之中,當收到 GetStatusRequest 訊息,將 由 GetStatusHandler 回撥處理,GetStatusHandler 是一個巨集。

#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if (!s.ok()) {                                                          \
        VLOG(3) << "Bad response from " << #method << ": " << s;              \
      }                                                                       \
      call->SendResponse(ToGrpcStatus(s));                                    \
    };                                                                        \
    if ((may_block_on_compute_pool)) {                                        \
      worker_->env()->env->SchedClosure(std::move(closure));                  \
    } else {                                                                  \
      worker_->env()->compute_pool->Schedule(std::move(closure));             \
    }                                                                         \
    ENQUEUE_REQUEST(method, false);                                           \
  }

  HANDLE_CALL(GetStatus, false);

4.2.3 Worker

最後來到 Worker 類,其實它也只是轉交給 DeviceMgr,並最終通過 GetStatusResponse 訊息返回給遠端呼叫方。

void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
                            GetStatusResponse* response, bool fail_fast,
                            StatusCallback done) {
  const DeviceMgr* dm = env_->device_mgr;
  std::vector<DeviceAttributes> devices;
  dm->ListDeviceAttributes(&devices);
  response->mutable_device_attributes()->Reserve(devices.size());
  for (auto& d : devices) {
    response->add_device_attributes()->Swap(&d);
  }
  done(Status::OK());
}

4.2.4 DeviceMgr

ListDeviceAttributes 有兩種本地裝置資訊彙總的實現,具體如下。

void StaticDeviceMgr::ListDeviceAttributes(
    std::vector<DeviceAttributes>* devices) const {
  devices->reserve(devices_.size());
  for (const auto& dev : devices_) {
    devices->emplace_back(dev->attributes());
  }
}

實現 2 如下:

void DynamicDeviceMgr::ListDeviceAttributes(
    std::vector<DeviceAttributes>* devices) const {
  tf_shared_lock l(devices_mu_);
  devices->reserve(dynamic_devices_.size());
  for (const auto& d : dynamic_devices_) {
    devices->emplace_back(d->attributes());
  }
}

至此,我們分析完了 Cache 和查詢裝置集,接下來我們去看看業務如何處理。

0xFF 參考

TensorFlow Internals

TensorFlow架構與設計:概述

TensorFlow核心剖析

TensorFlow架構與設計:OP本質論

[譯] TensorFlow 白皮書

2017TensorFlow開發者峰會

https://jcf94.com/2018/02/28/2018-02-28-tfunpacking3/

TensorFlow 拆包(五):Distributed

TensorFlow Architecture

『深度長文』Tensorflow程式碼解析(五)

什麼是in-graph replication和between-graph replication?

[騰訊機智] TensorFlow原始碼解析(1): 建立會話

05tensorflow分散式會話

第八節,配置分散式TensorFlow

TensorFlow 分散式(Distributed TensorFlow)

tensorflow原始碼解析之distributed_runtime

Distributed TensorFlow: A Gentle Introduction

一文說清楚Tensorflow分散式訓練必備知識

TensorFlow中的Placement啟發式演算法模組——Placer

TensorFlow的圖切割模組——Graph Partitioner

TensorFlow中的通訊機制——Rendezvous(一)本地傳輸

TensorFlow分散式採坑記

TensorFlow技術內幕(九):模型優化之分散式執行

Tensorflow架構流程]

gRPC原始碼分析(c++)

相關文章