[原始碼解析] PyTorch分散式(6) -------- DistributedDataParallel -- 初始化&store

羅西的思考發表於2021-11-18

[原始碼解析] PyTorch分散式(6) ---DistributedDataParallel -- 初始化&store

0x00 摘要

本文是 PyTorch 分散式系列的第六篇, 介紹 DistributedDataParallel 所依賴的初始化方法和Store這兩個概念。

本系列其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀

[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎

[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構

[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯

[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法

[原始碼解析] PyTorch 分散式(1)------歷史和概述

[原始碼解析] PyTorch 分散式(2) ----- DataParallel(上)

[原始碼解析] PyTorch 分散式(3) ----- DataParallel(下)

[原始碼解析] PyTorch 分散式(4)------分散式應用基礎概念

[原始碼解析] PyTorch分散式(5) ------ DistributedDataParallel 總述&如何使用

0x01 回顧

1.1 基本概念

關於分散式通訊,PyTorch 提供的幾個概念是:程式組,後端,初始化,Store。

  • 程式組 :DDP是真正的分散式訓練,可以使用多臺機器來組成一次並行運算的任務。為了能夠讓 DDP 的各個worker之間通訊,PyTorch 設定了程式組這個概念。
  • 後端 :後端這個概念是一個邏輯上的概念。本質上後端是一種IPC通訊機制。
  • 初始化 : 雖然有了後端和程式組的概念,但是如何讓 worker 在建立程式組之前發現彼此? 這就需要一種初始化方法來告訴大家傳遞一個資訊:如何聯絡到其它機器上的程式。
  • Store : 可以認為是分散式鍵值儲存,利用這個儲存就可以在組中的程式之間共享資訊以及初始化分散式包 (通過顯式建立儲存來作為init_method的替代)。

1.2 初始化程式組

在呼叫任何 DDP 其他方法之前,需要使用torch.distributed.init_process_group()進行初始化。該方法會初始化預設分散式程式組和分散式包。此方法會阻塞,直到所有程式都加入,函式定義如下:

init_process_group ( backend , 
                       init_method = None , 
                       timeout = default_pg_timeout , 
                       world_size =- 1 , 
                       rank =- 1 , 
                       store = None , 
                       group_name = '' , 
                       pg_options = None )

初始化程式組有兩種主要方法:

  1. 明確指定 store,rank 和 world_size。
  2. 指定 init_method(一個 URL 字串),它指示在哪裡/如何發現對等點。

如果兩者都沒有指定,init_method則假定為“env://”。因此大家可以看到,store 和 init_method 是互斥的

init_process_group 的引數具體如下:

  • 後端 – 要使用的後端。有效值包括mpigloo,和nccl。該欄位應作為小寫字串(例如"gloo")給出,也可以通過Backend屬性(例如Backend.GLOO)訪問 。如果在nccl後端每臺機器上使用多個程式,則每個程式必須對其使用的每個 GPU 具有獨佔訪問許可權,因為在程式之間共享 GPU 可能會導致死鎖。
  • init_method – 指定如何初始化程式組的 URL。如果未指定init_methodstore指定,則預設為“env://” 。與 store互斥。
  • world_size – 參與作業的程式數。如果store指定,則 world_size 為必需。
  • rank – 當前程式的等級(它應該是一個介於 0 和world_size-1之間的數字)。如果store指定,則 rank 為必需。
  • store – 所有 worker 都可以訪問的鍵/值儲存,用於交換連線/地址資訊。與init_method 互斥。
  • timeout – 針對程式組執行的操作超時。預設值等於 30 分鐘。這適用於gloo後端。對於nccl,這僅在環境變數NCCL_BLOCKING_WAITNCCL_ASYNC_ERROR_HANDLING設定為 1 時 適用。
  • group_name – 組名。
  • pg_options ( Process Group Options , optional ) – 程式組選項,指定在構建特定程式組期間需要傳入哪些附加選項。

0x02 初始化

2.1 初始化方法

目前DDP模組支援三種初始化方式:

  • Environment variable initialization
  • Shared file-system initialization:init_method='file:///mnt/nfs/sharedfile'
  • TCP initialization :init_method='tcp://10.1.1.20:23456'

環境變數

此方法將從環境變數中讀取配置,是允許完全自定義獲取資訊的方式。通過在所有機器上設定以下四個環境變數,所有程式都可以正常連線到master(就是 rank 0 程式)以獲取其他程式的資訊,並最終與它們握手。

  • MASTER_PORT:rank 0 程式的機器上的埠。
  • MASTER_ADDR:rank 0 程式的機器上的 IP 地址。
  • WORLD_SIZE: 程式總數,因此master知道要等待多少worker。
  • RANK: 每個程式的rank,所以程式會知道自己是否是master。

共享檔案系統

共享檔案系統要求所有程式都可以訪問共享檔案系統,並將通過共享檔案協調它們。這意味著每個程式都將開啟檔案,寫入其資訊,並等待每個程式都這樣做。之後,所有所需的資訊都將可供所有流程使用。為了避免競爭條件,檔案系統必須通過fcntl支援鎖定 。

dist.init_process_group(
    init_method='file:///mnt/nfs/sharedfile',
    rank=args.rank,
    world_size=4)

TCP

TCP 初始化方式是通過提供rank 0程式的IP和埠來實現的,在這裡,所有worker都可以連線到等級為 0 的程式並交換有關如何相互聯絡的資訊。

dist.init_process_group(
    init_method='tcp://10.1.1.20:23456',
    rank=args.rank,
    world_size=4)

2.2 init_method VS store

我們很好奇,為什麼要有 init_method 和 store 這兩個引數?

通過看 init_process_group 程式碼我們可以發現以下規律。

  • 當 MPI 時候, init_method 沒有用處。

  • 在非 MPI 後端時候,如果沒有 store 引數,則使用 init_method 構建一個store。

所以最終還是落到了 store 之上,store才是其作用的實體

        if store is None:
            rendezvous_iterator = rendezvous(
                init_method, rank, world_size, timeout=timeout
            )
            store, rank, world_size = next(rendezvous_iterator)
            store.set_timeout(timeout)

init_process_group 程式碼如下:

def init_process_group(backend,
                       init_method=None,
                       timeout=default_pg_timeout,
                       world_size=-1,
                       rank=-1,
                       store=None,
                       group_name='',
                       pg_options=None):

    global _pg_group_ranks
    global _backend
    global _default_pg_init_method

    if store is not None:
        assert world_size > 0, 'world_size must be positive if using store'
        assert rank >= 0, 'rank must be non-negative if using store'
    elif init_method is None:
        init_method = "env://"

    backend = Backend(backend)

    if backend == Backend.MPI:
          default_pg = _new_process_group_helper(
            -1,
            -1,
            [],
            Backend.MPI,
            None,
            group_name=group_name,
            timeout=timeout)
        _update_default_pg(default_pg)
    else:
        # backward compatible API
        if store is None:
            # 如果沒有store,還是要用init_method構建一個store。
            rendezvous_iterator = rendezvous(
                init_method, rank, world_size, timeout=timeout
            )
            store, rank, world_size = next(rendezvous_iterator)
            store.set_timeout(timeout)

        default_pg = _new_process_group_helper(
            world_size,
            rank,
            [],
            backend,
            store,
            pg_options=pg_options,
            group_name=group_name,
            timeout=timeout)
        _update_default_pg(default_pg)

    _pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())}  # type: ignore[attr-defined, index]
    _backend = _pg_map[GroupMember.WORLD][0]  # type: ignore[index]
    _default_pg_init_method = init_method

    # 省略

2.3 rendezvous

上面程式碼之中提到了 rendezvous,我們就來看看這個概念。

在我們可以執行集合演算法之前,參與的程式需要找到彼此並交換資訊才能夠進行通訊。我們稱這個過程為rendezvous。rendezvous過程的結果是一個三元組,其中包含一個共享鍵/值儲存(store),程式的等級(rank)和參與程式的總數。如果內建的rendezvous方法都不適用於您的執行環境,那麼您可以選擇註冊自己的rendezvous處理程式。在呼叫rendezvous函式時,選擇一個唯一的名稱並使用URL方案來標識它。

rendezvous 方法就是依據引數,選擇不同的handler來處理。

def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):

    # Append node-specific arguments.
    result = urlparse(url)
    if rank != -1 or world_size != -1:
        query_dict: Dict[str, Union[int, str]] = dict(
            # mypy doesn't allow dict() to accept List of values (#257)
            pair.split("=") for pair in filter(None, result.query.split("&"))  # type: ignore[arg-type, misc]
        )
        if rank != -1:
            query_dict["rank"] = rank
        if world_size != -1:
            query_dict["world_size"] = world_size

        result = result._replace(
            query="{}".format("&".join(["{}={}".format(k, v) for k, v in query_dict.items()]))
        )
        url = urlunparse(result)

    return _rendezvous_handlers[result.scheme](url, **kwargs)

handler 如下,你會發現,其實 handler 就是對應了初始化的三種方法

register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
register_rendezvous_handler("env", _env_rendezvous_handler)
register_rendezvous_handler("file", _file_rendezvous_handler)

2.4 小結

從目前分析結果來看,我們得到了如下結論:

  • init_method 最終還是落到了 store 之上,store才是起作用的實體。
  • 參與的程式需要找到彼此並交換資訊才能夠進行通訊。這個過程被稱為rendezvous。

0x03 Store

我們給出一個正式的概念。Store 是分散式包(distributed package)所提供的分散式鍵值儲存,所有的 workers 都會訪問這個儲存以共享資訊以及初始化分散式包 。使用者可以通過顯式建立儲存來作為init_method的替代。目前有 3 種鍵值儲存:TCPStoreFileStore,和HashStore

我們接著上節繼續看 handler 概念。

3.1 _rendezvous_handlers

在 PyTorch 定義了一個全域性變數 _rendezvous_handlers,用來儲存如何返回 store 的方法,可以認為是工廠方法。

_rendezvous_handlers = {}

具體註冊方式是:

register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
register_rendezvous_handler("env", _env_rendezvous_handler)
register_rendezvous_handler("file", _file_rendezvous_handler)

註冊程式碼如下,就是往全域性變數之中插入handler。

def register_rendezvous_handler(scheme, handler):
    """Registers a new rendezvous handler.
    Args:
        scheme (str): URL scheme to identify your rendezvous handler.
        handler (function): Handler that is invoked when the
            `rendezvous()` function is called with a URL that uses
            the corresponding scheme. It must be a generator function
            that yields the triplet.
    """
    global _rendezvous_handlers
    if scheme in _rendezvous_handlers:
        raise RuntimeError(
            "Rendezvous handler for {}:// already registered".format(scheme)
        )
    _rendezvous_handlers[scheme] = handler

3.2 handlers

如果仔細看 handlers 的程式碼,就會發現其就是返回了不同的 store,比如 _tcp_rendezvous_handler具體就是使用各種資訊建立 TCPStore,然後返回。

以下程式碼均刪除非關鍵程式碼。

3.2.1 _file_rendezvous_handler

這裡返回了FileStore。

def _file_rendezvous_handler(url: str, **kwargs):

    result = urlparse(url)
    path = result.path
    query: Dict[str, str]
    # mypy doesn't allow dict() to accept List of values (#257)
    query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))  # type: ignore[misc, arg-type]

    rank = int(query["rank"])
    world_size = int(query["world_size"])
    store = FileStore(path, world_size)
    yield (store, rank, world_size)

    # If this configuration is invalidated, there is nothing we can do about it
    raise RuntimeError("Unable to perform rerendezvous using file:// method")

3.2.2 _tcp_rendezvous_handler

這裡返回了 TCPStore。

def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):
    result = urlparse(url)
    query: Dict[str, Union[int, str]]
    # mypy doesn't allow dict() to accept List of values (#257)
    query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))  # type: ignore[misc, arg-type]

    rank = int(query["rank"])
    world_size = int(query["world_size"])
    start_daemon = rank == 0
    assert result.hostname is not None
    store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout)
    yield (store, rank, world_size)

    # If this configuration is invalidated, there is nothing we can do about it
    raise RuntimeError("Unable to perform rerendezvous using tcp:// method")

3.2.3 _env_rendezvous_handler

居然也返回了 TCPStore,但是其會從環境變數中提取需要的資訊。

def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):

    result = urlparse(url)
    query: Dict[str, Union[int, str]]
    query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) 
    rank: Optional[Union[str, int]]
    world_size: Optional[Union[str, int]]
    master_port: Optional[Union[str, int]]

    if "rank" in query:
        rank = int(query["rank"])
    else:
        rank = int(_get_env_or_raise("RANK"))

    if "world_size" in query:
        world_size = int(query["world_size"])
    else:
        world_size = int(_get_env_or_raise("WORLD_SIZE"))

    master_addr = _get_env_or_raise("MASTER_ADDR")
    master_port = int(_get_env_or_raise("MASTER_PORT"))

    use_torchelastic_store = os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None)

    if use_torchelastic_store == str(True):
        worker_process_prefix = "/worker"
        # When TORCHELASTIC_USE_AGENT_STORE is set up, the worker process is assumed
        # to be invoked by the torchelastic agent. Torchelastic agent creates a tcp daemon thread
        # on the GROUP_RANK=0, as a result all user worker processes should create store with: daemon=False
        tcp_store = TCPStore(master_addr, master_port, world_size, False, timeout)
        yield (PrefixStore(worker_process_prefix, tcp_store), rank, world_size)
    else:
        # Start the TCP store daemon on the rank 0
        start_daemon = rank == 0
        store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)
        yield (store, rank, world_size)

    # If this configuration is invalidated, there is nothing we can do about it
    raise RuntimeError("Unable to perform rerendezvous using env:// method")

3.3 使用

3.3.1 使用 handler

如何使用 handler?在 init_process_group 之中有:

rendezvous_iterator = rendezvous(
    init_method, rank, world_size, timeout=timeout
)
store, rank, world_size = next(rendezvous_iterator)

rendezvous 具體就是依據 init_method 來選擇一個 _rendezvous_handler,然後 _rendezvous_handler 返回了 store。

def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
    # Append node-specific arguments.
    result = urlparse(url)
    if rank != -1 or world_size != -1:
        query_dict: Dict[str, Union[int, str]] = dict(
            # mypy doesn't allow dict() to accept List of values (#257)
            pair.split("=") for pair in filter(None, result.query.split("&"))  # type: ignore[arg-type, misc]
        )
        if rank != -1:
            query_dict["rank"] = rank
        if world_size != -1:
            query_dict["world_size"] = world_size

        result = result._replace(
            query="{}".format("&".join(["{}={}".format(k, v) for k, v in query_dict.items()]))
        )
        url = urlunparse(result)

    return _rendezvous_handlers[result.scheme](url, **kwargs)

3.3.2 使用 Store

我們繼續看如何使用 store。在 init_process_group 程式碼之中,接下來就使用了 store 來初始化程式組。

default_pg = _new_process_group_helper(
    world_size,
    rank,
    [],
    backend,
    store,
    pg_options=pg_options,
    group_name=group_name,
    timeout=timeout)
_update_default_pg(default_pg)
3.3.2.1 _new_process_group_helper

為了接著看 _new_process_group_helper,我們首先看看幾個全域性變數。以下幾個變數 ProcessGroup 資訊做了全域性儲存,比如 _pg_map[pg] = (Backend.NCCL, store)。

# Cached process groups
# For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
# For MPI pg, it is a map from ProcessGroup to (Backend, None)
_pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {}
# Process group's names, map from ProcessGroup to str
_pg_names: Dict[ProcessGroup, str] = {}
# Process group's global rank to local rank mapping
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}

_new_process_group_helper 之中得到了 store 引數之後,據此生成了一個 prefix_store,然後再根據這個 pre_store 來生成了 ProcessGroupGloo。_new_process_group_helper 程式碼具體如下:

def _new_process_group_helper(world_size,
                              rank,
                              group_ranks,
                              backend,
                              store,
                              pg_options=None,
                              group_name=None,
                              timeout=default_pg_timeout):
    """
    Create a new distributed process group.

    This function must be called by ALL processes in the global group, even if
    the calling process is not part of the newly created group. In that case,
    this function returns GroupMember.NON_GROUP_MEMBER.

    This function is called with ``group_ranks == []`` for the default group.
    """
    global _pg_map
    global _group_count
    global _pg_names

    if not group_name:
        group_name = str(_group_count)
        _group_count += 1

    # The list of group ranks is empty if we're creating the default group.
    is_default_group = (len(group_ranks) == 0)

    backend = Backend(backend)
    pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL]
    if backend == Backend.MPI: # 沒有使用store
        pg = ProcessGroupMPI.create(group_ranks)
        if not pg:
            return GroupMember.NON_GROUP_MEMBER
        _pg_map[pg] = (Backend.MPI, None)
        _pg_names[pg] = group_name
    else:
      	# 這裡會使用store
      
        # If this is a subgroup (which means group_ranks is specified),
        # we check if the current process is a member of the new group.
        if not is_default_group:
            global_rank = _get_default_group().rank()
            if global_rank not in group_ranks:
                return GroupMember.NON_GROUP_MEMBER

        # Use the group name as prefix in the default store, such that
        # a single store can be reused by multiple groups.
        
        prefix_store = PrefixStore(group_name, store) # 構建了 PrefixStore

        if backend == Backend.GLOO:
            pg = ProcessGroupGloo(
                prefix_store, # 使用PrefixStore構建程式組
                rank,
                world_size,
                timeout=timeout)
            _pg_map[pg] = (Backend.GLOO, store)
            _pg_names[pg] = group_name
        elif backend == Backend.NCCL:
            if pg_options is not None:
                assert isinstance(pg_options, ProcessGroupNCCL.Options), \
                    "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
            else:
                # default pg_options for NCCL
                pg_options = ProcessGroupNCCL.Options()
                pg_options.is_high_priority_stream = False
                pg_options._timeout = timeout

            pg = ProcessGroupNCCL(
                prefix_store, # 使用PrefixStore構建程式組
                rank,
                world_size,
                pg_options)
            _pg_map[pg] = (Backend.NCCL, store)
            _pg_names[pg] = group_name
        else:
            pg = getattr(Backend, backend.upper())(
                prefix_store,
                rank,
                world_size,
                timeout)
            _pg_map[pg] = (backend, store)
            _pg_names[pg] = group_name

    return pg
3.3.2.2 ProcessGroupGloo

在 ProcessGroupGloo 之中有具體使用,比如在PrefixStore之上生成了一個GlooStore,利用 PrefixStore 建立網路等等。

ProcessGroupGloo::ProcessGroupGloo(
    const c10::intrusive_ptr<Store>& store,
    int rank,
    int size,
    c10::intrusive_ptr<Options> options)
    : ProcessGroup(rank, size),
      store_(new GlooStore(store)), // 在PrefixStore之上生成了一個GlooStore
      options_(options),
      stop_(false),
      collectiveCounter_(0) {
  auto& devices = options->devices;

  contexts_.reserve(options->devices.size());
  for (size_t i = 0; i < options->devices.size(); i++) {
    auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
    // 又生成了一個PrefixStore
    auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_);
    context->setTimeout(options->timeout);
    // 利用 PrefixStore 建立網路
    context->connectFullMesh(store, options->devices[i]);
    contexts_.push_back(std::move(context));
  }

  // Every worker thread stores the AsyncWork object it's currently
  // working on in the workInProgress_ vector. It must have size equal
  // to the number of workers such that they can simply index into it
  // using the worker index they are started with.
  workInProgress_.resize(options->threads);

  threads_.resize(options->threads);
  for (size_t i = 0; i < threads_.size(); i++) {
    threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i);
  }
}

在下面程式碼之中,也有對store_的使用,比如等待,存取。

void ProcessGroupGloo::setSequenceNumberForGroup() {
  if (rank_ == 0) {
    // Create and broadcast sequence number
    auto seq = 1 + rand();
    sequenceNum_ = c10d::SequenceNum(seq);
    std::vector<char> values = c10d::toVec<char>(seq, kBytes);
    store_->set(kSeqNumStoreKey, values); // 存value
  } else {
    // Read rank 0's sequence number from store.
    sequenceNum_ = c10d::SequenceNum();
    store_->wait({kSeqNumStoreKey}, options_->timeout); // 等待
    std::vector<char> values = store_->get(kSeqNumStoreKey); // 取value
    uint64_t num = c10d::fromVec<char>(values);
    sequenceNum_->set(num);
  }
}  

3.4 小結

從目前分析結果來看,我們擴充結論如下:

  • init_method 最終還是落到了 store 之上,store才是起作用的實體。
  • 參與的程式需要找到彼此並交換資訊才能夠進行通訊。這個過程被稱為rendezvous。
  • rendezvous 其實就是返回了某一種store 以供後續通訊使用。
  • 在程式組之中,會使用 store 來構建通訊,等待,存取等。

我們接下來選擇 TCPStore進行相信分析。

0x04 TCPStore

TCPStore 是基於 TCP 的分散式鍵值儲存實現。伺服器儲存/儲存資料,而儲存客戶端可以通過 TCP 連線到伺服器儲存並執行諸如set()插入鍵值對、get()檢索鍵值對等操作。系統中應該有一個初始化完畢的TCPStore儲存伺服器,因為儲存客戶端將等待這個儲存服務以建立連線。

TCPStore 的引數如下:

  • host_name ( str ) – 主機名或 IP 地址。儲存伺服器在其上執行。
  • port ( int ) – 儲存伺服器在這個埠上偵聽傳入請求。
  • world_size ( int , optional ) – 使用者總數。
    • world_size = 客戶端數 + 1,1 代表伺服器。
    • 預設值為 -1(負值表示不固定的使用者數)。
  • is_master ( bool , optional ) – 初始化儲存伺服器時為真,初始化儲存客戶端時為假。預設值為假。
  • timeout ( timedelta , optional ) – store在初始化期間,以及get()和 wait()方法使用的超時時間。預設為 timedelta(seconds=300)。
  • wait_for_worker ( bool , optional ) – 是否等待所有worker與儲存伺服器連線。這僅在 world_size 為固定值時適用。預設值為真。

使用例子如下:

import torch.distributed as dist
from datetime import timedelta
# Run on process 1 (server)
server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))
# Run on process 2 (client)
client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)
# Use any of the store methods from either the client or server after initialization
server_store.set("first_key", "first_value")
client_store.get("first_key")

或者

    >>> import torch.distributed as dist
    >>> from datetime import timedelta
    >>> # Using TCPStore as an example, other store types can also be used
    >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
    >>> # This will throw an exception after 10 seconds
    >>> store.wait(["bad_key"], timedelta(seconds=10))

從例子上看,就是簡單的 server,client 或者說 master, worker 的關係,我們接下來仔細分析。

4.1 TCPStore in python

在 Python 世界之中,就是簡單的設定了 host 和 port。

class TCPStore(Store):
    def __init__(self, host_name, port, world_size=-1, is_master=False, timeout=None, *args, **kwargs): # real signature unknown; NOTE: unreliably restored from __doc__ 
        pass

    host = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default
    """Gets the hostname on which the store listens for requests."""

    port = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default
    """Gets the port number on which the store listens for requests."""

我們需要深入到 C++ 世界看看。

4.2 TCPStore in CPP

4.2.1 API介面

首先,C++之中的 TCPStore 可以認為是一個API介面,其定義如下:

class TCPStore : public Store {
 public:
  explicit TCPStore(
      const std::string& masterAddr,
      PortType masterPort,
      c10::optional<int> numWorkers = c10::nullopt_t(-1),
      bool isServer = false,
      const std::chrono::milliseconds& timeout = kDefaultTimeout,
      bool waitWorkers = true);

  virtual ~TCPStore();

  void set(const std::string& key, const std::vector<uint8_t>& value) override;
  std::vector<uint8_t> compareSet(
      const std::string& key,
      const std::vector<uint8_t>& expectedValue,
      const std::vector<uint8_t>& desiredValue) override;
  std::vector<uint8_t> get(const std::string& key) override;
  int64_t add(const std::string& key, int64_t value) override;
  bool deleteKey(const std::string& key) override;

  // NOTE: calling other TCPStore APIs inside the callback is NOT threadsafe
  // watchKey() is a blocking operation. It will register the socket on
  // TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon. It will
  // return once it has verified the callback is registered on both background
  // threads. Only one thread can call watchKey() at a time.
  void watchKey(const std::string& key, WatchKeyCallback callback) override;
  bool check(const std::vector<std::string>& keys) override;
  int64_t getNumKeys() override;
  void wait(const std::vector<std::string>& keys) override;
  void wait(
      const std::vector<std::string>& keys,
      const std::chrono::milliseconds& timeout) override;
  // Waits for all workers to join.
  void waitForWorkers();
  // Returns the hostname used by the TCPStore.
  const std::string& getHost() const noexcept;
  // Returns the port used by the TCPStore.
  PortType getPort() const noexcept;

 private:
  int64_t addHelper_(const std::string& key, int64_t value);
  std::vector<uint8_t> getHelper_(const std::string& key);
  void waitHelper_(
      const std::vector<std::string>& keys,
      const std::chrono::milliseconds& timeout);

  std::mutex watchKeyMutex_;
  bool isServer_;
  int storeSocket_ = -1; // 
  int listenSocket_ = -1; // 
  int masterListenSocket_ = -1; // master 在這裡監聽

  std::string tcpStoreAddr_;
  PortType tcpStorePort_;

  c10::optional<int> numWorkers_;
  const std::string initKey_;
  const std::string regularPrefix_;

  std::unique_ptr<TCPStoreMasterDaemon> tcpStoreMasterDaemon_ = nullptr;
  std::unique_ptr<TCPStoreWorkerDaemon> tcpStoreWorkerDaemon_ = nullptr;
};

4.2.2 socket用處

其成員變數之中最主要的是三個socket,或者說他們是 store 的精華(難點)所在。

  int storeSocket_ = -1; // 
  int listenSocket_ = -1; // 
  int masterListenSocket_ = -1; // master 在這裡監聽
4.2.2.1 業務分工

具體解釋如下(後面還會結合程式碼繼續分析):

  • masterListenSocket_ 是 listen 在 masterPort 之上。
    • tcpStoreMasterDaemon_本身是一個master,就是為整個 TCPStore提供服務的 server。
    • tcpStoreMasterDaemon_ 使用 tcputil::addPollfd(fds, storeListenSocket_, POLLIN) 來監聽 masterListenSocket_
    • key-value 就是std::unordered_map<std::string, std::vector<uint8_t>> tcpStore。
  • storeSocket_ 在 tcpStoreWorkerDaemon_ 之上,其連線到 masterListenSocket_ : masterPort 之上。
    • storeSocket_ 的作用是封裝面對 master port 的操作,使用者只管 set,get 等操作,不用知道 master port。
    • set(key, data) 的作用就是通過 storeSocket_ 向master 傳送一個設定key : value 的請求。
    • tcpStoreMasterDaemon_ 監聽到socket變化,就開始相應。
    • tcpStoreMasterDaemon_ 內部把 key : value 新增到 std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_ 之上。
  • listenSocket_ 在 tcpStoreWorkerDaemon_ 之上,也連線到 masterListenSocket_: masterPort 之上。下面有一個解耦,如註釋所述,It will register the socket on TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon
    • listenSocket_ 封裝了對 watchKey 的處理。Store Client 使用watchKey(const std::string& key, WatchKeyCallback callback) 請求註冊,即:
      • Worker 請求註冊。使用 tcpStoreWorkerDaemon_->setCallback(regKey, callback) 來為 tcpStoreWorkerDaemon_std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_ 之上新增一個 callback。
      • Worker 傳送請求。通過 listenSocket_ 給 master 發訊息 (key, WATCH_KEY),告訴master,如果 key 的 value 有變化,就呼叫這個 callback。
    • Master 執行註冊。Master 接到 WATCH_KEY 訊息之後進行註冊,呼叫 watchHandler,使用 watchedSockets_[key].push_back(socket) 來配置,告訴自己,如果這個 key 有變化,就給這個 socket 發訊息。
    • Master通知Worker。在 TCPStoreMasterDaemon::setHandler 之中,如果設定了新 value 之後,呼叫 sendKeyUpdatesToClients,其會遍歷 watchedSockets_[key],如果有 socket,就給 socket 傳送訊息變化通知。
    • Worker執行callback。所以如果 key 有變化,就在 tcpStoreWorkerDaemon_ 之中呼叫了這個 callback。
4.2.2.2 Set 例子

我們首先看看 Set 的例子如下,就是 Worker 通過 socket 來在 Master 之上設定 value。

                                                                          +
+----------------------------------------------------------------------+  |  +----------------------------------------------+
| TCPStore                                                      Master |  |  | TCPStore                              Worker |
|                                                                      |  |  |                                              |
|                                                                      |  |  |                                              |
|                                                                      |  |  |                                              |
|   +------------------------------------------------------------+     |  |  |                                              |
|   | TcpStoreMasterDaemon_                            MasterPort|     |  |  |                                              |
|   |                                                            |     |  |  |                                              |
|   |    TCPStore.masterListenSocket_                            |     |  |  |      +---------------------------------+     |
|   |                                                            |     |  |  |      | set(key, value)                 |     |
|   |                                                            |     |  |  |      |                                 |     |
|   |    tcpStore_[key] = value  <------------------------------------------------+ |    storeSocket_                 |     |
|   |                                                            |     |  |  |      |                                 |     |
|   |                                                            |     |  |  |      +---------------------------------+     |
|   |                                                            |     |  |  |                                              |
|   +------------------------------------------------------------+     |  |  |                                              |
|                                                                      |  |  |                                              |
+----------------------------------------------------------------------+  |  +----------------------------------------------+
                                                                          +

手機如下:

4.2.2.3 Set 和 watchKey 結合

Set 和 watchKey 結合起來的示意圖如下(worker請求註冊,具體執行回撥;master執行註冊,通知worker執行回撥):

  1. Worker 請求註冊。Store Client 使用watchKey(const std::string& key, WatchKeyCallback callback) 就是使用 tcpStoreWorkerDaemon_->setCallback(regKey, callback) 來為 tcpStoreWorkerDaemon_std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_ 之上新增一個callback。
  2. Worker 傳送請求。Worker 通過 listenSocket_ 給 master 發訊息 (key, WATCH_KEY),告訴master,如果 key 的 value 有變化,就呼叫這個 callback。
  3. Master 執行註冊。Master 接到 WATCH_KEY 訊息之後,呼叫 watchHandler,使用 watchedSockets_[key].push_back(socket) 來配置,告訴自己,如果這個 key 有變化,就給這個 socket 發訊息。
  4. 下面我們假設 Store Client(這裡假設是同一個worker設定,實際上可能是不同worker)設定了一個 value。
  5. Master通知Worker。Master 在 TCPStoreMasterDaemon::setHandler 之中,如果設定了新 value 之後,呼叫 sendKeyUpdatesToClients,其會遍歷 watchedSockets_[key],如果有 socket,就給 socket 傳送訊息變化通知。
  6. Worker執行callback。如果 key 有變化,就在 tcpStoreWorkerDaemon_ 之中呼叫了這個 callback。
+----------------------------------------------------------------------+  +  +------------------------------------------------------------------------+
| TCPStore                                                      Master |  |  | TCPStore                                                        Worker |
|                                                                      |  |  |                                                                        |
|   +------------------------------------------------------------+     |  |  |                                                                        |
|   | TcpStoreMasterDaemon_                            MasterPort|     |  |  |      +---------------------------------+                               |
|   |                                                            |     |  |  |      |                                 |                               |
|   |                                                  2         |     |  |  |      | watchKey(key, callback) +----------------------+                |
|   |           TCPStore.masterListenSocket_   <----------------------------------+ |                                 |              |                |
|   |                       +                                    |     |  |  |      |    listenSocket_                |              |                |
|   |                       | 3                                  |     |  |  |      |                                 |            1 |                |
|   |                       v                                    |     |  |  |      |                                 |              |                |
|   |           watchedSockets_[key] = socket                    |     |  |  |      +---------------------------------+              |                |
|   |                                                            |     |  |  |                                                       |                |
|   |  +-------------------------------------------------+       |     |  |  |                                                       |                |
|   |  |                                                 |       |     |  |  |                                                       |                |
|   |  |    setHandler                                   |       |     |  |  |   +----------------------------------------------------------------+   |
|   |  |                                                 |       |     |  |  |   | TCPStoreWorkerDaemon                              |            |   |
|   |  |                                                 |       |     |  |  |   |                                                   v            |   |
|   |  |       tcpStore_[key] = newData                  |       |     |  |  |   |   unordered_map<string, WatchKeyCallback> keyToCallbacks_      |   |
|   |  |                   +                             |       |     |  |  |   |                                                                |   |
|   |  |                   |                             |       |     |  |  |   |   TCPStore.listenSocket_                                       |   |
|   |  |                   |                             |       |     |  |  |   |                                                                |   |
|   |  |                   v                             |       |     |  |  |   |  +----------------------------------------------------------+  |   |
|   |  |       sendKeyUpdatesToClients                   |       |     |  |  |   |  | run                                                      |  |   |
|   |  |                   +                             |       |  5  |  |  |   |  |                                                          |  |   |
|   |  |                   |                             |  +---------------------->+                                        6                 |  |   |
|   |  |                   |                             |  |    |     |  |  |   |  |       callbackHandler +-----> keyToCallbacks_(callback)  |  |   |
|   |  |                   v                             |  |    |     |  |  |   |  |                                                          |  |   |
|   |  |                                                 |  |    |     |  |  |   |  +----------------------------------------------------------+  |   |
|   |  |    for (int socket : watchedSockets_[key]){     |  |    |     |  |  |   +----------------------------------------------------------------+   |
|   |  |       tcputil::sendString(socket, key, true) +-----+    |     |  |  |                                                                        |
|   |  |    }                                            |       |     |  |  |                                                                        |
|   |  |                                                 |       |     |  |  |       +------------------------+                                       |
|   |  |                                                 |       |  4  |  |  |       | set(key, newData)      |                                       |
|   |  |                                                 | <-----------------------+ |                        |                                       |
|   |  +-------------------------------------------------+       |     |  |  |       |                        |                                       |
|   |                                                            |     |  |  |       +------------------------+                                       |
|   +------------------------------------------------------------+     |  |  |                                                                        |
|                                                                      |  |  |                                                                        |
+----------------------------------------------------------------------+  +  +------------------------------------------------------------------------+

手機如下:

4.2.3 功能函式

TCPStore 提供了若干功能函式。

void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
  std::string regKey = regularPrefix_ + key;
  tcputil::sendValue<QueryType>(storeSocket_, QueryType::SET);
  tcputil::sendString(storeSocket_, regKey, true);
  tcputil::sendVector<uint8_t>(storeSocket_, data);
}

std::vector<uint8_t> TCPStore::get(const std::string& key) {
  std::string regKey = regularPrefix_ + key;
  return getHelper_(regKey);
}

int64_t TCPStore::add(const std::string& key, int64_t value) {
  std::string regKey = regularPrefix_ + key;
  return addHelper_(regKey, value);
}

int64_t TCPStore::addHelper_(const std::string& key, int64_t value) {
  tcputil::sendValue<QueryType>(storeSocket_, QueryType::ADD);
  tcputil::sendString(storeSocket_, key, true);
  tcputil::sendValue<int64_t>(storeSocket_, value);
  return tcputil::recvValue<int64_t>(storeSocket_);
}

這些功能函式是呼叫如下基礎函式來傳送接收。

// this is only for convenience when sending rvalues
template <typename T>
void sendValue(int socket, const T& value, bool moreData = false) {
  sendBytes<T>(socket, &value, 1, moreData);
}

template <typename T>
T recvValue(int socket) {
  T value;
  recvBytes<T>(socket, &value, 1);
  return value;
}

4.2.4 構建函式

我們從構建函式可以看到:

  • 對於儲存伺服器角色,主要就是啟動了 tcpStoreMasterDaemon_,注意在啟動了 daemon 之後,server 就進入了等待worker狀態,不會啟動接下來程式碼中的 tcpStoreWorkerDaemon_
  • 對於儲存客戶端,則啟動了 tcpStoreWorkerDaemon_。
// TCPStore class methods
TCPStore::TCPStore(
    const std::string& masterAddr,
    PortType masterPort,
    c10::optional<int> numWorkers,
    bool isServer,
    const std::chrono::milliseconds& timeout,
    bool waitWorkers)
    : Store(timeout),
      isServer_(isServer),
      tcpStoreAddr_(masterAddr),
      tcpStorePort_(masterPort),
      numWorkers_(numWorkers),
      initKey_("init/"),
      regularPrefix_("/") {
  tcputil::socketInitialize();
  if (isServer_) { // 如果設定了是server,就在masterPort上監聽
    // Opening up the listening socket
    std::tie(masterListenSocket_, tcpStorePort_) = tcputil::listen(masterPort);
  }
  try {
    if (isServer_) { // 如果設定了是server,就啟動 tcpStoreMasterDaemon_
      // Now start the daemon
      tcpStoreMasterDaemon_ =
          std::make_unique<TCPStoreMasterDaemon>(masterListenSocket_);
    }
    // Connect to the daemon
    // worker 會與 master port 建立聯絡
    storeSocket_ = tcputil::connect(
        tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_);
    if (numWorkers.value_or(-1) >= 0 && waitWorkers) {
      waitForWorkers(); // server 等待 worker
    }

    // socket to handle requests from server,因為 master 也會給 worker 發訊息
    listenSocket_ = tcputil::connect(
        tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_);
    // 啟動 worker daemon
    tcpStoreWorkerDaemon_ =
        std::make_unique<TCPStoreWorkerDaemon>(listenSocket_);
  } catch (const std::exception&) {
    if (isServer_) {
      tcpStoreMasterDaemon_ = nullptr;
      tcputil::closeSocket(masterListenSocket_);
    }
    tcpStoreWorkerDaemon_ = nullptr;
    if (listenSocket_ != -1) {
      tcputil::closeSocket(listenSocket_);
    }
    if (storeSocket_ != -1) {
      tcputil::closeSocket(storeSocket_);
    }
    throw;
  }
}

server 會使用如下函式來等待 worker.

void TCPStore::waitForWorkers() {
  addHelper_(initKey_, 1);
  // Let server block until all workers have completed, this ensures that
  // the server daemon thread is always running until the very end
  if (isServer_) {
    const auto start = std::chrono::steady_clock::now();
    while (true) {
      std::vector<uint8_t> value = getHelper_(initKey_);
      auto buf = reinterpret_cast<const char*>(value.data());
      auto len = value.size();
      int numWorkersCompleted = std::stoi(std::string(buf, len));
      if (numWorkersCompleted >= numWorkers_.value_or(-1)) {
        break;
      }
      const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
          std::chrono::steady_clock::now() - start);
      if (timeout_ != kNoTimeout && elapsed > timeout_) {
        break;
      }
      /* sleep override */
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
    }
  }
}

4.2.5 TCPStoreWorkerDaemon

這個 daemon 程式只是用來處理 watchKey。

// Separate thread that is launched on all instances (including master)
// Right now only handles callbacks registered from watchKey()
class TCPStoreWorkerDaemon : public BackgroundThread {
 public:
  explicit TCPStoreWorkerDaemon(int listenSocket);
  // Set the callback to run key change
  void setCallback(std::string key, WatchKeyCallback cb);
  void waitForCallbackRegistration() {
    // Block until callback has been registered successfully
    std::unique_lock<std::mutex> callbackRegistrationLock(
        callbackRegistrationMutex_);
    callbackRegisteredCV_.wait(
        callbackRegistrationLock, [&] { return callbackRegisteredData_; });

    // Reset payload for next callback
    callbackRegisteredData_ = false;
  }
  void setCallbackRegistered() {
    callbackRegisteredData_ = true;
    callbackRegisteredCV_.notify_one();
  }

 private:
  void run();
  void callbackHandler(int socket);
  // List of callbacks map each watched key
  std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_;
  std::mutex keyToCallbacksMutex_;
  std::mutex callbackRegistrationMutex_;
  std::condition_variable callbackRegisteredCV_;
  bool callbackRegisteredData_ = false;
};


其構建函式只是建立一個執行緒。

// TCPStoreListener class methods
TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(int listenSocket)
    : BackgroundThread(listenSocket) {
  daemonThread_ = std::thread(&TCPStoreWorkerDaemon::run, this);
}
4.2.5.1 watchKey

Client Store 使用watchKey(const std::string& key, WatchKeyCallback callback) 的作用是往master註冊監聽key:

  • Worker 請求註冊。使用 tcpStoreWorkerDaemon_->setCallback(regKey, callback) 來為 tcpStoreWorkerDaemon_std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_ 之上新增一個 callback。
  • Worker 傳送請求。通過 listenSocket_ 給 master 發訊息 (key, WATCH_KEY),告訴master,如果 key 的 value 有變化,就呼叫這個 callback。
  • 然後使用 waitForCallbackRegistration 等待註冊完成。
void TCPStore::watchKey(const std::string& key, WatchKeyCallback callback) {
  // Only allow one thread to perform watchKey() at a time
  const std::lock_guard<std::mutex> watchKeyLock(watchKeyMutex_);

  // Register callback with TCPStoreMasterDaemon to call TCPStoreWorkerDaemon on
  // key change
  std::string regKey = regularPrefix_ + key;
  tcpStoreWorkerDaemon_->setCallback(regKey, callback);
  tcputil::sendValue<QueryType>(listenSocket_, QueryType::WATCH_KEY);
  tcputil::sendString(listenSocket_, regKey);

  // Block until callback has been registered successfully
  tcpStoreWorkerDaemon_->waitForCallbackRegistration();
}
4.2.5.2 執行

其執行分為 windows 和 其他系統,但是主要就是收到了業務key,然後進行相關業務處理。

  • Master 執行註冊。Master 接到 WATCH_KEY 訊息之後,呼叫 watchHandler,使用 watchedSockets_[key].push_back(socket) 來配置,告訴自己,如果這個 key 有變化,就給這個 socket 發訊息。
  • Master通知Worker。在 TCPStoreMasterDaemon::setHandler 之中,如果設定了新 value 之後,呼叫 sendKeyUpdatesToClients,其會遍歷 watchedSockets_[key],如果有 socket,就給 socket 傳送訊息變化通知。
  • Worker執行callback。所以如果 key 有變化,就在 tcpStoreWorkerDaemon_ 之中呼叫了這個 callback。
#ifdef _WIN32 
void TCPStoreWorkerDaemon::run() { // 這裡是windows系統
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);

  while (true) {
    // Check control and exit early if triggered
    int res;
    SYSCHECK_ERR_RETURN_NEG1(
        res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
    if (res == 0) {
      auto rvPoll = WaitForSingleObject(ghStopEvent_, 0);
      if (rvPoll != WAIT_TIMEOUT) {
        break;
      }
      continue;
    }

    // if connection is closed gracefully by master, peeked data will return 0
    char data;
    int ret = recv(fds[0].fd, &data, 1, MSG_PEEK);
    if (ret == 0) {
      auto rvData = WaitForSingleObject(ghStopEvent_, 0);
      if (rvData != WAIT_TIMEOUT) {
        break;
      }
      continue;
    }

    // valid request, perform callback logic
    callbackHandler(fds[0].fd); // 業務處理
  }
}
#else
void TCPStoreWorkerDaemon::run() {
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, controlPipeFd_[0], POLLHUP);
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);

  while (true) {
    SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));

    // Check control and exit early if triggered
    // The pipe receives an event which tells us to shutdown the listener thread
    if (fds[0].revents != 0) {
      // Will be POLLUP when the pipe is closed
      if (fds[0].revents ^ POLLHUP) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the control pipe's reading fd: " +
                std::to_string(fds[0].revents));
      }
      break;
    }

    // if connection is closed gracefully by master, peeked data will return 0
    char data;
    int ret = recv(fds[1].fd, &data, 1, MSG_PEEK);
    if (ret == 0) {
      continue;
    }

    // valid request, perform callback logic
    callbackHandler(fds[1].fd); // 業務處理
  }
}
#endif

4.2.6 TCPStoreMasterDaemon

這裡的 std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_; 是真實的 kv。

所以,TCPStoreMasterDaemon 就是負責對 kv 的操作,比如存取。

// Separate thread that is only launched on master
class TCPStoreMasterDaemon : public BackgroundThread {
 public:
  explicit TCPStoreMasterDaemon(int storeListenSocket);

 private:
  void run();
  void queryFds(std::vector<struct pollfd>& fds);
  void query(int socket);

  // The master runs on a single thread so only
  // one handler can be executed at a time
  void setHandler(int socket);
  void compareSetHandler(int socket);
  void addHandler(int socket);
  void getHandler(int socket) const;
  void checkHandler(int socket) const;
  void getNumKeysHandler(int socket) const;
  void deleteHandler(int socket);
  void waitHandler(int socket);
  void watchHandler(int socket);

  bool checkKeys(const std::vector<std::string>& keys) const;
  // Helper function to alerts waiting workers, used in setHandler, getHandler
  void wakeupWaitingClients(const std::string& key);
  // Helper function used when the key is changed
  // used in setHandler, addHandler, getHandler, deleteHandler
  void sendKeyUpdatesToClients(
      const std::string& key,
      const enum WatchResponseType& type,
      std::vector<uint8_t>& oldData,
      std::vector<uint8_t>& newData);
  std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
  // From key -> the list of sockets waiting on the key
  std::unordered_map<std::string, std::vector<int>> waitingSockets_;
  // From socket -> number of keys awaited
  std::unordered_map<int, size_t> keysAwaited_;
  // From key -> the list of sockets watching the key
  std::unordered_map<std::string, std::vector<int>> watchedSockets_;
};
4.2.6.1 執行

TCPStoreMasterDaemon 就是等待在 socket 之上,即 masterListenSocket_ 是 listen 在 masterPort 之上。

  • tcpStoreMasterDaemon_ 使用 tcputil::addPollfd(fds, storeListenSocket_, POLLIN) 來監聽 masterListenSocket_
  • tcpStoreMasterDaemon_本身成為一個master,就是為整個 TCPStore提供服務的 server。
  • key-value 就是std::unordered_map<std::string, std::vector<uint8_t>> tcpStore。
#ifdef _WIN32
void TCPStoreMasterDaemon::run() {
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);

  // receive the queries
  bool finished = false;
  while (!finished) {
    for (size_t i = 0; i < sockets_.size(); i++) {
      fds[i].revents = 0;
    }

    int res;
    SYSCHECK_ERR_RETURN_NEG1(
        res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
    if (res == 0) {
      auto rv = WaitForSingleObject(ghStopEvent_, 0);
      if (rv != WAIT_TIMEOUT) {
        finished = true;
        break;
      }
      continue;
    }

    // TCPStore's listening socket has an event and it should now be able to
    // accept new connections.
    if (fds[0].revents != 0) { // 收到了訊息
      if (!(fds[0].revents & POLLIN)) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the master's listening socket: " +
                std::to_string(fds[0].revents));
      }
      int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
      sockets_.push_back(sockFd);
      tcputil::addPollfd(fds, sockFd, POLLIN);
    }
    queryFds(fds); // 業務處理
  }
}
#else

void TCPStoreMasterDaemon::run() {
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);
  // Push the read end of the pipe to signal the stopping of the daemon run
  tcputil::addPollfd(fds, controlPipeFd_[0], POLLHUP);

  // receive the queries
  bool finished = false;
  while (!finished) {
    for (size_t i = 0; i < sockets_.size(); i++) {
      fds[i].revents = 0;
    }

    SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));

    // TCPStore's listening socket has an event and it should now be able to
    // accept new connections.
    if (fds[0].revents != 0) {
      if (fds[0].revents ^ POLLIN) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the master's listening socket: " +
                std::to_string(fds[0].revents));
      }
      int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
      sockets_.push_back(sockFd);
      tcputil::addPollfd(fds, sockFd, POLLIN);
    }

    // The pipe receives an event which tells us to shutdown the daemon
    if (fds[1].revents != 0) { // 收到了訊息
      // Will be POLLUP when the pipe is closed
      if (fds[1].revents ^ POLLHUP) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the control pipe's reading fd: " +
                std::to_string(fds[1].revents));
      }
      finished = true;
      break;
    }
    queryFds(fds); // 業務處理
  }
}
#endif
4.2.6.2 呼叫業務

queryFds 會根據 socket 監聽結果而呼叫不同業務。

void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) {
  // Skipping the fds[0] and fds[1],
  // fds[0] is master's listening socket
  // fds[1] is control pipe's reading fd, it is not for Windows platform
  for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) {
    if (fds[fdIdx].revents == 0) {
      continue;
    }

    // Now query the socket that has the event
    try {
      query(fds[fdIdx].fd); // 處理業務
    } catch (...) {
      tcputil::closeSocket(fds[fdIdx].fd);

      // Remove all the tracking state of the close FD
      for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
        for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
          if (*vecIt == fds[fdIdx].fd) {
            vecIt = it->second.erase(vecIt);
          } else {
            ++vecIt;
          }
        }
        if (it->second.size() == 0) {
          it = waitingSockets_.erase(it);
        } else {
          ++it;
        }
      }
      for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
        if (it->first == fds[fdIdx].fd) {
          it = keysAwaited_.erase(it);
        } else {
          ++it;
        }
      }
      fds.erase(fds.begin() + fdIdx);
      sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET);
      --fdIdx;
      continue;
    }
  }
}

4.2.6.4 處理業務

從 socket 之中讀取訊息,依據訊息內容來進行相關業務處理。

// query communicates with the worker. The format
// of the query is as follows:
// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
// or, in the case of wait
// type of query | number of args | size of arg1 | arg1 | ...
void TCPStoreMasterDaemon::query(int socket) {
  QueryType qt;
  tcputil::recvBytes<QueryType>(socket, &qt, 1);
  if (qt == QueryType::SET) {
    setHandler(socket);

  } else if (qt == QueryType::COMPARE_SET) {
    compareSetHandler(socket);

  } else if (qt == QueryType::ADD) {
    addHandler(socket);

  } else if (qt == QueryType::GET) {
    getHandler(socket);

  } else if (qt == QueryType::CHECK) {
    checkHandler(socket);

  } else if (qt == QueryType::WAIT) {
    waitHandler(socket);

  } else if (qt == QueryType::GETNUMKEYS) {
    getNumKeysHandler(socket);

  } else if (qt == QueryType::DELETE_KEY) {
    deleteHandler(socket);

  } else if (qt == QueryType::WATCH_KEY) {
    watchHandler(socket);

  } else {
    throw std::runtime_error("Unexpected query type");
  }
}

新增

此處是處理新增 value 的業務。

void TCPStoreMasterDaemon::setHandler(int socket) {
  std::string key = tcputil::recvString(socket);
  std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
  std::vector<uint8_t> oldData;
  bool newKey = true;
  auto it = tcpStore_.find(key);
  if (it != tcpStore_.end()) {
    oldData = it->second;
    newKey = false;
  }
  tcpStore_[key] = newData;
  // On "set", wake up all clients that have been waiting
  wakeupWaitingClients(key);
  // Send key update to all watching clients
  newKey ? sendKeyUpdatesToClients(
               key, WatchResponseType::KEY_CREATED, oldData, newData)
         : sendKeyUpdatesToClients(
               key, WatchResponseType::KEY_UPDATED, oldData, newData);
}
獲取

出處處理獲取 value 的業務。

void TCPStoreMasterDaemon::getHandler(int socket) const {
  std::string key = tcputil::recvString(socket);
  auto data = tcpStore_.at(key);
  tcputil::sendVector<uint8_t>(socket, data);
}

watchKey

此處新增了想要監控的 key。

對於WATCH_KEY,給對應的key新增了一個socket,作為以後傳送通知的物件。

void TCPStoreMasterDaemon::watchHandler(int socket) {
  std::string key = tcputil::recvString(socket);

  // Record the socket to respond to when the key is updated
  watchedSockets_[key].push_back(socket);

  // Send update to TCPStoreWorkerDaemon on client
  tcputil::sendValue<WatchResponseType>(
      socket, WatchResponseType::KEY_CALLBACK_REGISTERED);
}
通知

如果key 有變化,就通知客戶端。

void TCPStoreMasterDaemon::sendKeyUpdatesToClients(
    const std::string& key,
    const enum WatchResponseType& type,
    std::vector<uint8_t>& oldData,
    std::vector<uint8_t>& newData) {
  for (int socket : watchedSockets_[key]) {
    tcputil::sendValue<WatchResponseType>(socket, type);
    tcputil::sendString(socket, key, true);
    tcputil::sendVector<uint8_t>(socket, oldData);
    tcputil::sendVector<uint8_t>(socket, newData);
  }
}

4.2.7 總結

我們總結圖例如下:

  • Master 之中使用MasterPort 進行監聽請求。
  • 關於存取value。
    • Worker 之中,storeSocket_ 被用來儲存/獲取value,對應下圖 數字 1。
    • 在 Master 之中對應了 tcpStore_。
  • 關於監控。
    • Worker 之中,listenSocket_ 被用來通知 Master 我需要監聽這個 key,對應下圖 數字 2。同時 worker 內部給這個 key 設定了 callback,對應了下圖 數字 3。
    • 監聽在 Master 之中對應了 watchedSockets_[key] = socket_
    • Master 之中,如果設定 value 時候,發現是一個被監控的 key,就通知 watchedSockets_[key],對應了下圖 數字 4。
    • Worker 之中會進行相關業務呼叫。
                                                                          +
+----------------------------------------------------------------------+  |  +------------------------------------------------------------------------+
| TCPStore                                                      Master |  |  | TCPStore                                                        Worker |
|                                                                      |  |  |                                                                        |
|   storeSocket_                                                       |  |  |                                                                        |
|                                                                      |  |  |                                                                        |
|   +------------------------------------------------------------+     |  |  |                                                                        |
|   | TcpStoreMasterDaemon_                            MasterPort|     |  |  |  1   +---------------------------------+                               |
|   |                                                            | <--------------+ | set(key, value)                 |                               |
|   |   unordered_map<string, vector<uint8_t> > tcpStore_+---+   |     |  |  |      |                                 |                               |
|   |                                                        |   |     |  |  |      |    storeSocket_                 |                               |
|   |   TCPStore.masterListenSocket_                         |   |     |  |  |      |                                 |                               |
|   |                                                        |   |     |  |  |      +---------------------------------+                               |
|   |   +-----------------------------------------------+    |   |     |  |  |                                                                        |
|   |   |  run                                          |    |   |     |  |  |  2   +---------------------------------+                               |
|   |   |                                               |    |   | <--------------+ |                                 |                               |
|   |   |    queryFds     query                         |    |   |     |  |  |      | watchKey(key, callback) +-------------------------------+       |
|   |   |                                               |    |   |     |  |  |      |                                 |        3              |       |
|   |   |    setHandler   getHandler                    |    |   |     |  |  |      |    listenSocket_                |                       |       |
|   |   |                                               |    |   |     |  |  |      |                                 |                       |       |
|   |   +-----------------------------------------------+    |   |     |  |  |      |                                 |                       |       |
|   |                                                        |   |     |  |  |      +---------------------------------+                       |       |
|   +------------------------------------------------------------+     |  |  |                                                                |       |
|                                                            |         |  |  |                                                                |       |
|                                                            |         |  |  |                                                                |       |
|                                                            |         |  |  |   +----------------------------------------------------------------+   |
|                                                            |         |  |  |   | TCPStoreWorkerDaemon                                       |   |   |
|                                                            |         |  |  |   |                                                            |   |   |
|                                                            |         |  |  |   |   unordered_map<string, WatchKeyCallback> keyToCallbacks_  |   |   |
|                                                            |         |  |  |   |                                                            |   |   |
|                                                            |         |  |  |   |   TCPStore.listenSocket_                              +----+   |   |
|                                                            |         |  |  |   |                                                       |        |   |
|                                                            |         |  |  |   |  +----------------------------------------------------------+  |   |
|                                                            |         |  |  |   |  | run                                                |     |  |   |
|                                                            |     4   |  |  |   |  |                                                    |     |  |   |
|                                                            +--------------------->+                                                    v     |  |   |
|                                                                      |  |  |   |  |       callbackHandler +-----> keyToCallbacks_(callback)  |  |   |
|                                                                      |  |  |   |  |                                                          |  |   |
|                                                                      |  |  |   |  +----------------------------------------------------------+  |   |
|                                                                      |  |  |   +----------------------------------------------------------------+   |
+----------------------------------------------------------------------+  +  +------------------------------------------------------------------------+

手機如下:

至此,我們梳理了初始化方法和Store這兩個概念,最終其實是Store這個概念在初始化過程中起了作用。我們也通過TCPStore 的分析知道了一個Store應該具備的功能,比如設定KV,監控某個key的變等等,正是這些功能才可以讓若干程式彼此知道對方的存在。

下一篇我們介紹程式組的概念,敬請期待。

相關文章