[原始碼解析] 深度學習流水線並行 PipeDream(5)--- 通訊模組

羅西的思考發表於2021-09-13

[原始碼解析] 深度學習流水線並行 PipeDream(5)--- 通訊模組

0x00 摘要

在前文中,我們介紹了PipeDream的總體架構,Profile階段,計算分割槽階段,模型轉換階段和執行時引擎,本文我們介紹PipeDream 的通訊模組,通訊模組是引擎的基礎,同時也是PyTorch DDP,P2P 如何使用的一個萬花筒和完美示例。

流水線並行其他文章連結如下:

[原始碼解析] 深度學習流水線並行Gpipe(1)---流水線基本實現

[原始碼解析] 深度學習流水線並行GPipe (2) ----- 梯度累積

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

[原始碼解析] 深度學習流水線並行之PipeDream(1)--- Profile階段

[原始碼解析] 深度學習流水線並行 PipeDream(2)--- 計算分割槽

[原始碼解析] 深度學習流水線並行 PipeDream(3)--- 轉換模型

[原始碼解析] 深度學習流水線並行 PipeDream(4)--- 執行時引擎

0x01 前言

通訊模組程式碼位於:runtime/communication.py。我們首先思考一下,通訊模組需要哪些功能?

  • 階段(Stage)之間的通訊,如果階段在不同機器上如何處理?在同一個機器上如何處理?

  • 因為是非同步通訊為主,不同節點的效能可能不同,是否需要一個快取機制來協調不同節點,類似背壓功能?

  • 深度學習引數眾多,涉及的張量和梯度眾多,層數眾多,每層的資料並行數目也不同,所以前向傳播和反向傳播如何保證按照確定次序執行?

  • 因為節點上需要進行前向,後向傳播,所以需要建立多個執行緒進行分別傳輸。

因此我們下面分析時候,就結合這些問題進行思考。

0x02 類定義

CommunicationHandler 負責在階段(Stage)之間的通訊。

  • 如果階段位於不同機器上,就使用 PyTorch p2p 的 send/recv。
  • 如果階段位於同一個機器上,則使用 PyTorch p2p 的 broadcast。

下面程式碼中,主要就是初始化各種成員變數,我們目前最熟悉的是和DDP相關的,比如init_process_group。

class CommunicationHandler(object):
    """ Handles communication between stages.

    For stages on different machines, use send/recv.
    For stages on same machine, use broadcast.
    """
    def __init__(self, master_addr, master_port, rank,
                 local_rank, num_ranks_in_server,
                 world_size, fp16, backend):
        """ Set up process groups.

        Note: To turn off broadcasting, set num_ranks_in_server = 1.
        """
        self.rank = rank
        self.local_rank = local_rank
        self.backend = backend
        self.num_ranks_in_server = num_ranks_in_server
        self.world_size = world_size
        self.fp16 = fp16
        assert num_ranks_in_server > 0

        # Initialize the distributed environment.
        # 以下是為了 DDP
        os.environ['MASTER_ADDR'] = master_addr
        os.environ['MASTER_PORT'] = str(master_port)
        dist.init_process_group(backend, rank=rank, world_size=world_size)
        assert dist.get_world_size() == self.world_size

        # Stores list of ranks of GPUs on the same server.
        self.ranks_in_server = []

        if num_ranks_in_server == 1:
            return

        # Stores information about tensors sent directly GPU-to-GPU.
        self.connection_list = []

        # Stores process groups (for broadcast() connections).
        self.process_groups = {}

        # Populate ranks_in_server.
        rank_of_first_gpu_in_server = rank - rank % num_ranks_in_server
        for connected_rank in range(
            rank_of_first_gpu_in_server,
            rank_of_first_gpu_in_server + num_ranks_in_server):
            if connected_rank == rank:
                continue
            self.ranks_in_server.append(connected_rank)
        assert len(self.ranks_in_server) == num_ranks_in_server - 1, \
            self.ranks_in_server

0x03 構建

3.1 初始化

前面章節中提到,當生成了CommunicationHandler之後,會呼叫initialize進行初始化。

        if self.comm_handler is not None:
            self.comm_handler.initialize(
                self.receive_ranks,
                self.send_ranks,
                self.tensor_tags,
                self.target_tensor_names,
                self.training_tensor_dtypes,
                self.rank_in_stage,
                self.num_ranks_in_stage,
                self.ranks_in_previous_stage,
                self.ranks_in_next_stage)

在初始化程式碼之中,完成如下操作,主要是:

  • 構建通訊需要的queue。
  • 構建傳送訊息的次序。
  • 構建程式組。
    def initialize(self, receive_ranks, send_ranks,
                   tensor_tags, target_tensor_names,
                   training_tensor_dtypes,
                   rank_in_stage,
                   num_ranks_in_stage,
                   ranks_in_previous_stage,
                   ranks_in_next_stage):
        """
        Initialize state needed for CommunicationHandler.
        """
        self.receive_ranks = receive_ranks
        self.send_ranks = send_ranks
        self.tensor_tags = tensor_tags
        self.target_tensor_names = target_tensor_names
        self.training_tensor_dtypes = training_tensor_dtypes
        self.rank_in_stage = rank_in_stage
        self.num_ranks_in_stage = num_ranks_in_stage
        self.ranks_in_previous_stage = ranks_in_previous_stage
        self.num_ranks_in_previous_stage = len(ranks_in_previous_stage)
        self.ranks_in_next_stage = ranks_in_next_stage
        self.num_ranks_in_next_stage = len(ranks_in_next_stage)

        self.setup_queues() # 構建通訊需要的queue
        self.setup_messaging_schedule() # 構建傳送訊息的次序
        self.create_process_groups() # 構建程式組

我們具體分析如下。

3.2 建立queue

Queue 的作用是作為 send,receive 的基礎,系統通過index找到哪一個queue,然後進行相應操作。

initialize 函式傳入了兩個ranks列表。

  • receive_ranks 就是本節點的輸入rank。
  • send_ranks 就是本節點的輸出rank。

ranks 列表舉例如下:

receive_ranks = {dict: 3}  # 這裡就是每個tensor對應的接收目標rank
 'out8' = {list: 1} [2] # out8 是tensor name, {list: 1} [2] 是 out8 對應的 ranks
 'out9' = {list: 1} [2] # 就是這幾個張量都要從 rank 2 接收
 'out10' = {list: 1} [2]
 __len__ = {int} 3

setup_queues 相應一共建立了4個queue列表:

  • forward_receive_queues :前向傳播過程中,接受張量的queue。對應了 receive_ranks
  • backward_send_queues : 後向傳播過程中,傳送張量的queue。對應了 receive_ranks。因為前向傳播中接受的物件,就是後向傳播中傳送的目標。
  • forward_send_queues : 前向傳播過程中,傳送張量的queue。對應了 send_ranks
  • backward_receive_queues :後向傳播過程中,接受張量的queue。對應了 send_ranks。因為前向傳播中傳送的目標就是後向傳播中接受的物件。

大致邏輯如下:

forward_receive_queues <-----> receive_ranks <------->  backward_send_queues
forward_send_queues  <------>  send_ranks    <------->  backward_receive_queues

以 forward_receive_queues 為例。

  • forward_receive_queues 這個列表包括多個queue。
  • receive_ranks 列表中包括多個 rank,每個rank在通訊過程之中,對應了一個張量,可以認為 receive_ranks 包括多個張量,由一個張量名字來對應。張量名字類似於:target_tensor_names = {"target", "target_length"}。
  • forward_receive_queues 列表之中,每一個queue對應了receive_ranks 之中的一個 張量。
  • 每個張量,對應一個唯一的tag,PipeDream的目的是讓每一個tag都有自己的process group,因為任何一個stage都有可能並行。
  • 針對這個張量和這個唯一的tag,註冊 [tag, rank] 到 connection_list。

具體如下:

    def setup_queues(self):
        """
        Setup queues for communication between main compute thread
        and helper communication threads. One queue per tensor
        in forward / backward direction.
        """
        self.forward_receive_queues = {}
        self.backward_receive_queues = {}
        self.forward_send_queues = {}
        self.backward_send_queues = {}
        self.num_forward_threads = 0
        self.num_backward_threads = 0

        self.target_receive_rank_counts = {}
        self.target_send_rank_counts = {}
        # Setup queues for each tensor to be received and sent.
        for input_name in self.receive_ranks: # 遍歷張量
            # 與 input_name 張量對應的queue,input_name 是張量名字
            self.forward_receive_queues[input_name] = []
            self.backward_send_queues[input_name] = []
            # 遍歷該張量對應的每個 ranks
            for i in range(len(self.receive_ranks[input_name])):
                self.forward_receive_queues[input_name].append(
                    threadsafe_queue.Queue())
                self.backward_send_queues[input_name].append(
                    threadsafe_queue.Queue())
                # 得到 rank
                target_receive_rank = self.receive_ranks[input_name][i]
                # 針對 rank,註冊張量
                self.register_tensor(
                    connected_rank=target_receive_rank,
                    tag=self.tensor_tags[input_name])
                if target_receive_rank not in self.target_receive_rank_counts:
                    self.target_receive_rank_counts[target_receive_rank] = 0
                self.target_receive_rank_counts[target_receive_rank] += 1
                self.num_forward_threads += 1
                self.num_backward_threads += 1
                
        for output_name in self.send_ranks: # 遍歷張量
            # 與 output_name 張量對應的queue
            self.backward_receive_queues[output_name] = []
            self.forward_send_queues[output_name] = []
            # 遍歷該張量對應的每個 ranks
            for i in range(len(self.send_ranks[output_name])):
                self.backward_receive_queues[output_name].append(
                    threadsafe_queue.Queue())
                self.forward_send_queues[output_name].append(
                    threadsafe_queue.Queue())
                # 得到 rank
                target_send_rank = self.send_ranks[output_name][i]
                # 針對 rank,註冊張量
                self.register_tensor(
                    connected_rank=target_send_rank,
                    tag=self.tensor_tags[output_name])
                if target_send_rank not in self.target_send_rank_counts:
                    self.target_send_rank_counts[target_send_rank] = 0
                self.target_send_rank_counts[target_send_rank] += 1
                self.num_forward_threads += 1
                self.num_backward_threads += 1

        # 單獨處理目標tensor
        for target_tensor_name in self.target_tensor_names:
            # Queues for target in forward pass.
            self.forward_receive_queues[target_tensor_name] = []
            self.forward_send_queues[target_tensor_name] = []

            if self.num_ranks_in_previous_stage > 0:
                self.receive_ranks[target_tensor_name] = self.ranks_in_previous_stage
                for i in range(len(self.receive_ranks[target_tensor_name])):
                    # 針對 rank,註冊張量
                    self.register_tensor(
                        connected_rank=self.receive_ranks[target_tensor_name][i],
                        tag=self.tensor_tags[target_tensor_name])
                    self.forward_receive_queues[target_tensor_name].append(
                        threadsafe_queue.Queue())
                    self.num_forward_threads += 1

            if self.num_ranks_in_next_stage > 0:
                self.send_ranks[target_tensor_name] = self.ranks_in_next_stage
                for i in range(len(self.send_ranks[target_tensor_name])):
                    self.register_tensor(
                        connected_rank=self.send_ranks[target_tensor_name][i],
                        tag=self.tensor_tags[target_tensor_name])
                    self.forward_send_queues[target_tensor_name].append(
                        threadsafe_queue.Queue())
                    self.num_forward_threads += 1

        print ("Send ranks: ", self.send_ranks)
        print ("Receive ranks: ", self.receive_ranks)

        # Queues for ack for forward pass-only runs as a clocking mechanism.
        # 單獨處理 ack 情況
        self.num_ack_threads = 0
        if "ack" in self.tensor_tags:
            self.backward_receive_queues["ack"] = []
            self.backward_send_queues["ack"] = []
            for i in range(self.num_ranks_in_previous_stage):
                # 針對 rank,註冊張量
                self.register_tensor(
                    connected_rank=self.ranks_in_previous_stage[i],
                    tag=self.tensor_tags["ack"])
                self.backward_send_queues["ack"].append(
                    threadsafe_queue.Queue())
                self.num_ack_threads += 1
            for i in range(self.num_ranks_in_next_stage):
                # 針對 rank,註冊張量
                self.register_tensor(
                    connected_rank=self.ranks_in_next_stage[i],
                    tag=self.tensor_tags["ack"])
                self.backward_receive_queues["ack"].append(
                    threadsafe_queue.Queue())
                self.num_ack_threads += 1

注意,每個張量有唯一一個tag,針對這個張量和這個唯一的tag,註冊 [tag, rank] 到 connection_list。

    def register_tensor(self, connected_rank, tag):
        """
        Builds connections list of tensors that are communicated GPU to GPU.

        For tensors that are sent GPU-to-GPU (intra-server for GLOO backend),
        make a list of destination/source ranks and the corresponding tag.
        This information is then used to crate process groups.
        """
        if not self.is_gpu_to_gpu_comm(connected_rank=connected_rank):
            return
        connection_info = [tag, connected_rank]
        self.connection_list.append(connection_info)

於是,此時邏輯如下,我們僅僅以部分 ranks,queue等舉例,forward_receive_queues 之中的這幾個queue 就是用來作為對應張量的buffer。

+------------------------+         'out8' = {list: 1} [2]
|                        |
|     receive_ranks +----------->  'out9' = {list: 1} [2]
|                        |
+------------------------+         'out10' = {list: 1} [2]



+--------------------------+
|                          |         'out8' : Queue
| forward_receive_queues+-------->
|                          |         'out9' : Queue
+--------------------------+
                                     'out10' : Queue




+--------------------------+       'out8' : rank 2
|                          |
|    connection_list  +--------->  'out9' : rank 2
|                          |
+--------------------------+       'out10' : rank 2

3.3 前向後向順序

接下來建立訊息傳遞的前後向順序,其目的是為了讓每個 worker 記錄如何處理由前向層/後向層傳來的rank。

3.3.1 建立順序

setup_messaging_schedule 方法就是用來建立:

  • 前向傳播時接受的順序。
  • 後向傳播時傳送的順序。

這裡的重點是:如果前一層數目比本層數目多,就把 i對應的前一層ranki + (本層rank數目) * n 對應的前一層rank 都加入到本層 i 的計劃(self.message_schedule)。n 等於 num_ranks_in_stage。

最終把順序放入 self.messaging_schedule 成員變數。假如本stage是擁有 3 個rank,則 self.messaging_schedule 就是這三個rank 分別的 message_schedule,每個 message_schedule 裡面都是對應的上一層 某些 ranks。

再細化一下:

  • self.messaging_schedule 是一個列表。
  • self.messaging_schedule 其中每一個item又是一個列表。self.messaging_schedule[ i ] 就表示比如 本層 第 i 個 rank 對應的 schedule(message_schedule)。
  • schedule(message_schedule)是上一層 或者 下一層 的某些ranks。
  • message_schedule包括的ranks是本stage所包括ranks的一個index。因為是內部使用,所以不需要是真正的 rank 數值,只要能和內部的queue等其他內部資料結構對映上即可。

程式碼如下:

    def setup_messaging_schedule(self):
        """ Order in which to receive forward and send backwards.

        Separate indexes of ranks in previous stage based on their
        corresponding offset in this stage. Then each worker will go
        in increasing order within a subset, and process subsets in
        a decreasing order.

        This is done so that messages are processed in the order
        that they are sent. Backwards send is done so that that it
        matches up with forward receive.
        """
        self.messaging_schedule = []
        for i in range(self.num_ranks_in_stage): # 本stage的並行數目
            idx = i
            message_schedule = []
            while idx < self.num_ranks_in_previous_stage: # 上一個stage的並行數目
                message_schedule.append(idx)
                # 如果前一層比本層多,就把 i, i + (本層rank) * n 對應的前一層rank都加入到本層 i 的計劃
                idx += self.num_ranks_in_stage
            if len(message_schedule) > 0:
                self.messaging_schedule.append(message_schedule)

        self.fwd_messaging_scheduling_row = self.rank_in_stage # 自己的rank index
        self.fwd_messaging_scheduling_col = 0 # receive forward
        self.bwd_messaging_scheduling_row = self.rank_in_stage # 自己的rank index
        self.bwd_messaging_scheduling_col = 0 # send backwards

        # For cases where previous stage has less workers than current stage.
        while self.fwd_messaging_scheduling_row >= \
            len(self.messaging_schedule):
            self.fwd_messaging_scheduling_row -= 1
            self.bwd_messaging_scheduling_row -= 1

具體邏輯如下:

+-------------------+                 +--------------------------------------------------+
| Stage 0           |                 | Stage 1                                          |
|                   |                 |                                                  |
|                   |                 |                                                  |
|                   |                 |     +----------------------------------------+   |
|                   |   send_ranks    |     | messaging_schedule                     |   |
|  ranks:           |                 |     |                                        |   |
|                   +---------------> |     |                                        |   |
|  [0,1,2,3,4,5,    |                 |     |   message_schedule +---> [0,1,2,9]     |   |
|  6,7,8,9,10,11,12]|                 |     |                                        |   |
|                   |                 |     |   message_schedule +---> [3,4,5,6,10]  |   |
|                   |                 |     |                                        |   |
|                   |                 |     |   message_schedule +---> [6,7,8,11]    |   |
|                   |                 |     |                                        |   |
|                   |                 |     +----------------------------------------+   |
|                   |                 |                                                  |
+-------------------+                 +--------------------------------------------------+

3.3.2 獲取訊息序列

get_messaging_index 方法是用來獲取本次傳遞的物件,就是應該和哪個rank進行互動。

    def get_messaging_index(self, sending):
        if sending:
            connection_rank = self.messaging_schedule[
                self.bwd_messaging_scheduling_row][
                    self.bwd_messaging_scheduling_col]
        else:
            connection_rank = self.messaging_schedule[
                self.fwd_messaging_scheduling_row][
                    self.fwd_messaging_scheduling_col]

        return connection_rank

哪裡用到了 get_messaging_index?原來是send, recv 函式,就是和前一層打交道時候會用到。

比如:

    def recv(self, tensor_name, forward_minibatch_id,
             backward_minibatch_id, backward=False):
        if backward:
            index = (backward_minibatch_id + self.rank_in_stage) % \
                len(self.backward_receive_queues[tensor_name])
            tensor = self.backward_receive_queues[tensor_name][
                index].remove()
            return tensor
        else:
            # 這裡會使用到,獲取與哪一個rank進行互動
            index = self.get_messaging_index(sending=False)
            # 然後得到使用哪個張量,從queue之中提取對應的最新張量
            tensor = self.forward_receive_queues[tensor_name][
                index].remove()
            if tensor.dtype == torch.float32:
                tensor = tensor.requires_grad_()
            return tensor

3.3.3 增加訊息序列

increment_messaging_index 方法用來增加訊息序列,就是得到下一次應該使用哪個訊息。

其中,兩個引數需要說明:

  • bwd_messaging_scheduling_col 表示上游具體哪一個 rank index。

  • bwd_messaging_scheduling_row 表示自己的 rank index。

方法如下:

    def increment_messaging_index(self, sending):
        if sending:
            self.bwd_messaging_scheduling_col += 1 # send backwards 對應的下一個 rank
            if self.bwd_messaging_scheduling_col == len(
                    self.messaging_schedule[
                        self.bwd_messaging_scheduling_row]):
                self.bwd_messaging_scheduling_col = 0
                self.bwd_messaging_scheduling_row -= 1 # 自己的rank index
                if self.bwd_messaging_scheduling_row == -1:
                    self.bwd_messaging_scheduling_row = \ # 重置回self.messaging_schedule,繼續新的一輪本地 rank通訊
                        len(self.messaging_schedule) - 1
        else:
            self.fwd_messaging_scheduling_col += 1 # receive forward 對應的下一個 rank
            if self.fwd_messaging_scheduling_col == len(
                    self.messaging_schedule[
                        self.fwd_messaging_scheduling_row]): 
                self.fwd_messaging_scheduling_col = 0
                self.fwd_messaging_scheduling_row -= 1 # 自己的rank index
                if self.fwd_messaging_scheduling_row == -1:
                    self.fwd_messaging_scheduling_row = \ # 重置回self.messaging_schedule,繼續新的一輪本地 rank通訊
                        len(self.messaging_schedule) - 1

哪裡會用到?在以下幾個函式中會用到:

    def receive_tensors_forward(self):
        if self.loader_iter is not None:
			# ......
        else:
            # Receive all required tensors from upstream machines.
			# ......
            # Used to track where to receive forward from.
            self.comm_handler.increment_messaging_index(
                sending=False)

    def send_tensors_backward(self):
        # Send all required gradients upstream.

        if self.num_ranks_in_previous_stage > 0:
            # Used to track where to send tensors in the
            # backward pass.
            self.comm_handler.increment_messaging_index(
                sending=True)    
            
    def run_ack(self):
        if self.stage > 0:
            self.comm_handler.send(
                "ack",
                torch.zeros(self.tensor_shapes["ack"],
                            dtype=torch.int64).cuda(),
                forward_minibatch_id=self.forward_minibatch_id,
                backward_minibatch_id=self.backward_minibatch_id,
                backward=True)

            # Used to track where to receive forward from.
            self.comm_handler.increment_messaging_index(sending=True)        

3.4 建立程式組

目的是:針對每個張量,設定兩個程式組,一個用於前向,一個用於後向。每一個張量有一個自己的tag。每一個tag都有自己的兩個process group,因為任何一個stage都有可能並行。

3.4.1 設計

首先,我們看看註釋,學習一下為何這麼設計。

create_process_groups 方法在所有rank之中以同樣順序建立程式組。為了以同樣順序建立程式組,每個worker都會收集其他所有workers的connection_list(GPU to GPU)。為了做到這一點,每個worker收集所有其他workers的連線列表connection_list(L)的最大大小。然後每個worker建立一個大小為Lx2的張量,其中每行表示一個連線,並根據“它本身連線列表大小”來填充此張量。擁有最大連線列表的worker將填充整個張量。

構建此列表後,將執行all_gather操作,然後每個worker都擁有一個相同的 NxLx2 輸出,其中N是worker 數量(world_size),輸出的每個index代表一個worker的連線列表。對於 i=self.rank,輸出將與本worker的本地連線列表相同。

每個worker以相同的順序在連線列表上進行迭代,檢查是否已建立每個連線(每個連線都將在輸出中出現兩次),如果連線不存在,則對於前向和後向都建立一個新的程式組。既然在程式組中rank永遠是一致的,所以小rank排在前面,大的rank排在後面。

3.4.2 程式碼

回到程式碼上,我們仔細分析下。

+--------------------------+       'out8' : rank 2
|                          |
|    connection_list  +--------->  'out9' : rank 2
|                          |
+--------------------------+       'out10' : rank 2

這裡就用到了 connection_list。具體邏輯是:

  • 找到 workers 之中最大的 connection_list
  • 獲取到 connection_list 的大小,即 connection_list_size
  • 用集合通訊來對 connection_list_size 進行聚合,最後得到的gathered_connection_list_sizes就是所有節點上的 connection_list_size 集合
  • 得到connection_list的最大數值
  • 利用最大數值來構建張量列表 connection_list_tensor
  • 把張量移動到GPU之上
  • 用集合通訊來對 connection_list_tensor進行聚合,得到aggregated_connection_list
  • 在每個worker之上,利用 dist.new_group 建立同樣的程式組
  • 遍歷aggregated_connection_list中的每一個connection
    • 得到張量對應的tag
    • 針對每個張量,設定兩個程式組,一個前向,一個後向

因此,目的就是在每個 worker 之中建立同樣的程式組,針對每個張量,設定兩個程式組,一個前向,一個後向。

具體程式碼如下:

    def create_process_groups(self):
        """ Create process groups in the same order across all ranks.

        To create process groups in the same order, each worker collects
        the connection_list of all other workers. To do this, every worker
        gathers the largest size of all other worker's connection_lists (L).
        Then every worker creates a tensor of size Lx2, where each row
        represents a connection, and fills up this tensor depending on how
        large its own connection list is. The worker(s) w/ the largest
        connection list will fill up the entire tensor.

        After constructing this list, an all_gather is performed, after which
        each worker has an identical NxLx2 output, where N is the number of
        workers (world_size), and each index of output represents a worker's
        connection list. For i=self.rank, the output will be identical to the
        workers local connection list.

        Each worker then iterates in the same order over the connections list,
        checking if each connection has been created yet (every connection will
        appear twice in the output), and creating a new process group if one
        doesn't exist for that connection, for both the forward and backward
        direction. Since ranks within process groups must always be identical,
        the smaller rank always goes first, followed by the larger rank.
        """
        if self.num_ranks_in_server == 1:
            return

        print("Setting up process groups for broadcasts...")

        # Figure out the size of the largest connection list that any worker
        # has (L).
        # 找到最大的 connection_list
        # 獲取到 connection_list 的大小,即 connection_list_size
        connection_list_size = torch.tensor(
            len(self.connection_list), dtype=torch.int)
        if self.backend == NCCL:
            connection_list_size = connection_list_size.cuda()
        gathered_connection_list_sizes = [
            torch.ones_like(connection_list_size)
            for _ in range(self.world_size)]
        
        # 用集合通訊來對 connection_list_size 進行聚合,最後得到的gathered_connection_list_sizes就是所有節點上的 connection_list_size 集合
        dist.all_gather(gathered_connection_list_sizes,
                        connection_list_size)
        # 得到最大數值
        max_connection_list_size = max(
            gathered_connection_list_sizes)

        if max_connection_list_size == 0:
            return 

        # 利用最大數值來構建張量列表 connection_list_tensor
        # Build tensor to send local connection list to all other workers.
        connection_list_tensor = torch.ones([max_connection_list_size, 2],
                                            dtype=torch.int) * -1
        # 把張量移動到GPU之上
        if self.backend == NCCL:
            connection_list_tensor = connection_list_tensor.cuda()
        if len(self.connection_list) > 0:
            connection_list_tensor[0:len(self.connection_list)] = \
                torch.IntTensor(self.connection_list)

        # 用集合通訊來對 connection_list_tensor進行聚合       
        # Gather connection lists of all workers.
        aggregated_connection_list = [
            torch.ones_like(connection_list_tensor)
            for _ in range(self.world_size)]
        dist.all_gather(aggregated_connection_list,
                        connection_list_tensor)

        # 在每個worker之上,利用 dist.new_group 建立同樣的程式組
        # Construct identical process groups on each worker.
        local_rank_connections = 0
        for src_rank in range(len(aggregated_connection_list)):
            for connection in aggregated_connection_list[src_rank]:
                # 得到張量對應的tag
                tag = int(connection[0])
                dst_rank = int(connection[1])

                if tag == -1:
                    assert dst_rank == -1
                    continue

                min_rank = min(src_rank, dst_rank)
                max_rank = max(src_rank, dst_rank)
                assert min_rank != max_rank

                if min_rank not in self.process_groups:
                    self.process_groups[min_rank] = {}

                if max_rank not in self.process_groups[min_rank]:
                    self.process_groups[min_rank][max_rank] = {}

                if tag not in self.process_groups[min_rank][max_rank]:
                    # 用到了pytorch p2p 的api
                    sub_process_group_fwd = dist.new_group(
                        ranks=[min_rank, max_rank])
                    sub_process_group_bwd = dist.new_group(
                        ranks=[min_rank, max_rank])

                    # 針對每個張量,設定程式組
                    self.process_groups[min_rank][max_rank][tag] = {
                        'forward': sub_process_group_fwd,
                        'backward': sub_process_group_bwd
                    }

                    if min_rank == self.rank or max_rank == self.rank:
                        local_rank_connections += 1
        assert local_rank_connections == len(self.connection_list)

具體 如何使用程式組?在 recv_helper_thread_args 等函式會使用,比如:

    def recv_helper_thread_args(self, tensor_name, index, dtype,
                                backward, num_iterations):
        if backward:
            src_rank = self.send_ranks[tensor_name][index]
        else:
            src_rank = self.receive_ranks[tensor_name][index]

        sub_process_group = None
        # 獲取張量 tensor_name 對應的 tag
        tag = self.tensor_tags[tensor_name]
        if self.is_gpu_to_gpu_comm(connected_rank=src_rank) and tensor_name != "ack":
            min_rank = min(self.rank, src_rank)
            max_rank = max(self.rank, src_rank)
            
            if src_rank > self.rank:
                # 獲取 tag 對應的程式組,呼叫者後續會使用
                sub_process_group = \
                    self.process_groups[min_rank][max_rank][tag]['backward']
            else:
                # 獲取 tag 對應的程式組,呼叫者後續會使用
                sub_process_group = \
                    self.process_groups[min_rank][max_rank][tag]['forward']
            assert sub_process_group

        if backward:
            queue = self.backward_receive_queues[tensor_name][index]
        else:
            queue = self.forward_receive_queues[tensor_name][index]
        tensor_shape = self.tensor_shapes[tensor_name]

        return (queue, self.counter, self.local_rank, tensor_name,
                src_rank, tag, tensor_shape, dtype, sub_process_group,
                num_iterations)

3.5 啟動助手執行緒

使用 start_helper_threads 來進行啟動助手執行緒。這些助手執行緒是為了 P2P 使用。

首先,ranks舉例,可以看出來,key 是張量名字,value 是ranks列表。

receive_ranks = {dict: 3}  # 這裡就是每個tensor對應的接收目標rank
 'out8' = {list: 1} [2]
 'out9' = {list: 1} [2]
 'out10' = {list: 1} [2]
 __len__ = {int} 3

3.5.1 建立執行緒

回憶一下之前建立的 4 個queues:

  • forward_receive_queues :前向傳播過程中,接受張量的queue。對應了 receive_ranks
  • backward_send_queues : 後向傳播過程中,傳送張量的queue。對應了 receive_ranks。因為前向傳播中接受的物件,就是後向傳播中傳送的目標。
  • forward_send_queues : 前向傳播過程中,傳送張量的queue。對應了 send_ranks
  • backward_receive_queues :後向傳播過程中,接受張量的queue。對應了 send_ranks。因為前向傳播中傳送的目標就是後向傳播中接受的物件。

這 4 個queue 其實就對應了 4 個不同的助手執行緒。

思路是:

  • 針對接受ranks進行處理,即遍歷 receive_ranks 中的張量
    • 遍歷張量對應的ranks,對於每一個rank
      • 需要後向處理,所以建立後向傳送執行緒
      • 建立接受助手執行緒
  • 針對傳送ranks進行處理,即遍歷 send_ranks 中的張量
    • 遍歷張量對應的ranks,對於每一個rank
      • 需要後向處理,所以建立後向接受執行緒
      • 建立傳送助手執行緒
  • 針對target進行處理
  • 如果只有前向,則需要補齊ack

具體程式碼是:

    def start_helper_threads(self, num_iterations, forward_only):
        """
        Start helper communication threads, one for each queue.
        """
        if forward_only:
            self.set_counter(self.num_forward_threads +
                             self.num_ack_threads)
            # For validation, receive acks in backward pass from next stage, send
            # acks in backward pass to next stage.
            self.receive_ranks["ack"] = self.ranks_in_previous_stage
            self.send_ranks["ack"] = self.ranks_in_next_stage
        else:
            self.set_counter(self.num_forward_threads +
                             self.num_backward_threads)
            if "ack" in self.receive_ranks:
                del self.receive_ranks["ack"]
            if "ack" in self.send_ranks:
                del self.send_ranks["ack"]

        (num_iterations_for_forward_threads,
         num_iterations_for_backward_threads) = \
            self.num_iterations_for_helper_threads(
                num_iterations=num_iterations)
        dtype = torch.float16 if self.fp16 else torch.float32

        # Setup queues for each tensor to be received and sent.
        # 針對接受rank進行處理
        for input_name in self.receive_ranks:
            if input_name in self.target_tensor_names or input_name == "ack":
                continue

            # 遍歷張量對應的ranks
            for i in range(len(self.receive_ranks[input_name])):
                if not forward_only:
                    # 需要後向處理,所以建立後向傳送執行緒
                    self.start_helper_thread(
                        self.send_helper_thread_args,
                        send_helper_thread,
                        [input_name, i, True],
                        num_iterations_for_backward_threads)
                # 建立接受助手執行緒    
                self.start_helper_thread(
                    self.recv_helper_thread_args,
                    recv_helper_thread,
                    [input_name,
                     i,
                     self.training_tensor_dtypes[input_name],
                     False],
                    num_iterations_for_backward_threads)
             
        # 針對傳送ranks進行處理
        for output_name in self.send_ranks:
            if output_name in self.target_tensor_names or output_name == "ack":
                continue

            # 遍歷張量對應的ranks
            for i in range(len(self.send_ranks[output_name])):
                if not forward_only:
                    # 需要後向處理,所以建立後向接受執行緒
                    self.start_helper_thread(
                        self.recv_helper_thread_args,
                        recv_helper_thread,
                        [output_name, i,
                         self.training_tensor_dtypes[output_name],
                         True],
                        num_iterations_for_forward_threads)
                # 傳送助手執行緒
                self.start_helper_thread(
                    self.send_helper_thread_args,
                    send_helper_thread,
                    [output_name, i, False],
                    num_iterations_for_forward_threads)

        # 針對target進行處理
        for target_tensor_name in self.target_tensor_names:
            if self.num_ranks_in_previous_stage > 0:
                for i in range(len(self.receive_ranks[target_tensor_name])):
                    self.start_helper_thread(
                        self.recv_helper_thread_args,
                        recv_helper_thread,
                        [target_tensor_name, i, torch.int64,
                         False],
                        num_iterations_for_backward_threads)

            if self.num_ranks_in_next_stage > 0:
                for i in range(len(self.send_ranks[target_tensor_name])):
                    self.start_helper_thread(
                        self.send_helper_thread_args,
                        send_helper_thread,
                        [target_tensor_name, i, False],
                        num_iterations_for_forward_threads)

        # Start helper threads for ack for forward pass-only run as a clocking
        # mechanism.
        # 如果只有前向,則需要補齊ack
        if forward_only:
            # 有前向就補齊 ack
            if "ack" in self.receive_ranks:
                for i in range(len(self.receive_ranks["ack"])):
                    self.start_helper_thread(self.send_helper_thread_args,
                                             send_helper_thread,
                                             ["ack", i, True],
                                             num_iterations_for_backward_threads)
            if "ack" in self.send_ranks:
                for i in range(len(self.send_ranks["ack"])):
                    self.start_helper_thread(self.recv_helper_thread_args,
                                             recv_helper_thread,
                                             ["ack", i, torch.int64, True],
                                             num_iterations_for_forward_threads)


具體執行緒建立函式為:

    def start_helper_thread(self, args_func, func, args_func_args, num_iterations):
        """
        Start passed-in func on a helper thread.
        """
        args_func_args += [num_iterations]
        args = args_func(*args_func_args) # 需要注意的是使用函式來獲取對應的引數
        helper_thread = threading.Thread(target=func, # 用執行緒主函式來執行執行緒
                                         args=args)
        helper_thread.start()

3.5.2 執行緒主函式

recv_helper_thread 和 send_helper_thread 分別是 接受助手執行緒 和 傳送助手執行緒。分別呼叫 _recv 和 _send 來完成具體業務工作。

需要注意的是使用函式來獲取對應的引數。就是使用 recv_helper_thread_args 和 send_helper_thread_args 來獲取引數。

def recv_helper_thread(queue, counter, local_rank, tensor_name,
                       src_rank, tag, tensor_shape, dtype,
                       sub_process_group, num_iterations):
    torch.cuda.set_device(local_rank)
    # This method is to be executed from a helper daemon thread.
    for i in range(num_iterations):
        tensor = _recv(
            tensor_name, src_rank, tensor_shape=tensor_shape,
            dtype=dtype, tag=tag,
            sub_process_group=sub_process_group)
        queue.add(tensor)
    counter.decrement()

def send_helper_thread(queue, counter, local_rank, tensor_name,
                       src_rank, dst_rank, tag,
                       sub_process_group, num_iterations):
    torch.cuda.set_device(local_rank)
    # This method is to be executed from a helper daemon thread.
    for i in range(num_iterations):
        tensor = queue.remove()
        _send(tensor, tensor_name, src_rank, dst_rank,
              tag=tag,
              sub_process_group=sub_process_group)
    counter.decrement()

3.5.3 構建引數

回憶一下,在 create_process_groups 方法中,有如下程式碼,這裡就給每一個 tag 設定了 程式組,在助手執行緒之中,就要利用這些程式組來完成邏輯:

if tag not in self.process_groups[min_rank][max_rank]:
	sub_process_group_fwd = dist.new_group(ranks=[min_rank, max_rank])
    sub_process_group_bwd = dist.new_group(ranks=[min_rank, max_rank])

	self.process_groups[min_rank][max_rank][tag] = {
    	'forward': sub_process_group_fwd,
        'backward': sub_process_group_bwd
	}

使用如下函式來完成對執行緒主函式引數的獲取。基本邏輯就是:

  • 利用張量名字,獲取到對應的rank
  • 利用張量名字,獲取到對應的tag
  • 使用tag來獲取到對應的程式組
  • 利用張量名字和index得到對應的queue
  • 返回引數
    def recv_helper_thread_args(self, tensor_name, index, dtype,
                                backward, num_iterations):
        # 利用張量名字,獲取到對應的rank
        if backward:
            src_rank = self.send_ranks[tensor_name][index]
        else:
            src_rank = self.receive_ranks[tensor_name][index]

        # 利用張量名字,獲取到對應的tag
        sub_process_group = None
        tag = self.tensor_tags[tensor_name]
        
        # 使用tag來獲取到對應的程式組
        if self.is_gpu_to_gpu_comm(connected_rank=src_rank) and tensor_name != "ack":
            min_rank = min(self.rank, src_rank)
            max_rank = max(self.rank, src_rank)
            if src_rank > self.rank:
                sub_process_group = \
                    self.process_groups[min_rank][max_rank][tag]['backward']
            else:
                sub_process_group = \
                    self.process_groups[min_rank][max_rank][tag]['forward']
            assert sub_process_group

        # 得到對應的queue
        if backward:
            queue = self.backward_receive_queues[tensor_name][index]
        else:
            queue = self.forward_receive_queues[tensor_name][index]
        tensor_shape = self.tensor_shapes[tensor_name]

        # 返回引數
        return (queue, self.counter, self.local_rank, tensor_name,
                src_rank, tag, tensor_shape, dtype, sub_process_group,
                num_iterations)

    def send_helper_thread_args(self, tensor_name, index,
                                backward, num_iterations):
        # 利用張量名字得到對應的rank
        if backward:
            dst_rank = self.receive_ranks[tensor_name][index]
            num_ranks_in_connected_stage = self.num_ranks_in_previous_stage
        else:
            dst_rank = self.send_ranks[tensor_name][index]
            num_ranks_in_connected_stage = self.num_ranks_in_next_stage

        # 使用tag來獲取到對應的程式組
        sub_process_group = None
        tag = self.tensor_tags[tensor_name]
        if self.is_gpu_to_gpu_comm(connected_rank=dst_rank) and tensor_name != "ack":
            min_rank = min(self.rank, dst_rank)
            max_rank = max(self.rank, dst_rank)
            if dst_rank > self.rank:
                sub_process_group = \
                     self.process_groups[min_rank][max_rank][tag]['forward']
            else:
                sub_process_group = \
                    self.process_groups[min_rank][max_rank][tag]['backward']
            assert sub_process_group

        # 得到對應的queue
        if backward:
            queue = self.backward_send_queues[tensor_name][index]
        else:
            queue = self.forward_send_queues[tensor_name][index]

        # 返回引數
        return (queue, self.counter, self.local_rank, tensor_name, self.rank,
                dst_rank, tag, sub_process_group, num_iterations)

0x04 功能函式

以下功能函式就是最終被使用完成 流水線 RPC 邏輯的函式。

這裡有一個通過queue完成的解耦合:

  • recv 和 send 就會對於 queue 進行操作,往queue裡面新增或者提取張量。
  • 助手執行緒會呼叫 _recv 和 _send 對 queue 進行操作。

所以我們要先看看這個Queue的實現,可以看到,無論是 add 還是 remove,都使用了 threading.Condition,就說明幾個執行緒可以在 Queue 上通過 add / remove 實現等待,阻塞,即生產者和消費者。

class Queue:
    def __init__(self):
        self.queue = []
        self.cv = threading.Condition()

    def add(self, tensor):
        self.cv.acquire()
        self.queue.append(tensor)
        self.cv.notify()
        self.cv.release()

    def remove(self):
        self.cv.acquire()
        while len(self.queue) == 0:
            self.cv.wait()
        tensor = self.queue.pop(0)
        self.cv.release()
        return tensor

4.1 傳送邏輯

傳送的邏輯如下:

  1. 訓練程式碼會呼叫StageRuntime.run_backward。
  2. StageRuntime.run_backward 方法會呼叫 StageRuntime.send_tensors_backward 來傳送張量 tensor_name。
  3. send_tensors_backward 會呼叫 CommunicationHandler.send 來向 CommunicationHandler 的成員變數backward_send_queues[tensor_name] [index] 新增這個張量。每個張量對應了若干個queue。這裡就是個解耦合。
  4. send 函式 會呼叫 backward_send_queues.add,這裡會通知阻塞在queue上的 send_helper_thread 進行工作。
  5. 在 CommunicationHandler 的執行緒 send_helper_thread 中,之前就阻塞在queue這裡,此時會從 backward_send_queues[tensor_name] [index] 之中提取張量。
  6. send_helper_thread 會呼叫 _send 來傳送張量。
  7. 而最終呼叫的是 dist.send,就是PyTorch P2P。

具體如下圖:

 StageRuntime            CommunicationHandler              send_helper_thread

      +                           +                                 +
      |                           |                                 |
      | 1                         |                                 |
      v                           |                                 |
 run_backward                     |                                 |
      |                           |                                 |
      | 2                         |                                 |
      |                           |                    wait on backward_send_queues
      v                  3        v                                 |
send_tensors_backward +--------> send                               |
                                  |                                 |
                                  |                                 |
                                  |  4                              |
                                  v               5                 v
               backward_send_queues.add(tensor) +----> tensor = queue.remove()
                                                notify              |
                                                                    |
                                                                    | 6
                                                                    v
                                                                  _send
                                                                    |
                                                                    | 7
                                                                    |
                                                                    v
                                                                 dist.send

4.2 接受邏輯

接受邏輯如下:

  1. StageRuntime 訓練程式碼中呼叫 run_backward。
  2. run_backward 呼叫 receive_tensors_backward。
  3. receive_tensors_backward 呼叫 self.gradients[output_name] = self.comm_handler.recv 獲取梯度。CommunicationHandler 的 recv 函式會阻塞在 backward_receive_queues[tensor_name] [index] 之上
  4. 同時,CommunicationHandler 的 recv_helper_thread 執行緒呼叫 _recv 接受其他stage點傳來的張量。
  5. _recv呼叫 dist.recv 或者 dist.broadcast 接受張量。
  6. _recv 向 backward_receive_queues[tensor_name] [index] 新增張量。這樣就通知阻塞的 CommunicationHandler 的 recv 函式進行工作
  7. CommunicationHandler 的 recv 函式會從backward_receive_queues[tensor_name] [index] 提取梯度,然後返回給 StageRuntime。就是 3 的返回。

具體如下圖:

    StageRuntime             CommunicationHandler           recv_helper_thread
          +                            +                            +
          |                            |                            |
          | 1                          |                            |
          |                            |                            | 4
          v                            |                            v
    run_backward                       |                         _recv
          |                            |                            |
          |                            |                            |
          |                            |                            | 5
          |                            |                            |
          | 2                          |                            v
          |                            |                  dist.recv / dist.broadcast
          |                            |                            |
          v                  3         v                            |
receive_tensors_backward +--------->  recv                          |
          +                            |                            |
          |                            |                            |
          |                            |                            |
          |                            |                            |
          |                            v                            |
          |                 backward_receive_queues.remove()        |
          |                            |                            |
          |                            |                            |
          |                            |                            |
          |                            |                            |
          |               wait on backward_receive_queues           |
          |                            |                            |
          |                            |                            |
          |                            |                            |
          |                            |                 6          v
          |                  backward_receive_queues <-------+ queue.add(tensor)
          |                            |               notify
          |                            |  7
          v                  3 return  |
gradients[output_name] <---------------+

4.3 recv

這裡其實就是從對應的queue之中,依據張量名字來獲取對應的張量。

    def recv(self, tensor_name, forward_minibatch_id,
             backward_minibatch_id, backward=False):
        if backward:
            index = (backward_minibatch_id + self.rank_in_stage) % \
                len(self.backward_receive_queues[tensor_name])
            tensor = self.backward_receive_queues[tensor_name][
                index].remove()
            return tensor
        else:
            # 前向時候,需要知道從前一層的哪一個index獲取
            index = self.get_messaging_index(sending=False)
            tensor = self.forward_receive_queues[tensor_name][
                index].remove()
            if tensor.dtype == torch.float32:
                tensor = tensor.requires_grad_()
            return tensor

在執行時 receive_tensors_forward,receive_tensors_backward 函式中,會呼叫到 recv 函式,從對應的queue 拿到已經存的張量。比如:

    def receive_tensors_backward(self):
        # Receive all required gradients from downstream
        # machines.
        for output_name in self.send_ranks:
             if output_name in self.target_tensor_names:
                continue

             self.gradients[output_name] = \
                self.comm_handler.recv( # 這裡使用了
                    output_name,
                    forward_minibatch_id=self.forward_minibatch_id,
                    backward_minibatch_id=self.backward_minibatch_id,
                    backward=True)

             self.backward_stats.stats['receive_tensors_size'] += \
                 (self.gradients[output_name].element_size() *
                  self.gradients[output_name].nelement())

4.4 send

這裡是把張量放置在對應的queue之中。

    def send(self, tensor_name, tensor, forward_minibatch_id,
             backward_minibatch_id, backward=False):
        if backward:
            # 後向時候,需要知道傳送給前一層的哪一個index
            index = self.get_messaging_index(sending=True)
            dst_rank = self.receive_ranks[tensor_name][index]
            self.backward_send_queues[tensor_name][index].add(tensor)
        else:
            index = (forward_minibatch_id + self.rank_in_stage) % \
                len(self.send_ranks[tensor_name])
            self.forward_send_queues[tensor_name][index].add(tensor)


send_tensors_backward,send_tensors_forward 之中會使用,比如:

    def send_tensors_backward(self):
        # Send all required gradients upstream.
        for input_name in self.receive_ranks:
            if input_name in self.target_tensor_names:
                continue

            self.comm_handler.send(
                input_name,
                self.gradients[input_name],
                forward_minibatch_id=self.forward_minibatch_id,
                backward_minibatch_id=self.backward_minibatch_id,
                backward=True)

            self.backward_stats.stats['send_tensors_size'] += \
                (self.gradients[input_name].element_size() *
                 self.gradients[input_name].nelement())

        if self.num_ranks_in_previous_stage > 0:
            # Used to track where to send tensors in the
            # backward pass.
            self.comm_handler.increment_messaging_index(
                sending=True)

4.5 _recv

_recv 引數中,sub_process_group 就是上面程式碼中構建的。

如果在同一個節點上,就使用dist.broadcast,否則使用dist.recv。

def _recv(tensor_name, src_rank, tensor_shape=None, dtype=torch.float32,
          tensor=None, tag=None, sub_process_group=None):
    """
    Receives tensor by calling PyTorch's recv() call.

    Tensor will be copied to GPU prior to return.
    """
    assert tag is not None
    if tensor is None:
        assert tensor_shape is not None
        assert dtype is not None
        assert dtype != torch.float16

    if sub_process_group is not None:
        # Receive tensor shape.
        received_tensor_shape = torch.zeros(len(tensor_shape),
                                            dtype=torch.int)
        dist.broadcast(tensor=received_tensor_shape,
                       src=src_rank,
                       group=sub_process_group)
        received_tensor_shape = list(map(lambda x: int(x),
                                         received_tensor_shape))

        # Receive tensor.
        tensor = torch.zeros(received_tensor_shape, dtype=dtype).cuda()
        dist.broadcast(tensor=tensor,
                       src=src_rank,
                       group=sub_process_group)
    else:
        # Receive tensor shape.
        received_tensor_shape = torch.zeros(len(tensor_shape),
                                            dtype=torch.int)
        dist.recv(tensor=received_tensor_shape,
                  src=src_rank,
                  tag=tag)
        received_tensor_shape = list(map(lambda x: int(x),
                                         received_tensor_shape))

        # Receive tensor.
        tensor = torch.zeros(received_tensor_shape, dtype=dtype)
        dist.recv(tensor=tensor,
                  src=src_rank,
                  tag=tag)
        tensor = tensor.cuda()

    assert tensor.is_cuda
    return tensor

在 recv_helper_thread 之中會呼叫 _recv。

def recv_helper_thread(queue, counter, local_rank, tensor_name,
                       src_rank, tag, tensor_shape, dtype,
                       sub_process_group, num_iterations):
    torch.cuda.set_device(local_rank)
    # This method is to be executed from a helper daemon thread.
    for i in range(num_iterations):
        tensor = _recv(
            tensor_name, src_rank, tensor_shape=tensor_shape,
            dtype=dtype, tag=tag,
            sub_process_group=sub_process_group)
        queue.add(tensor) # 獲取到張量之後,放入queue
    counter.decrement()

4.6 _send

如果在同一個節點上,就使用dist.broadcast,否則使用dist.send。

def _send(tensor, tensor_name, src_rank, dst_rank, tag, sub_process_group=None):
    """
    Sends tensor by calling PyTorch's send() call.

    If tensor is being sent not via broadcast(), it will
    be first copied to the CPU.
    """
    if sub_process_group is not None:
        assert tensor.is_cuda

        # Send tensor shape.
        tensor_shape = torch.tensor(tensor.shape, dtype=torch.int)
        dist.broadcast(tensor=tensor_shape, src=src_rank,
                      group=sub_process_group)

        # Send tensor.
        contiguous_tensor = tensor.detach().clone()
        dist.broadcast(tensor=contiguous_tensor.contiguous(),
                       src=src_rank,
                       group=sub_process_group)
    else:
        assert tensor.is_cuda
        tensor = tensor.cpu()

        # Send tensor shape.
        tensor_shape = torch.tensor(tensor.shape, dtype=torch.int)
        dist.send(tensor=tensor_shape, dst=dst_rank, tag=tag)

        # Send tensor.
        dist.send(tensor=tensor, dst=dst_rank, tag=tag)

recv_helper_thread 使用 _send獲取張量。

def send_helper_thread(queue, counter, local_rank, tensor_name,
                       src_rank, dst_rank, tag,
                       sub_process_group, num_iterations):
    torch.cuda.set_device(local_rank)
    # This method is to be executed from a helper daemon thread.
    for i in range(num_iterations):
        tensor = queue.remove()
        # 從queue提取張量,傳送出去。
        _send(tensor, tensor_name, src_rank, dst_rank,
              tag=tag,
              sub_process_group=sub_process_group)
    counter.decrement()

至此,通訊模組已經分析完畢,下一篇終於要介紹 1F1B 了。

0xFF 參考

相關文章