MXNet: Barrier

weixin_33807284發表於2019-02-15

1. KVStore裡的Barrier

在mxnet的分散式訓練裡,主要模式就是引數伺服器。每個worker或者agent就是一臺machine,server用於引數的更新。那麼,當我們期望在不同的worker之間進行同步的時候,就會需要到barrier這個方法。
當程式碼執行在worker的時候,我們可以通過呼叫kv._barrier()來進行同步。它的作用就是,會阻塞程式碼執行,直到每個worker都執行了kv._barrier()。然後接著執行。這樣就實現了同步。
那麼它是怎麼做到的呢?

通過原始碼,我們不難發現,python端的介面呼叫了c++端的方法:

void Barrier() override {
    ps::Postoffice::Get()->Barrier(ps_worker_->get_customer()->customer_id(), ps::kWorkerGroup);
}

這個全域性的PostofficeBarrier方法的部分原始碼如下:

void Postoffice::Barrier(int customer_id, int node_group) {
  // 省略部分程式碼
  // 省略部分程式碼


  std::unique_lock<std::mutex> ulk(barrier_mu_);
  barrier_done_[0][customer_id] = false;
  Message req;
  req.meta.recver = kScheduler;
  req.meta.request = true;
  req.meta.control.cmd = Control::BARRIER;
  req.meta.app_id = 0;
  req.meta.customer_id = customer_id;
  req.meta.control.barrier_group = node_group;
  req.meta.timestamp = van_->GetTimestamp();
  CHECK_GT(van_->Send(req), 0);
  barrier_cond_.wait(ulk, [this, customer_id] {
      return barrier_done_[0][customer_id];
    });
}

可以看到該方法會首先對barrier_mu_上鎖,之後將對應的barrier_done_設定為false。然後將這次的barrier資訊傳送給scheduler。告訴scheduler需要進行一次barrier。然後就阻塞等待barrier_done_被設定為true,代表完成了barrier,也就是其他的worker也都進行了barrier。

那麼問題就變成了,每個worker都是怎麼直到其他worker也進行了barrier的?

首先我們要知道,在引數伺服器也就是PS中,每個程式都會建立kvstore。如果是worker,會在建構函式中執行如下程式碼:

if (IsWorkerNode()) {
      int new_customer_id = GetNewCustomerId();
      ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);
      ps::StartAsync(new_customer_id, "mxnet\0");
      if (!ps::Postoffice::Get()->is_recovery()) {
        ps::Postoffice::Get()->Barrier(
          new_customer_id,
          ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
      }
    }

其中ps::StartAsync如下:

inline void StartAsync(int customer_id, const char* argv0 = nullptr) {
  Postoffice::Get()->Start(customer_id, argv0, false);
}

也就是說,worker在建立起ps_worker_後,開始執行postoffice,而postoffice的Start會進行一系列的操作,並呼叫van_->Start,接著vanStart會進行一系列的初始化後,開啟接受訊息的執行緒,也就是

receiver_thread_ = std::unique_ptr<std::thread>(
            new std::thread(&Van::Receiving, this));

receiving函式會使用ProcessBarrierCommand處理barrier訊號,該函式會++barrier_count_[group],也就是將對應group的barrier次數進行統計。當barrier_count_[group]等於這個group的個數的時候。它會傳送類似於ACK的返回資訊。

然後worker會呼叫Manage方法來處理該message。Manage發現是barrier的返回資訊,將barrier_done_設定為true,然後將等待的執行緒喚醒。也就是python端呼叫barrier後被阻塞的地方。

至此,就完成了一次worker之間的barrier。

相關文章