[原始碼解析] 機器學習引數伺服器ps-lite (1) ----- PostOffice

羅西的思考發表於2021-07-30

[原始碼解析] 機器學習引數伺服器ps-lite 之(1) ----- PostOffice

0x00 摘要

引數伺服器是機器學習訓練一種正規化,是為了解決分散式機器學習問題的一個程式設計框架,其主要包括伺服器端,客戶端和排程器,與其他正規化相比,引數伺服器把模型引數儲存和更新提升為主要元件,並且使用多種方法提高了處理能力。

本文是引數伺服器系列第一篇,介紹ps-lite的總體設計和基礎模組 Postoffice。
l

0x01 概要

1.1 引數伺服器是什麼

如果做一個類比,引數伺服器是機器學習領域的分散式記憶體資料庫,其作用是儲存模型和更新模型

我們來看看機器學習的幾個步驟,這些步驟不斷迴圈往復。

  1. 準備資料:訓練程式拿到權重 weight 和資料(data + label);
  2. 前向計算:訓練程式使用資料進行前向計算,得到 loss = f(weight, data & label);
  3. 反向求導:通過對 loss 反向求導,得到導數 grad = b(loss, weight, data & label);
  4. 更新權重:weight -= grad * lr;
  5. 來到1,進行下一次迭代;

如果使用引數伺服器訓練,我們可以把如上步驟對應如下:

  1. 引數下發:引數伺服器服務端 將 weight 發給 每個worker(或者worker自行拉取),worker就是引數伺服器Client端;
  2. 平行計算:每個worker 分別完成自己的計算(包括前向計算和反向求導);
  3. grad 收集:引數伺服器服務端 從每個 Worker 處得到 grad,完成歸併(或者worker自行推送);
  4. 更新權重:引數伺服器服務端 自行將 grad 應用到 weight 上;
  5. 來到1,進行下一次迭代;

具體如下圖:

          FP/BP    +--------+  Gather/Sum                                       FP/BP            +-------+    Gather/Sum
      +----------> | grad 1 +------+                                    +----------------------> |grad 2 +-----------+
      |            +--------+      |                                    |                        +-------+           |
+-----+----+                       v                     +--------------+-------------------+                        v
|          |                   +---+----------+  Update  |                                  |                 +------+-----+ Update   +------------------+
| weight 1 |                   | total grad 1 +--------->+weight 2 = weight 1 - total grad 1|                 |total grad 2+--------> |weight 2 = ...... |
|          |                   +---+----------+          |                                  |                 +------+-----+          +------------------+
+-----+----+                       ^                     +--------------+-------------------+                        ^
      |   FP/BP    +--------+      |                                    |       FP/BP            +-------+           |
      +----------> | grad 2 +------+                                    +----------------------> |grad 2 +-----------+
                   +--------+  Gather/Sum                                                        +-------+    Gather/Sum

手機如下:

因此我們可以推匯出引數伺服器之中各個模組的作用:

  • 伺服器端(Server ):存放機器學習模型引數,接收客戶端傳送的梯度,完成歸併,對本地模型引數進行更新。
  • 客戶端(Client 或者 Worker):
    • 從伺服器端獲取當前最新的引數;
    • 使用訓練資料和從最新引數計算得到預測值,根據損失函式來計算關於訓練引數的梯度;
    • 將梯度傳送給伺服器端;
  • 排程器(Scheduler):管理伺服器/客戶端節點,完成節點之間資料同步,節點新增/刪除等功能。

1.2 歷史溯源

引數伺服器屬於機器學習訓練的一個正規化,具體可以分為三代(目前各大公司應該有自己內部最新實現,可以算為第四代)。

在引數伺服器之前,大部分分散式機器學習演算法是通過定期同步來實現的,比如集合通訊的all-reduce,或者 map-reduce類系統的reduce步驟。但是定期同步有兩個問題:

  • 同步時期只能做同步,不能訓練。
  • straggler問題:由於一些軟硬體的原因,節點的計算能力往往不盡相同。對於迭代問題來說,每一輪結束時算得快的節點都需等待算得慢的節點算完,再進行下一輪迭代。這種等待在節點數增多時將變得尤為明顯,從而拖慢整體的效能。

因此,當async sgd出現之後,就有人提出了引數伺服器。

引數伺服器的概念最早來自於Alex Smola於2010年提出的並行LDA的框架。它通過採用一個分散式的Memcached作為存放共享引數的儲存,這樣就提供了有效的機制用於分散式系統中不同的Worker之間同步模型引數,而每個Worker只需要儲存他計算時所以來的一小部分引數即可,也避免了所有程式在一個時間點上都停下來同步。但是獨立的kv對帶來了很大的通訊開銷,而且服務端端難以程式設計。

第二代由Google的Jeff Dean進一步提出了第一代Google大腦的解決方案:DistBelief。DistBelief將巨大的深度學習模型分佈儲存在全域性的引數伺服器中,計算節點通過引數伺服器進行資訊傳遞,很好地解決了SGD和L-BFGS演算法的分散式訓練問題。

再後來就是李沐所在的DMLC組所設計的引數伺服器。根據論文中所寫,該parameter server屬於第三代引數伺服器,就是提供了更加通用的設計。架構上包括一個Server Group和若干個Worker Group。

1.3 論文架構

我們首先用沐神論文中的圖來看看系統架構。

解釋一下圖中整體架構中每個模組:

  • resource manager:資源分配及管理器。引數伺服器使用業界現有的資源管理系統,比如yarn,k8s。
  • training data:幾十上百億的訓練資料一般儲存在分散式檔案系統上(比如HDFS),resource manager會均勻的分配到每個worker上。
  • 引數伺服器的節點被劃分到一個 server group 和多個 worker group。
  • server group:一次訓練任務中申請的servers,用於模型引數的更新和pull應答。
    • server group 中的每個 server 只負責自己分到的部分全域性共享引數(server 共同維持一個全域性共享引數),一般優化器在此實現。
    • server 之間相互通訊以便進行引數的備份/遷移。
    • server group 有一個 server manager node,負責維護 server 後設資料的一致性,例如節點狀態,引數的分配情況。一般不會有什麼邏輯,只有當有server node加入或退出的時候,為了維持一致性雜湊而做一些調整。
  • worker group:一次訓練任務中申請的workers,用於前向過程和梯度計算。
    • 每個 worker group 執行一個計算任務,worker group 中的 每個worker 使用部分資料進行訓練。
    • 分成多個group,這樣就可以支援多工的平行計算。
    • 每個 worker group 有一個 task scheduler,負責向 worker 分配任務,並監控他們的執行情況,當有 worker 進入或者退出時,task scheduler 重新分配未完成的任務。
    • worker 之間沒有通訊,只和對應的 server 通訊進行引數更新。

在分散式計算梯度時,系統的資料流如下:

圖中每個步驟的作用為:

  1. worker 節點 基於該 batch 內的樣本計算模型權重的梯度;
  2. worker將梯度以key-value的形式推送給server;
  3. server按指定的優化器對模型權重進行梯度更新;
  4. worker從server中拉取最新的模型權重;

上面兩個圖的依據是其原始程式碼。ps-lite 是後來的精簡版程式碼,所以有些功能在 ps-lite 之中沒有提供。

1.4 ps-lite發展歷程

從網上找到了一些 ps-lite發展歷程,可以看到其演進的思路。

第一代是parameter,針對特定演算法(如邏輯迴歸和LDA)進行了設計和優化,以滿足規模龐大的工業機器學習任務(數百億個示例和10-100TB資料大小的功能)。

後來嘗試為機器學習演算法構建一個開源通用框架。 該專案位於dmlc / parameter_server。

鑑於其他專案的需求不斷增長,建立了ps-lite,它提供了一個乾淨的資料通訊API和一個輕量級的實現。 該實現基於dmlc / parameter_server,但為不同的專案重構了作業啟動器,檔案IO和機器學習演算法程式碼,如dmlc-core和wormhole

根據在開發dmlc / mxnet期間學到的經驗,從v1進一步重構了API和實現。 主要變化包括:

  • 庫依賴性較少;
  • 更靈活的使用者定義回撥,便於其他語言繫結;
  • 讓使用者(如mxnet的依賴引擎)管理資料一致性;

1.5 ps-lite 系統總體

ps-lite 其實是Paramter Server的實現的一個框架,其中引數處理具體相關策略需使用者自己實現

Parameter Server包含三種角色:Worker,Server,Scheduler。具體關係如下圖:

具體角色功能為:

  • worker(工作節點):若干個,執行data pipeline、前向和梯度計算,以key-value的形式將模型權重梯度push到server節點以及從server節點拉取模型最新權重;
  • server(服務節點):若干個,負責對worker的push和pull請求做response,儲存,維護和更新模型權重以供各個worker使用(每個server僅維護模型的一部分);
  • scheduler(控制節點):系統內只有一個。負責所有節點的心跳監測、節點id分配和worker&server間的通訊建立,它還可用於將控制訊號傳送到其他節點並收集其進度。

其中引入scheduler的好處如下:

  • 引入一個 scheduler 模組,則會形成一個比較經典的三角色分散式系統架構;worker 和 server 的角色和職責不變,而 scheduler 模組則有比較多的選擇:
    • 只承擔和下層資源排程系統般若(類似 yarn、mesos)的互動;
    • 額外增加對 worker、server 心跳監控、流程控制的功能;
  • 引入 scheduler 模組的另一個好處是給實現模型並行留出了空間;
  • scheduler 模組不僅有利於實現模型並行訓練正規化,還有其他好處:比如通過針對特定模型引數相關性的理解,對引數訓練過程進行細粒度的排程,可以進一步加快模型收斂速度,甚至有機會提升模型指標。

熟悉分散式系統的同學可能會擔心 scheduler 模組的單點問題,這個通過 raft、zab 等 paxos 協議可以得到比較好的解決。

1.6 基礎模組

ps-lite系統中的一些基礎模組如下:

  • Environment:一個單例模式的環境變數類,它通過一個 std::unordered_map<std::string, std::string> kvs 維護了一組 kvs 藉以儲存所有環境變數名以及值;

  • PostOffice:一個單例模式的全域性管理類,一個 node 在生命期內具有一個PostOffice,依賴它的類成員對Node進行管理;

  • Van:通訊模組,負責與其他節點的網路通訊和Message的實際收發工作。PostOffice持有一個Van成員;

  • SimpleApp:KVServer和KVWorker的父類,它提供了簡單的Request, Wait, Response,Process功能;KVServer和KVWorker分別根據自己的使命重寫了這些功能;

  • Customer:每個SimpleApp物件持有一個Customer類的成員,且Customer需要在PostOffice進行註冊,該類主要負責:

    • 跟蹤由SimpleApp傳送出去的訊息的回覆情況;
    • 維護一個Node的訊息佇列,為Node接收訊息;
  • Node :資訊類,儲存了本節點的對應資訊,每個 Node 可以使用 hostname + port 來唯一標識。

0x02 系統啟動

2.1 如何啟動

從原始碼中的例子可以看出,使用ps-lite 提供的指令碼 local.sh 可以啟動整個系統,這裡 test_connection 為編譯好的可執行程式。

./local.sh 2 3 ./test_connection

2.2 啟動指令碼

具體 local.sh 程式碼如下。注意,在shell指令碼中,有三個shift,這就讓指令碼中始終使用$1。

針對我們的例子,指令碼引數對應了就是

  • DMLC_NUM_SERVER 為 2;
  • DMLC_NUM_WORKER 為 3;
  • bin 是 ./test_connection;

可以從指令碼中看到,本指令碼做了兩件事:

  • 每次執行應用程式之前,都會依據本次執行的角色來對環境變數進行各種設定,除了DMLC_ROLE設定得不同外,其他變數在每個節點上都相同。
  • 在本地執行多個不同角色。這樣 ps-lite 就用多個不同的程式(程式)共同合作完成工作。
    • 首先啟動Scheduler節點。這是要固定好Server和Worker數量,Scheduler節點管理所有節點的地址。
    • 啟動Worker或Server節點。每個節點要知道Scheduler節點的IP、port。啟動時連線Scheduler節點,繫結本地埠,並向Scheduler節點註冊自己資訊(報告自己的IP,port)。
    • Scheduler等待所有Worker節點都註冊後,給其分配id,並把節點資訊傳送出去(例如Worker節點要知道Server節點IP和埠,Server節點要知道Worker節點的IP和埠)。此時Scheduler節點已經準備好。
    • Worker或Server接收到Scheduler傳送的資訊後,建立對應節點的連線。此時Worker或Server已經準備好,會正式啟動。

具體如下:

#!/bin/bash
# set -x
if [ $# -lt 3 ]; then
    echo "usage: $0 num_servers num_workers bin [args..]"
    exit -1;
fi

# 對環境變數進行各種配置,此後不同節點都會從這些環境變數中獲取資訊
export DMLC_NUM_SERVER=$1
shift
export DMLC_NUM_WORKER=$1
shift
bin=$1
shift
arg="$@"

# start the scheduler
export DMLC_PS_ROOT_URI='127.0.0.1'
export DMLC_PS_ROOT_PORT=8000
export DMLC_ROLE='scheduler'
${bin} ${arg} &


# start servers
export DMLC_ROLE='server'
for ((i=0; i<${DMLC_NUM_SERVER}; ++i)); do
    export HEAPPROFILE=./S${i}
    ${bin} ${arg} &
done

# start workers
export DMLC_ROLE='worker'
for ((i=0; i<${DMLC_NUM_WORKER}; ++i)); do
    export HEAPPROFILE=./W${i}
    ${bin} ${arg} &
done

wait

2.3 示例程式

我們依然使用官方例子看看。

ps-lite 使用的是 C++語言,其中 worker, server, scheduler 都使用同一套程式碼。這會讓習慣於Java,python的同學非常不適應,大家需要適應一個階段。

針對這個示例程式,起初會讓人疑惑,為什麼每次程式執行,程式碼中都會啟動 scheduler,worker,server?其實,從下面註釋就能看出來,具體執行是依據環境變數來決定。如果環境變數設定了本次角色是 server,則不會啟動 scheduler 和 worker。

#include <cmath>
#include "ps/ps.h"

using namespace ps;

void StartServer() {
  if (!IsServer()) {
    return;
  }
  auto server = new KVServer<float>(0);
  server->set_request_handle(KVServerDefaultHandle<float>()); //註冊functor
  RegisterExitCallback([server](){ delete server; });
}

void RunWorker() {
  if (!IsWorker()) return;
  KVWorker<float> kv(0, 0);

  // init
  int num = 10000;
  std::vector<Key> keys(num);
  std::vector<float> vals(num);

  int rank = MyRank();
  srand(rank + 7);
  for (int i = 0; i < num; ++i) {
    keys[i] = kMaxKey / num * i + rank;
    vals[i] = (rand() % 1000);
  }

  // push
  int repeat = 50;
  std::vector<int> ts;
  for (int i = 0; i < repeat; ++i) {
    ts.push_back(kv.Push(keys, vals)); //kv.Push()返回的是該請求的timestamp

    // to avoid too frequency push, which leads huge memory usage
    if (i > 10) kv.Wait(ts[ts.size()-10]);
  }
  for (int t : ts) kv.Wait(t);

  // pull
  std::vector<float> rets;
  kv.Wait(kv.Pull(keys, &rets));

  // pushpull
  std::vector<float> outs;
  for (int i = 0; i < repeat; ++i) {
    // PushPull on the same keys should be called serially
    kv.Wait(kv.PushPull(keys, vals, &outs));
  }

  float res = 0;
  float res2 = 0;
  for (int i = 0; i < num; ++i) {
    res += std::fabs(rets[i] - vals[i] * repeat);
    res2 += std::fabs(outs[i] - vals[i] * 2 * repeat);
  }
  CHECK_LT(res / repeat, 1e-5);
  CHECK_LT(res2 / (2 * repeat), 1e-5);
  LL << "error: " << res / repeat << ", " << res2 / (2 * repeat);
}

int main(int argc, char *argv[]) {
  // start system
  Start(0); // Postoffice::start(),每個node都會呼叫到這裡,但是在 Start 函式之中,會依據本次設定的角色來不同處理,只有角色為 scheduler 才會啟動 Scheduler。
  // setup server nodes
  StartServer(); // Server會在其中做有效執行,其他節點不會有效執行。
  // run worker nodes
  RunWorker(); // Worker 會在其中做有效執行,其他節點不會有效執行。
  // stop system
  Finalize(0, true); //結束。每個節點都需要執行這個函式。
  return 0;
}

其中KVServerDefaultHandle是functor,用與處理server收到的來自worker的請求,具體如下:

/**
 * \brief an example handle adding pushed kv into store
 */
template <typename Val>
struct KVServerDefaultHandle { //functor,用與處理server收到的來自worker的請求
    // req_meta 是儲存該請求的一些元資訊,比如請求來自於哪個節點,傳送給哪個節點等等
    // req_data 是傳送過來的資料
    // server 是指向當前server物件的指標  
  void operator()(
      const KVMeta& req_meta, const KVPairs<Val>& req_data, KVServer<Val>* server) {
    size_t n = req_data.keys.size();
    KVPairs<Val> res;
    if (!req_meta.pull) { //收到的是pull請求
      CHECK_EQ(n, req_data.vals.size());
    } else { //收到的是push請求
      res.keys = req_data.keys; res.vals.resize(n);
    }
    for (size_t i = 0; i < n; ++i) {
      Key key = req_data.keys[i];
      if (req_meta.push) { //push請求
        store[key] += req_data.vals[i]; //此處的操作是將相同key的value相加
      }
      if (req_meta.pull) {  //pull請求
        res.vals[i] = store[key];
      }
    }
    server->Response(req_meta, res);
  }
  std::unordered_map<Key, Val> store;
};

0x03 Postoffice

Postoffice 是一個單例模式的全域性管理類,其維護了系統的一個全域性資訊,具有如下特點:

  • 三種Node角色都依賴 Postoffice 進行管理,每一個 node 在生命期內具有一個單例 PostOffice。
  • 如我們之前所說,ps-lite的特點是 worker, server, scheduler 都使用同一套程式碼,Postoffice也是如此,所以我們最好分開描述。
  • 在 Scheduler側,顧名思義,Postoffice 是郵局,可以認為是一個地址簿,一個調控中心,其記錄了系統(由scheduler,server, worker 集體構成的這個系統)中所有節點的資訊。具體功能如下:
    • 維護了一個Van物件,負責整個網路的拉起、通訊、命令管理如增加節點、移除節點、恢復節點等等;
    • 負責整個叢集基本資訊的管理,比如worker、server數的獲取,管理所有節點的地址,server 端 feature分佈的獲取,worker/server Rank與node id的互轉,節點角色身份等等;
    • 負責 Barrier 功能;
  • 在 Server / Worker 端,負責:
    • 配置當前node的一些資訊,例如當前node是哪種型別(server,worker),nodeid是啥,以及worker/server 的rank 到 node id的轉換。
    • 路由功能:負責 key 與 server 的對應關係。
    • Barrier 功能;

請注意:這些程式碼都是在 Postoffice 類內,沒有按照角色分開成多個模組。

3.1 定義

類 UML 圖如下:

下面我們只給出關鍵變數和成員函式說明,因為每個節點都包含一個 PostOffice,所以 PostOffice 的資料結構中包括了各種節點所需要的變數,會顯得比較繁雜。

主要變數作用如下:

  • van_ :底層通訊物件;
  • customers_ :本節點目前有哪些 customer;
  • node_ids_ :node id 對映表;
  • server_key_ranges_ :Server key 區間範圍物件
  • is_worker_, is_server_, is_scheduler_ :標註了本節點型別;
  • heartbeats_ :節點心跳物件;
  • barrier_done_ : Barrier 同步變數;

主要函式作用如下:

  • InitEnvironment :初始化環境變數,建立 van 物件;
  • Start :建立通訊初始化;
  • Finalize :節點阻塞退出;
  • Manage :退出 barrier 阻塞狀態;
  • Barrier :進入 barrier 阻塞狀態;
  • UpdateHeartbeat :
  • GetDeadNodes :根據 heartbeats_ 獲取已經 dead 的節點;

具體如下:

class Postoffice {
  /**
   * \brief start the system
   *
   * This function will block until every nodes are started.
   * \param argv0 the program name, used for logging.
   * \param do_barrier whether to block until every nodes are started.
   */
  void Start(int customer_id, const char* argv0, const bool do_barrier);
  /**
   * \brief terminate the system
   *
   * All nodes should call this function before existing.
   * \param do_barrier whether to do block until every node is finalized, default true.
   */
  void Finalize(const int customer_id, const bool do_barrier = true);
  /**
   * \brief barrier
   * \param node_id the barrier group id
   */
  void Barrier(int customer_id, int node_group);
  /**
   * \brief process a control message, called by van
   * \param the received message
   */
  void Manage(const Message& recv);
  /**
   * \brief update the heartbeat record map
   * \param node_id the \ref Node id
   * \param t the last received heartbeat time
   */
  void UpdateHeartbeat(int node_id, time_t t) {
    std::lock_guard<std::mutex> lk(heartbeat_mu_);
    heartbeats_[node_id] = t;
  }
  /**
   * \brief get node ids that haven't reported heartbeats for over t seconds
   * \param t timeout in sec
   */
  std::vector<int> GetDeadNodes(int t = 60);  
 private:  
 void InitEnvironment();  
  Van* van_;
  mutable std::mutex mu_;
  // app_id -> (customer_id -> customer pointer)
  std::unordered_map<int, std::unordered_map<int, Customer*>> customers_;
  std::unordered_map<int, std::vector<int>> node_ids_;
  std::mutex server_key_ranges_mu_;
  std::vector<Range> server_key_ranges_;
  bool is_worker_, is_server_, is_scheduler_;
  int num_servers_, num_workers_;
  std::unordered_map<int, std::unordered_map<int, bool> > barrier_done_;
  int verbose_;
  std::mutex barrier_mu_;
  std::condition_variable barrier_cond_;
  std::mutex heartbeat_mu_;
  std::mutex start_mu_;
  int init_stage_ = 0;
  std::unordered_map<int, time_t> heartbeats_;
  Callback exit_callback_;
  /** \brief Holding a shared_ptr to prevent it from being destructed too early */
  std::shared_ptr<Environment> env_ref_;
  time_t start_time_;
  DISALLOW_COPY_AND_ASSIGN(Postoffice);
}; 

3.2 ID 對映功能

首先我們介紹下 node id 對映功能,就是如何在邏輯節點和物理節點之間做對映,如何把物理節點劃分成各個邏輯組,如何用簡便的方法做到給組內物理節點統一發訊息

  • 1,2,4分別標識Scheduler, ServerGroup, WorkerGroup。
  • SingleWorker:rank * 2 + 9;SingleServer:rank * 2 + 8。
  • 任意一組節點都可以用單個id標識,等於所有id之和。

3.2.1 概念

  • Rank 是一個邏輯概念,是每一個節點(scheduler,work,server)內部的唯一邏輯標示
  • Node id 是物理節點的唯一標識,可以和一個 host + port 的二元組唯一對應
  • Node Group 是一個邏輯概念,每一個 group 可以包含多個 node id。ps-lite 一共有三組 group : scheduler 組,server 組,worker 組。
  • Node group id 是 是節點組的唯一標示。
    • ps-lite 使用 1,2,4 這三個數字分別標識 Scheduler,ServerGroup,WorkerGroup。每一個數字都代表著一組節點,等於所有該型別節點 id 之和。比如 2 就代表server 組,就是所有 server node 的組合。
    • 為什麼選擇這三個數字?因為在二進位制下這三個數值分別是 "001, 010, 100",這樣如果想給多個 group 發訊息,直接把 幾個 node group id 做 或操作 就行。
    • 即 1-7 內任意一個數字都代表的是Scheduler / ServerGroup / WorkerGroup的某一種組合。
      • 如果想把某一個請求傳送給所有的 worker node,把請求目標節點 id 設定為 4 即可。
      • 假設某一個 worker 希望向所有的 server 節點 和 scheduler 節點同時傳送請求,則只要把請求目標節點的 id 設定為 3 即可,因為 3 = 2 + 1 = kServerGroup + kScheduler。
      • 如果想給所有節點傳送訊息,則設定為 7 即可。

3.2.2 邏輯組的實現

三個邏輯組的定義如下:

/** \brief node ID for the scheduler */
static const int kScheduler = 1;
/**
 * \brief the server node group ID
 *
 * group id can be combined:
 * - kServerGroup + kScheduler means all server nodes and the scheuduler
 * - kServerGroup + kWorkerGroup means all server and worker nodes
 */
static const int kServerGroup = 2;
/** \brief the worker node group ID */
static const int kWorkerGroup = 4;

3.2.3 Rank vs node id

node id 是物理節點的唯一標示,rank 是每一個邏輯概念(scheduler,work,server)內部的唯一標示。這兩個標示由一個演算法來確定。

如下面程式碼所示,如果配置了 3 個worker,則 worker 的 rank 從 0 ~ 2,那麼這幾個 worker 實際對應的 物理 node ID 就會使用 WorkerRankToID 來計算出來。

    for (int i = 0; i < num_workers_; ++i) {
      int id = WorkerRankToID(i);
      for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                    kWorkerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }

具體計算規則如下:

  /**
   * \brief convert from a worker rank into a node id
   * \param rank the worker rank
   */
  static inline int WorkerRankToID(int rank) {
    return rank * 2 + 9;
  }
  /**
   * \brief convert from a server rank into a node id
   * \param rank the server rank
   */
  static inline int ServerRankToID(int rank) {
    return rank * 2 + 8;
  }
  /**
   * \brief convert from a node id into a server or worker rank
   * \param id the node id
   */
  static inline int IDtoRank(int id) {
#ifdef _MSC_VER
#undef max
#endif
    return std::max((id - 8) / 2, 0);
  }

這樣我們可以知道,1-7 的id表示的是node group,單個節點的id 就從 8 開始。

而且這個演算法保證server id為偶數,node id為奇數。

  • SingleWorker:rank * 2 + 9;
  • SingleServer:rank * 2 + 8;

3.2.4 Group vs node

因為有時請求要傳送給多個節點,所以ps-lite用了一個 map 來儲存每個 node group / single node 對應的實際的node節點集合,即 確定每個id值對應的節點id集。

std::unordered_map<int, std::vector<int>> node_ids_ 

如何使用這個node_ids_?我們還是需要看之前的程式碼:

    for (int i = 0; i < num_workers_; ++i) {
      int id = WorkerRankToID(i);
      for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                    kWorkerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }

我們回憶一下之前的節點資訊:

  • 1 ~ 7 的 id 表示的是 node group;
  • 後續的 id(8,9,10,11 ...)表示單個的 node。其中雙數 8,10,12... 表示 worker 0, worker 1, worker 2,... 即(2n + 8),9,11,13,...,表示 server 0, server 1,server 2,...,即(2n + 9);

所以,為了實現 “設定 1-7 內任意一個數字 可以傳送給其對應的 所有node” 這個功能,對於每一個新節點,需要將其對應多個id(node,node group)上,這些id組就是本節點可以與之通訊的節點。例如對於 worker 2 來說,其 node id 是 2 * 2 + 8 = 12,所以需要將它與

  • 12(本身)
  • 4(kWorkerGroup)li
  • 4+1(kWorkerGroup + kScheduler)
  • 4+2(kWorkerGroup + kServerGroup)
  • 4+1+2,(kWorkerGroup + kServerGroup + kScheduler )

這 5 個id 相對應,即需要在 node_ids_ 這個對映表中對應的 4, 4 + 1, 4 + 2, 4 +1 + 2, 12 這五個 item 之中新增。就是上面程式碼中的內部 for 迴圈條件。即,node_ids_ [4], node_ids_ [5],node_ids_ [6],node_ids_ [7] ,node_ids_ [12] 之中,都需要把 12 新增到 vector 最後。

3.3 參數列示

workers 跟 servers 之間通過 pushpull 來通訊。worker 通過 push 將計算好的梯度傳送到server,然後通過 pull 從server更新引數。

3.3.1 KV格式

parameter server 中,引數都是可以被表示成(key, value)的集合,比如一個最小化損失函式的問題,key就是feature ID,而value就是它的權值。對於稀疏引數來說,value不存在的key,就可以認為value是0。

把參數列示成 k-v, 形式更自然,易於理解和程式設計實現。

3.3.2 key-values

分散式演算法有兩個額外成本:資料通訊成本,負載均衡不理想和機器效能差異導致的同步成本。

對於高維機器學習訓練來說,因為高頻特徵更新極為頻繁,所會導致網路壓力極大。如果每一個引數都設一個key並且按key更新,那麼會使得通訊變得更加頻繁低效,為了抹平這個問題,就需要有折衷和平衡,即,
利用機器學習演算法的特性,給每個key對應的value賦予一個向量或者矩陣,這樣就可以一次性傳遞多個引數,權衡了融合與同步的成本。

做這樣的操作的前提是假設引數是有順序的。缺點是在對於稀疏模型來說,總會在向量或者矩陣裡會有引數為0,這在單個引數狀態下是不用存的,所以,造成了資料的冗餘。

但這樣做有兩點好處:

  • 降低網路通訊
  • 使得向量層面的操作變得可行,從而很多線性庫的優化特性可以利用的上,比如BLAS、LAPACK、ATLAS等。

3.3.3 Range 操作

為了提高計算效能和頻寬效率,引數伺服器也會採用批次更新的辦法,來減輕高頻 key 的壓力。比如把minibatch之中高頻key合併成一個minibatch進行更新。

ps-lite 允許使用者使用 Range PushRange Pull 操作。

3.4 路由功能(keyslice)

路由功能指的就是:Worker 在做 Push/Pull 時候,如何知道把訊息傳送給哪些 Servers

我們知道,ps-lite 是多 Server 架構,一個很重要的問題是如何分佈多個引數。比如給定一個引數的鍵,如何確定其儲存在哪一臺 Server 上。所以必然有一個路由邏輯用來確立 key與server的對應關係。

PS Lite 將路由邏輯放置在 Worker 端,採用範圍劃分的策略,即每一個 Server 有自己固定負責的鍵的範圍。這個範圍是在 Worker 啟動的時候確定的。細節如下:

  • 根據編譯 PS Lite 時是否設定的巨集 USE_KEY32 來決定引數的鍵的資料型別,要麼是 32 位無符號整數,要麼是 64 位的。
  • 根據鍵的資料型別,確定其值域的上界。例如 uint32_t 的上界是 4294967295。
  • 根據鍵域的上界和啟動時獲取的 Server 數量(即環境變數 DMLC_NUM_SERVER 的值)來劃分範圍。
  • 每個server維護的key範圍按 uint32_t / uint64_t 從小到大等距分割槽間。給定上界 MAX 和 Server 數量 N,第 i 個 Server 負責的範圍是 [MAX/N*i, MAX/N*(i+1))
  • 對key的hash值構造有一定的要求以避免server間的key傾斜(如32位、16位、8位、4位、2位高低位對調)。
  • Worker push和pull的key按升序排列進行slice以實現zero copy。

需要注意的是,在不能剛好整除的情況下,鍵域上界的一小段被丟棄了。

具體實現如下:

首先,ps-lite的key只支援int型別。

#if USE_KEY32
/*! \brief Use unsigned 32-bit int as the key type */
using Key = uint32_t;
#else
/*! \brief Use unsigned 64-bit int as the key type */
using Key = uint64_t;
#endif
/*! \brief The maximal allowed key value */
static const Key kMaxKey = std::numeric_limits<Key>::max();

其次,將int範圍均分即可

const std::vector<Range>& Postoffice::GetServerKeyRanges() {
  if (server_key_ranges_.empty()) {
    for (int i = 0; i < num_servers_; ++i) {
      server_key_ranges_.push_back(Range(
          kMaxKey / num_servers_ * i,
          kMaxKey / num_servers_ * (i+1)));
    }
  }
  return server_key_ranges_;
}

3.5 初始化環境

從之前分析中我們可以知道,ps-lite 是通過環境變數來控制具體節點。

具體某個節點屬於哪一種取決於啟動節點之前設定了哪些環境變數以及其數值。

環境變數包括:節點角色,worker&server個數、ip、port等。

InitEnvironment 函式就是建立了 Van,得到了 worker 和 server 的數量,得到了本節點的型別。

void Postoffice::InitEnvironment() {
  const char* val = NULL;
  std::string van_type = GetEnv("DMLC_PS_VAN_TYPE", "zmq");
  van_ = Van::Create(van_type);
  val = CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_WORKER"));
  num_workers_ = atoi(val);
  val =  CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_SERVER"));
  num_servers_ = atoi(val);
  val = CHECK_NOTNULL(Environment::Get()->find("DMLC_ROLE"));
  std::string role(val);
  is_worker_ = role == "worker";
  is_server_ = role == "server";
  is_scheduler_ = role == "scheduler";
  verbose_ = GetEnv("PS_VERBOSE", 0);
}

3.6 啟動

主要就是:

  • 呼叫 InitEnvironment() 來初始化環境,建立 VAN 物件;
  • node_ids_初始化。根據worker和server節點個數,確定每個id值對應的節點id集。具體邏輯我們前面有分析。
  • 啟動 van,這裡會進行各種互動(有一個 ADD_NODE 同步等待,與後面的 barrier 等待不同);
  • 如果是第一次呼叫PostOffice::Start,初始化start_time_成員;
  • 如果設定了需要 barrier,則呼叫 Barrier 來進行 等待/處理 最終系統統一啟動。即 所有Node準備並向Scheduler傳送要求同步的Message,進行第一次同步;

具體程式碼如下:

void Postoffice::Start(int customer_id, const char* argv0, const bool do_barrier) {
  start_mu_.lock();
  if (init_stage_ == 0) {
    InitEnvironment();

    // init node info.
    // 對於所有的worker,進行node設定
    for (int i = 0; i < num_workers_; ++i) {
      int id = WorkerRankToID(i);
      for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                    kWorkerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }
		// 對於所有的server,進行node設定
    for (int i = 0; i < num_servers_; ++i) {
      int id = ServerRankToID(i);
      for (int g : {id, kServerGroup, kWorkerGroup + kServerGroup,
                    kServerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }
		// 設定scheduler的node
    for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup,
                  kScheduler + kWorkerGroup, kScheduler + kServerGroup}) {
      node_ids_[g].push_back(kScheduler);
    }
    init_stage_++;
  }
  start_mu_.unlock();

  // start van
  van_->Start(customer_id);

  start_mu_.lock();
  if (init_stage_ == 1) {
    // record start time
    start_time_ = time(NULL);
    init_stage_++;
  }
  start_mu_.unlock();
  // do a barrier here
  if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
}

3.7 Barrier

3.7.1 同步

總的來講,schedular節點通過計數的方式實現各個節點的同步。具體來說就是:

  • 每個節點在自己指定的命令執行完後會向schedular節點傳送一個Control::BARRIER命令的請求並自己阻塞直到收到schedular對應的返回後才解除阻塞;
  • schedular節點收到請求後則會在本地計數,看收到的請求數是否和barrier_group的數量是否相等,相等則表示每個機器都執行完指定的命令了,此時schedular節點會向barrier_group的每個機器傳送一個返回的資訊,並解除其阻塞。

3.7.2 初始化

ps-lite 使用 Barrier 來控制系統的初始化,就是大家都準備好了再一起前進。這是一個可選項。具體如下:

  • Scheduler等待所有的worker和server傳送BARRIER資訊;
  • 在完成ADD_NODE後,各個節點會進入指定 group 的Barrier阻塞同步機制(傳送 BARRIER 給 Scheduler),以保證上述過程每個節點都已經完成;
  • 所有節點(worker和server,包括scheduler) 等待scheduler收到所有節點 BARRIER 資訊後的應答;
  • 最終所有節點收到scheduler 應答的Barrier message後退出阻塞狀態;
3.7.2.1 等待 BARRIER 訊息

Node會呼叫 Barrier 函式 告知Scheduler,隨即自己進入等待狀態。

注意,呼叫時候是

if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);  

這就是說,等待所有的 group,即 scheduler 節點也要給自己傳送訊息。

void Postoffice::Barrier(int customer_id, int node_group) {
  if (GetNodeIDs(node_group).size() <= 1) return;
  auto role = van_->my_node().role;
  if (role == Node::SCHEDULER) {
    CHECK(node_group & kScheduler);
  } else if (role == Node::WORKER) {
    CHECK(node_group & kWorkerGroup);
  } else if (role == Node::SERVER) {
    CHECK(node_group & kServerGroup);
  }

  std::unique_lock<std::mutex> ulk(barrier_mu_);
  barrier_done_[0][customer_id] = false;
  Message req;
  req.meta.recver = kScheduler;
  req.meta.request = true;
  req.meta.control.cmd = Control::BARRIER;
  req.meta.app_id = 0;
  req.meta.customer_id = customer_id;
  req.meta.control.barrier_group = node_group; // 記錄了等待哪些
  req.meta.timestamp = van_->GetTimestamp();
  van_->Send(req); // 給 scheduler 發給 BARRIER
  barrier_cond_.wait(ulk, [this, customer_id] { // 然後等待
      return barrier_done_[0][customer_id];
    });
}

3.7.2.2 處理 BARRIER 訊息

處理等待的動作在 Van 類之中,我們提前放出來。

具體ProcessBarrierCommand邏輯如下:

  • 如果 msg->meta.request 為true,說明是 scheduler 收到訊息進行處理。
    • Scheduler會對Barrier請求進行增加計數。
    • 當 Scheduler 收到最後一個請求時(計數等於此group節點總數),則將計數清零,傳送結束Barrier的命令。這時候 meta.request 設定為 false;
    • 向此group所有節點傳送request==falseBARRIER訊息。
  • 如果 msg->meta.request 為 false,說明是收到訊息這個 respones,可以解除barrier了,於是進行處理,呼叫 Manage 函式 。
    • Manage 函式 將app_id對應的所有costomer的barrier_done_置為true,然後通知所有等待條件變數barrier_cond_.notify_all()
void Van::ProcessBarrierCommand(Message* msg) {
  auto& ctrl = msg->meta.control;
  if (msg->meta.request) {  // scheduler收到了訊息,因為 Postoffice::Barrier函式 會在傳送時候做設定為true。
    if (barrier_count_.empty()) {
      barrier_count_.resize(8, 0);
    }
    int group = ctrl.barrier_group;
    ++barrier_count_[group]; // Scheduler會對Barrier請求進行計數
    if (barrier_count_[group] ==
        static_cast<int>(Postoffice::Get()->GetNodeIDs(group).size())) { // 如果相等,說明已經收到了最後一個請求,所以傳送解除 barrier 訊息。
      barrier_count_[group] = 0;
      Message res;
      res.meta.request = false; // 回覆時候,這裡就是false
      res.meta.app_id = msg->meta.app_id;
      res.meta.customer_id = msg->meta.customer_id;
      res.meta.control.cmd = Control::BARRIER;
      for (int r : Postoffice::Get()->GetNodeIDs(group)) {
        int recver_id = r;
        if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) {
          res.meta.recver = recver_id;
          res.meta.timestamp = timestamp_++;
          Send(res);
        }
      }
    }
  } else { // 說明這裡收到了 barrier respones,可以解除 barrier了。具體見上面的設定為false處。
    Postoffice::Get()->Manage(*msg);
  }
}


Manage 函式就是解除了 barrier。

void Postoffice::Manage(const Message& recv) {
  CHECK(!recv.meta.control.empty());
  const auto& ctrl = recv.meta.control;
  if (ctrl.cmd == Control::BARRIER && !recv.meta.request) {
    barrier_mu_.lock();
    auto size = barrier_done_[recv.meta.app_id].size();
    for (size_t customer_id = 0; customer_id < size; customer_id++) {
      barrier_done_[recv.meta.app_id][customer_id] = true;
    }
    barrier_mu_.unlock();
    barrier_cond_.notify_all(); // 這裡解除了barrier
  }
}

具體示意如下:

                                                    +
    Scheduler                                       |                  Worker
        +                                           |                     +
        |                                           |                     |
        |                                           |                     |
        +--------------------------------+          |                     +-----------------+
        |                                |          |                     |                 |
        |                                |          |                     |                 |
        |                                |          |                     |                 |
        |                                v          |                     |                 v
        |                         receiver_thread_  |                     |           receiver_thread_
        |                                +          |                     |                 |
        |                                |          |                     |                 |
        v              BARRIER           |          |   BARRIER           v                 |
Postoffice::Barrier +----------------->  | <---------------------+ Postoffice::Barrier      |
        +                                |          |                     +                 |
        |                                |          |                     |                 |
        |                                |          |                     |                 |
        |                                |          |                     |                 |
        |                                v          |                     |                 |
        v                                           |                     v                 |
 barrier_cond_.wait          ProcessBarrierCommand  |               barrier_cond_.wait      |
        |                                +          |                     |                 |
        |                                |          |                     |                 |
        |                  All Nodes OK  |          |                     |                 |
        |                                |          |                     |                 |
        |                 +--------------+          |   BARRIER           |                 |
        |                 |              +---------------------------------------------->   |
        |                 |  BARRIER     |          |                     |                 |
        |                 +------------> |          |                     |                 |
        |                                |          |                     |                 |
        |                                |          |                     |                 |
        +<-------------------------------<          |                     | <---------------+
        |          barrier_cond_.notify_all         |                     |    barrier_cond_.notify_all
        v                                           |                     v
                                                    +


手機如下:

至此,Postoffice的分析我們初步完成,其餘功能我們將會結合 Van 和 Customer 在後續文章中分析。

0xEE 個人資訊

★★★★★★關於生活和技術的思考★★★★★★

微信公眾賬號:羅西的思考

如果您想及時得到個人撰寫文章的訊息推送,或者想看看個人推薦的技術資料,敬請關注。

在這裡插入圖片描述

0xFF 參考

MXNet設計和實現簡介

史上最全面的ps-lite理解

ps-lite 深度原始碼解讀

ps-lite原始碼剖析

基於Parameter Server的可擴充套件分散式機器學習架構

ps-lite程式碼解析

ps-lite程式碼筆記

分散式TensorFlow入門教程

分散式機器學習(上)-平行計算與機器學習

分散式機器學習(中)-平行計算與機器學習

分散式機器學習(下)-聯邦學習

ps-lite 原始碼分析

Talk - Scaling Distributed Machine Learning with System and Algorithm Co-design 筆記

相關文章