[原始碼解析] 深度學習流水線並行 PipeDream(2)--- 計算分割槽

羅西的思考發表於2021-09-03

[原始碼解析] 深度學習流水線並行 PipeDream(2)--- 計算分割槽

0x00 摘要

在前文中,我們介紹了PipeDream的總體架構和Profile階段,本文我們繼續介紹計算分割槽階段。其功能是:依據profile結果確定所有層的執行時間,然後使用動態規劃對模型進行劃分,將模型劃分為不同的stage,以及得到每個stage的replication數。計算結果具體如下圖所示:

流水線並行其他文章連結如下:

[原始碼解析] 深度學習流水線並行Gpipe(1)---流水線基本實現

[原始碼解析] 深度學習流水線並行GPipe (2) ----- 梯度累積

[原始碼解析] 深度學習流水線並行之PipeDream(1)--- Profile階段

0x01 前言

1.1 Profile檔案

我們首先看看profile檔案 profiler/translation/profiles/gnmt/graph.txt 內容,這裡只是做摘錄。

node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
node2 -- Input1 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
node6 -- Dropout(p=0.2) -- forward_compute_time=0.077, backward_compute_time=0.196, activation_size=12582912.0, parameter_size=0.000
node7 -- LSTM(2048, 1024) -- forward_compute_time=3.190, backward_compute_time=5.348, activation_size=[6291456.0; 131072.0; 131072.0], parameter_size=50364416.000
node8 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
node9 -- __getitem__(1) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=131072.0, parameter_size=0.000
node10 -- Dropout(p=0.2) -- forward_compute_time=0.064, backward_compute_time=0.128, activation_size=6291456.0, parameter_size=0.000
node11 -- LSTM(1024, 1024) -- forward_compute_time=2.491, backward_compute_time=4.203, activation_size=[6291456.0; 131072.0; 131072.0], parameter_size=33587200.000
node12 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
node13 -- __getitem__(1) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=131072.0, parameter_size=0.000
node14 -- Add -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
node15 -- Dropout(p=0.2) -- forward_compute_time=0.059, backward_compute_time=0.121, activation_size=6291456.0, parameter_size=0.000
node16 -- LSTM(1024, 1024) -- forward_compute_time=2.492, backward_compute_time=4.201, activation_size=[6291456.0; 131072.0; 131072.0], parameter_size=33587200.000
node17 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
......
	node1 -- node4
	node4 -- node5
	node2 -- node5
	node5 -- node6
	node6 -- node7
	node7 -- node8
	node7 -- node9
	node8 -- node10
	node10 -- node11
	node11 -- node12
	node11 -- node13
	node12 -- node14
	node8 -- node14
	node14 -- node15
	node15 -- node16
	node16 -- node17
	node16 -- node18
	node17 -- node19
	node14 -- node19
......

1.2 總體思路

在前文我們也提到了幾個挑戰,其中有:

  • 如何高效劃分流水線
    • 模型特質和硬體拓撲會降低效率。分配演算法也必須考慮模型特質和硬體拓撲
    • 機器間的過度通訊會降低硬體效率。
  • 如何防止流水線瓶頸
    • 由木桶原理我們可以知道,一個流水線管道的吞吐量由這個流水線上最慢環節的吞吐量決定。所以需要確保流水線中所有階段都大致花費相同的計算時間,否則最慢的階段將會成為整個流水線的瓶頸

因此當跨機器將層劃分為不同的階段時,PipeDream的自動劃分演算法必須確保每個階段大致執行相同的總工作量。同時還必須確保各階段之間通訊的資料量儘可能小,以避免通訊中斷。

PipeDream的自動劃分演算法總體目標是輸出一個平衡的管道,演算法如下:

  • 將DNN層劃分為多個階段,以便每個階段以大致相同的速率完成,即花費大致相同的計算時間。
  • 嘗試以拓撲感知的方式儘量減少worker之間的通訊(例如,如果可能,向更高頻寬的鏈路傳送較大的輸出)。
  • 因為DNN並不總可以在可用的workers做平均分配,為了進一步改進負載平衡,PipeDream允許複製一個stage,即在這個stage上使用多個worker進行資料並行。

這個劃分問題等價於最小化流水線的最慢階段所花費的時間,並且具有最優子問題屬性:在給定worker工作量前提下,吞吐量最大化的流水線由一系列子流水線構成,其中每一個子流水線針對較小worker工作量來最大化自己的輸出。因此PipeDream使用動態規劃來尋找最優解。

這裡給出對應的架構圖如下:

我們下面先看看計算分割槽之前的準備工作:圖相關工作和構建反鏈。

0x02 圖相關

圖的定義位於 graph/graph.py 檔案之中,主要資料結構有兩個:Graph 和 Node。

2.1 Graph

Graph就是圖的資料結構,其主要成員包括:

  • nodes :圖內節點;
  • edges :圖內每個節點的輸出邊;
  • in_edges :圖的每個節點的輸入邊;
  • _predecessors :每個節點的前序節點;
  • _successors :每個節點的後序節點;
  • _antichain_dag :反鏈DAG;
class Graph(object):
    def __init__(self, node=None):
        self.nodes = {} # 節點
        if node is not None:
            self.nodes[node.node_id] = node
        self.edges = {} # 出邊
        self.in_edges = {} # 入邊

        self._predecessors = {} #每個節點的前序節點 
        self._successors = {} # 每個節點的後序節點
        self._augmented_antichains = {}
        self._deaugmented_augmented_antichains = {}
        self._next_antichains = {}
        self._antichain_dag = None # 反鏈DAG

        if node is not None:
            self.in_edges[node.node_id] = list()

節點定義如下,裡面就是從profile獲取到的結構,比如:

  • forward_compute_time : 前向傳播時間;
  • backward_compute_time :反向傳播時間;
  • activation_size : 啟用值大小;
  • parameter_size : 引數大小;
class Node(object):
    def __init__(self, node_id, node_desc="", forward_compute_time=0.0,
                 backward_compute_time=0.0, activation_size=0.0, parameter_size=0.0,
                 stage_id=None):
        self.node_id = node_id
        self.node_desc = node_desc
        self.forward_compute_time = forward_compute_time
        self.backward_compute_time = backward_compute_time
        self.activation_size = activation_size
        self.parameter_size = parameter_size
        self.stage_id = stage_id
        self.depth = None
        self.height = None

我們列印出執行時看看,可以發現 Graph 的具體情況。

gr = {Graph} 
 # 邊
 edges = {dict: 39}
  'node1' = {list: 1} 
   0 = {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
  'node4' = {list: 1} 
   0 = {Node} node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
   ......

 # 輸入邊 
 in_edges = {dict: 44} 
  'node4' = {list: 1} 
   0 = {Node} node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
  'node5' = {list: 2} 
   0 = {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
   1 = {Node} node2 -- Input1 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
   ......
  
 # 節點 
 nodes = {dict: 48}
  'node1' = {Node} node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
  'node4' = {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
  'node5' = {Node} node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
 ......

# 前置節點
_predecessors = {dict: 36} 
 'node4' = {set: 0} set()
  __len__ = {int} 0
 'node5' = {set: 1} {<graph.graph.Node object at 0x7fb055e4bf28>}
  {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
  __len__ = {int} 1
 'node6' = {set: 2} {<graph.graph.Node object at 0x7fb055e4bf98>, <graph.graph.Node object at 0x7fb055e4bf28>}
  {Node} node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
  {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
  __len__ = {int} 2
 'node7' = {set: 3} {<graph.graph.Node object at 0x7fb055e4bf98>, <graph.graph.Node object at 0x7fb055e4bf28>, <graph.graph.Node object at 0x7fb055e670f0>}
  {Node} node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
  {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
  {Node} node6 -- Dropout(p=0.2) -- forward_compute_time=0.077, backward_compute_time=0.196, activation_size=12582912.0, parameter_size=0.000
  __len__ = {int} 3

 # 其他變數
  _antichain_dag = {NoneType} None
  _augmented_antichains = {dict: 0} {}
  _deaugmented_augmented_antichains = {dict: 0} {}
  _next_antichains = {dict: 0} {}
  _successors = {dict: 0} {}

2.2 構建圖

圖是由profile檔案的字串構建出來。找出來profile檔案內容我們就可以知道,具體是針對每行進行不同處理。

node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
	node1 -- node4
	node4 -- node5
	node2 -- node5

構建圖具體程式碼如下:

@staticmethod
def from_str(graph_str):
    gr = Graph()
    graph_str_lines = graph_str.strip().split('\n') 
    for graph_str_line in graph_str_lines: # 逐行處理
        if not graph_str_line.startswith('\t'):
            node = Node.from_str(graph_str_line.strip()) # 構建節點
            gr.nodes[node.node_id] = node
        else:
            # 構建邊
            [in_node_id, node_id] = graph_str_line.strip().split(" -- ")
            if node_id not in gr.in_edges: # 每個節點的輸入邊
                gr.in_edges[node_id] = [gr.nodes[in_node_id]]
            else:
                gr.in_edges[node_id].append(gr.nodes[in_node_id])
            if in_node_id not in gr.edges: # 每個節點的輸出邊
                gr.edges[in_node_id] = [gr.nodes[node_id]]
            else:
                gr.edges[in_node_id].append(gr.nodes[node_id])
    return gr

構建節點具體程式碼如下:

    @staticmethod
    def from_str(node_str):
        node_str_tokens = node_str.strip().split(" -- ")
        node_id = node_str_tokens[0] # 節點名字
        node_desc = node_str_tokens[1] # 節點描述
        node_metadata = node_str_tokens[2] # 後設資料
        stage_id = None
        if len(node_str_tokens) > 3:
            stage_id = int(node_str_tokens[3].split("=")[1]) # 階段資訊
        [forward_compute_time, backward_compute_time, activation_size, parameter_size] = node_metadata.split(", ")
        forward_compute_time = float(forward_compute_time.split("=")[1]) # 前向傳播計算時間
        backward_compute_time = float(backward_compute_time.split("=")[1]) # 後向傳播計算時間
        if "[" in activation_size:
            activation_size = activation_size.split("=")[1] # 啟用值大小
            activation_size = sum([float(x) for x in activation_size.lstrip("[").rstrip("]").split("; ")])
        else:
            activation_size = float(activation_size.split("=")[1])
        parameter_size = float(parameter_size.split("=")[1]) # 引數大小
        # 構建節點
        return Node(node_id, node_desc, forward_compute_time=forward_compute_time,
                    backward_compute_time=backward_compute_time, activation_size=activation_size,
                    parameter_size=parameter_size, stage_id=stage_id)

2.3 反鏈

在有向無環圖中,有如下的一些概念:

  • 鏈 :一條鏈是一些點的集合,在此鏈上的任意兩個點x, y,滿足以下條件:或者 x 能到達 y ,或者 y 能到達 x 。也可以認為是某一個偏序集S的全序子集(所謂全序是指其中任意兩個元素可以比較)

  • 反鏈 :一條反鏈也是一些點的集合,在此鏈上任意兩個點x, y,滿足如下條件: x 不能到達 y,且 y 也不能到達 x。也可以認為是某一個偏序集S的子集,其中任意兩個元素不可比較。

在PipeDream的圖資料結構之中,也有反鏈的概念。反鏈節點定義如下:

class AntichainNode(Node):
    def __init__(self, node_id, antichain, node_desc=""):
        self.antichain = antichain
        self.output_activation_size = 0.0
        super(AntichainNode, self).__init__(node_id, node_desc)

因為此處過於複雜,所以我們會在下面用一節專門分析。

0x03 構建反鏈

因為本節概念比較繞,所以我們先提前劇透。

尋找某節點後續反鏈的目的就是找到下一個圖分割點 A(可能是若干node的組合),為了確定 A 的執行時間(或者其他資訊),我們需要找到 A 的增強反鏈

此處具體程式碼位於optimizer_graph_hierarchical.py 檔案。

我們利用如下邏輯來演示:

+-------+       +-------+
| node1 |       | node2 |
+---+---+       +---+---+
    |               |
    |               |
    |               |
    v               v
+---+---+       +---+---+        +-------+        +-------+
| node4 +-----> | node5 +------> | node6 +------->+ node7 |
+-------+       +-------+        +-------+        +-+-+---+
                                                    | |
                                                    | |
                                      +-------------+ |
                                      |               |
                                      v               v
                                 +----+--+        +---+---+
                                 | node9 |        | node8 +-----+
                                 +-------+        +---+---+     |
                                                      |         |
                    +---------------------------------+         |
                    |                                           |
                    v                                           |
               +----+---+       +--------+        +--------+    |
               | node10 +-----> | node11 +------> | node12 |    |
               +--------+       +---+----+        +----+---+    |
                                    |                  |        |
                                    |                  |        |
                                    v                  v        |
                                +---+----+        +----+---+    |
                                | node13 |        | node14 +<---+
                                +--------+        +-+----+-+
                                                    |    |
                                             +------+    +---+
                                             |               |
                                             v               v
                                        +----+---+        +--+-----+
                                        | node15 |        | node19 |
                                        +--------+        +--------+

3.1 main函式入口

我們首先從 main 函式看起。main函式第一部分是構建反鏈和拓撲排序,具體如下:

  • 從圖中移除source節點。目的是排除干擾,因為input必然在第一層,沒必要讓優化器再來選擇把輸入放在哪裡,所以先去除,後續轉換模型時候會再加上。
  • 對圖的輸出進行處理,移除沒有用到的輸出。
  • 得到反鏈DAG。
  • 對反鏈DAG進行拓撲排序,得到一個排序好的節點列表。

具體程式碼如下:

def main(all_num_machines, profile_filename, network_bandwidths, memory_size,
         straight_pipeline, use_memory_constraint, use_fewer_machines,
         activation_compression_ratio, output_directory,
         print_configuration=True, verbose=False):
    gr = graph.Graph.from_str(open(profile_filename, 'r').read())

    # Zero out all metadata associated with inputs in graph, since the optimizer
    # shouldn't really get a choice with where to place the input (should always
    # be in the first stage).
    # 排除干擾,因為input必然在第一層,沒必要讓優化器再來選擇把輸入放在哪裡,所以先去除,後續會再加上。
    sources = gr.sources() # 對圖的輸入進行處理
    nodes_to_remove = OrderedDict()
    for source in sources:
        if source.node_desc.startswith("Input"): # 只處理input
            source.forward_compute_time = 0.0
            source.backward_compute_time = 0.0
            source.activation_size = 0.0
            source.parameter_size = 0.0
            nodes_to_remove[source] = []
            for out_node in gr.edges[source.node_id]:
                nodes_to_remove[source].append(out_node) # 記錄這些刪除source對應了哪些out節點,因為後續還要處理
            gr.remove_node(source) # 在圖中移除這些input source

    # Remove all unneeded sinks that are not used, makes code generation and
    # optimization easier.
    sinks = gr.sinks() # 對圖的輸出進行處理,移除沒有用到的輸出
    for sink in sinks:
        if sink.node_desc.startswith("__getitem__"):
            gr.remove_node(sink)

    antichain_gr = gr.antichain_dag() # 得到反鏈DAG
    states = antichain_gr.topological_sort() # 拓撲排序,得到一個排序好的節點列表

    # 後續程式碼暫時省略

這裡再取出反鏈節點定義如下,可以看出來和程式碼對應關係。

class AntichainNode(Node):
    def __init__(self, node_id, antichain, node_desc=""):
        self.antichain = antichain
        self.output_activation_size = 0.0
        super(AntichainNode, self).__init__(node_id, node_desc)

3.2 增強反鏈

首先要介紹先增強反鏈概念。每個節點的增強反鏈包括:本身節點 + 部分前序節點

這個前序節點的選取演算法是:

  1. 獲取本節點的全部前序節點列表;
  2. 如果一個前序節點的"出邊目的節點"不在全部前序節點列表,且"出邊目的節點"不為本身,則選取此前序節點為增強反鏈的一部分。

從下面圖例中可以看出來,如果某一個節點 A,其前置節點中有一個分叉節點 Z,且這個分叉之中,有一個分叉繞過了節點 A,則對於節點 A,他的增強反鏈就是 [A, Z]。

對於增強反鏈概念,可以理解為:對於節點 A,他只有把節點 Z 一起考慮,才能唯一確定自己節點的執行時間。因為如果思考節點 A 的執行時間,我理解的大致思路是:

  • 因為各個階段可以流水線並行,所以 A 的執行時間應該是以下三個時間的最大值:A的計算時間,A的輸入時間,A的輸出時間。
  • A 的輸入時間是以下兩個時間的最大值: X --> A 節點輸出時間,Z --> A 節點的輸出時間。
  • 但是因為不清楚 Z 的內部執行機制,所以不能確定 Z 的兩個輸出之間是否有依賴關係,比如 "必須先完成 Z--> D,才能輸出 Z--> A", 所以,也需要考慮 Z --> D 的傳輸時間。

所以,需要把 [ A,Z ] 放在一起作為一個狀態考慮,事實上 PipeDream 就是這麼處理的,用 [ A,Z ] 這個狀態來統一計算

因為作為一個狀態考慮,所以給節點 A 計算輸出啟用值大小,具體是通過遍歷其反鏈(增強反鏈)來計算,就是把其增強反鏈的前序節點給自己的輸出都疊加起來

    +-----+            +-----+
    |  X  |            |  Z  |
    +--+--+            +--+-++
       |                  | |
       |                  | |
       +------+   +-------+ |
              |   |         |
              v   v         |
             ++---++        |
             |  A  |        |
             ++-+--+        |
              | |           |
    +---------+ |           |
    |           |           |
    v           v           v
+---+-+      +--+--+      +-+---+
|  B  |      |  C  |      |  D  |
+-----+      +-----+      +-----+

在程式碼之中,_augmented_antichains 是增強反鏈,也是一個字典類,key是節點名字,value是 key 節點的增強反鏈,比如:

augment_antichain函式作用就是對每個節點,找到其增強反鏈。

def augment_antichain(self, antichain):
    # 引數 antichain 是一個節點列表
    antichain_key = tuple(sorted(antichain))
    # 如果key已經在擴大反鏈之中,就直接返回對應key的增強反鏈
    if antichain_key in self._augmented_antichains:
        return self._augmented_antichains[antichain_key]
    extra_nodes = set()
    all_predecessors = set()
    # 遍歷引數list之中的反鏈節點,獲取每個節點的前置節點,歸併在all_predecessors之中。
    for antichain_node in antichain:
        predecessors = self.predecessors(antichain_node)
        all_predecessors = all_predecessors.union(predecessors)
    # 遍歷引數list之中的反鏈節點
    for antichain_node in antichain:
        # 獲取每個反鏈節點的前置節點列表
        predecessors = self.predecessors(antichain_node)
        # 遍歷每個前置節點
        for predecessor in predecessors:
            # 看每個前置節點的出邊,如果出邊不在前置節點列表之中,且 出邊節點不等於本反鏈節點
            for out_node in self.edges[predecessor.node_id]:
                if out_node not in predecessors and out_node.node_id != antichain_node:
                    # 把這個前置節點插入到附加節點列表中
                    extra_nodes.add(predecessor.node_id)
    # 最終把個附加節點列表插入到增強節點之中
    self._augmented_antichains[antichain_key] = list(extra_nodes) + antichain
    return self._augmented_antichains[antichain_key]

比如對應下圖中的邏輯,初始化之後,_augmented_antichains 就是

_augmented_antichains = {dict: 1} 
 ('node4',) = {list: 1} ['node4']

後續迭代node 5之後,_augmented_antichains 就是

_augmented_antichains = {dict: 2} 
 ('node4',) = {list: 1} ['node4']
 ('node5',) = {list: 1} ['node5']
 __len__ = {int} 2

繼續迭代,增強反鏈為:

_augmented_antichains = {dict: 7} 
('node4',) = {list: 1} ['node4'] # node4的增強反鏈只有自己
('node5',) = {list: 1} ['node5'] # node5的增強反鏈只有自己
('node6',) = {list: 1} ['node6']
('node7',) = {list: 1} ['node7']
('node8',) = {list: 1} ['node8']
('node10',) = {list: 2} ['node8', 'node10'] # node10的增強反鏈是'node8', 'node10'
('node14',) = {list: 1} ['node14']
('node11',) = {list: 2} ['node8', 'node11'] # node11的增強反鏈是'node8', 'node11'
('node15',) = {list: 2} ['node14', 'node15']
('node19',) = {list: 1} ['node19']
('node12',) = {list: 2} ['node8', 'node12']
('node16',) = {list: 2} ['node14', 'node16']
('node23',) = {list: 2} ['node20', 'node23']
('node17',) = {list: 2} ['node14', 'node17']  

圖例中可以看出來,因為有 node 8的出邊 [node 8,node 14] 存在,對於 node 10, node 11, node 12 來說,他們必須把 node 8 加入自己的增強反鏈之中。

對於 node 10,我們可以認為,必須結合 node 8之後,node 10 才能確定 node 10 的執行時間。下面圖上標記出來了 node 10 的 augmented 反鏈(本身節點 + 部分前序節點)。

+-------+       +-------+
| node1 |       | node2 |
+---+---+       +---+---+
    |               |
    |               |
    |               |
    v               v
+---+---+       +---+---+        +-------+        +-------+
| node4 +-----> | node5 +------> | node6 +------->+ node7 |
+-------+       +-------+        +-------+        +-+-+---+
                                                    | |
                                                    | |
                                      +-------------+ |
                                      |               |
                                      v               v  augmented
                                 +----+--+        +---+---+
                                 | node9 |        | node8 +-----+
                                 +-------+        +---+---+     |
                                                      |         |
                    +---------------------------------+         |
                    |                                           |
                    v                                           |
               +----+---+       +--------+        +--------+    |
     antichain | node10 +-----> | node11 +------> | node12 |    |
               +--------+       +---+----+        +----+---+    |
             augmented              |                  |        |
                                    |                  |        |
                                    v                  v        |
                                +---+----+        +----+---+    |
                                | node13 |        | node14 +<---+
                                +--------+        +-+----+-+
                                                    |    |
                                             +------+    +---+
                                             |               |
                                             v               v
                                        +----+---+        +--+-----+
                                        | node15 |        | node19 |
                                        +--------+        +--------+

3.3 後續反鏈

在程式碼之中,_next_antichains 是一個字典類,key是節點名字,value是 key 節點的後續反鏈。

比如,對於 node A 來說,下一個反鏈是 [ node B, node C ],其中 node B 和 node C 彼此之間無法排序。尋找反鏈的目的就是找到下一個圖分割點

    +-----+            +-----+
    |  X  |            |  Z  |
    +--+--+            +--+-++
       |                  | |
       |                  | |
       +------+   +-------+ |
              |   |         |
              v   v         |
             ++---++        |
             |  A  |        |
             ++-+--+        |
              | |           |
    +---------+ |           |
    |           |           |
    v           v           v
+---+-+      +--+--+      +-+---+
|  B  |      |  C  |      |  D  |
+-----+      +-----+      +-----+

對於每個節點 antichain ,next_antichains 函式獲取其後續反鏈。

    def next_antichains(self, antichain):
        # 構建antichain的反鏈key,其實就是 antichain 自己作為key
        antichain_key = tuple(sorted(antichain))
        # 如果key已經在後續反鏈之中,則返回這個後續反鏈
        if antichain_key in self._next_antichains:
            return self._next_antichains[antichain_key]

        next_antichains = []
        antichain_set = set(antichain)
        # 獲取 antichain 的增強反鏈
        augmented_antichain = self.augment_antichain(antichain)
        # 遍歷增強反鏈
        for augmented_antichain_node in augmented_antichain:
            # 遍歷增強反鏈某節點的出邊
            next_nodes = self.edges[augmented_antichain_node] if augmented_antichain_node in self.edges else []
            # 遍歷增強反鏈某節點的出邊
            for next_node in next_nodes:
                # 如果出邊節點已經在反鏈集合之中,跳過,進入下一迴圈
                if next_node.node_id in antichain_set:
                    continue
                # 如果出邊節點是後續反鏈,則假如到反鏈列表   
                if self.is_next_antichain(augmented_antichain, next_node.node_id):
                    next_antichain = self.construct_antichain(augmented_antichain,
                                                              augmented_antichain_node,
                                                              next_node.node_id)
                    next_antichains.append(next_antichain)
        # 最終把反鏈列表設定為key對應的反鏈            
        self._next_antichains[antichain_key] = next_antichains
        return self._next_antichains[antichain_key]

is_next_antichain 方法用來判斷某新節點是否為後續反鏈。

def is_next_antichain(self, augmented_antichain, new_node):
    successors = self.successors(new_node)
    augmented_antichain_set = set(augmented_antichain)
    # 遍歷新節點的後續節點
    for successor in successors:
        # 如果後續節點有一個在增強節點之中,就返回false,說明不是後續反鏈
        if successor.node_id in augmented_antichain_set:
            return False
    # 否則就是後續反鏈      
    return True

_next_antichains舉例如下,大家可以結合之前的增強反鏈對比看看。

  • 以 node 10 為例,其增強節點為:[ node 8,node 10 ],
  • 遍歷這些增強節點,看每一個增強節點的出邊。8 的出邊 [ node 10,node 14 ],10 的出邊是 [ node 11]。
  • 所以有三個點 node 10,node 11,node 14 可以繼續看。其中node 10 已經在[ node 8,node 10 ]之中,所以不考慮。
  • 用 14 呼叫 is_next_antichain。
    • is_next_antichain 之中,augmented_antichain 為 [ node 8, node 10],new_node 是 node 14。
    • 得到 successors 集合為 [ node31,node16,node23,node44,node48 ....] 等22個節點,這些節點都不在 [ node 8, node 10] 之中,所以 is_next_antichain 為 true,14 是後續反鏈節點之一。
  • 用 11 呼叫 is_next_antichain。
    • is_next_antichain 之中,augmented_antichain 為 [ node 8, node 10],new_node 是 node 11。
    • 得到 successors 集合為 [ node16,node40,node23,....] 等節點,這些節點都不在 [ node 8, node 10] 之中,所以 is_next_antichain 為 true,11 是後續反鏈節點之一。

所以 node 10 的後續反鏈是 [ ['node14'] ,[ 'node11'] ]。

對比 看看,node 10 的增強反鏈是 ['node8', 'node10'],

_next_antichains = {dict: 99} 
 ('node4',) = {list: 1} [['node5']]
 ('node5',) = {list: 1} [['node6']]
 ('node6',) = {list: 1} [['node7']]
 ('node7',) = {list: 1} [['node8']]
 ('node8',) = {list: 2} [['node10'], ['node14']]
 ('node10',) = {list: 2} [['node14'], ['node11']] # 這裡
 ('node14',) = {list: 2} [['node15'], ['node19']]
 ('node11',) = {list: 2} [['node14'], ['node12']]
 ('node15',) = {list: 2} [['node19'], ['node16']]
 ('node19',) = {list: 1} [['node23']]
 ('node12',) = {list: 2} [['node14'], ['node14']]
 ('node16',) = {list: 2} [['node19'], ['node17']]

具體如下圖,可以看出來,node 11和 node 14確實是 node 10的後續反鏈,就是在這兩個節點上可以對於圖進行分割。

可以這麼理解:對於 node 10 來說,下一個反鏈是 [ node 11, node 14],其中 node 11 和 node 14 彼此之間無法排序尋找後續反鏈的目的就是找到下一個圖分割點

+-------+       +-------+
| node1 |       | node2 |
+---+---+       +---+---+
    |               |
    |               |
    |               |
    v               v
+---+---+       +---+---+        +-------+        +-------+
| node4 +-----> | node5 +------> | node6 +------->+ node7 |
+-------+       +-------+        +-------+        +-+-+---+
                                                    | |
                                                    | |
                                      +-------------+ |
                                      |               |
                                      v               v  augmented
                                 +----+--+        +---+---+
                                 | node9 |        | node8 +-----+
                                 +-------+        +---+---+     |
                                                      |         |
                    +---------------------------------+         |
                    |                                           |
                    v              next                         |
               +----+---+       +--------+        +--------+    |
     antichain | node10 +-----> | node11 +------> | node12 |    |
               +--------+       +---+----+        +----+---+    |
             augmented              |                  |        |
                                    |                  |        |
                                    v             next v        |
                                +---+----+        +----+---+    |
                                | node13 |        | node14 +<---+
                                +--------+        +-+----+-+
                                                    |    |
                                             +------+    +---+
                                             |               |
                                             v               v
                                        +----+---+        +--+-----+
                                        | node15 |        | node19 |
                                        +--------+        +--------+

3.4 總體構建

antichain_dag 的目的是依據 增強反鏈列表後續反鏈列表來構建一個反鏈 DAG。

我們以上面的圖例進行講解,以 node 8 為例。

def antichain_dag(self):
    if self._antichain_dag is not None:
        return self._antichain_dag

    antichain_dag = Graph()
    antichain_id = 0
    antichain = [self.sources()[0].node_id] # 獲取source第一個節點。
    # 構建首節點,同時利用 augment_antichain 來往_augmented_antichains 之中新增首節點。
    source_node = AntichainNode("antichain_%d" % antichain_id, self.augment_antichain(antichain))
    antichain_dag.source = source_node
    antichain_queue = [antichain] # 把第一個節點插入queue
    antichain_mapping = {tuple(sorted(antichain)): source_node}

    # 如果queue之中還有節點
    while len(antichain_queue) > 0:
        antichain = antichain_queue.pop(0) # 彈出第一個節點,賦值為 antichain,這裡為 node 8
        # key就是由 antichain 節點名字構建,比如 antichain_key = {tuple: 1} node8
        antichain_key = tuple(sorted(antichain)) 
        # 如果 antichain_key 已經位於self._next_antichains之中,即 antichain_key 的後續反鏈已經被記錄,就跳過去
        if antichain_key in self._next_antichains:  
            continue
        # 獲取 antichain 的後續反鏈,對於8,這裡是[[10],[14]]
        next_antichains = self.next_antichains(antichain)
        # 遍歷後續反鏈[10,14]
        for next_antichain in next_antichains:
            # 下一個反鏈節點的key 10
            next_antichain_key = tuple(sorted(next_antichain))
            if next_antichain_key not in antichain_mapping: # 如果存在,就跳過
                antichain_id += 1
                # 下一反鏈節點 10 被設定為其增強節點 [ 8, 10 ]
                next_antichain_node = AntichainNode("antichain_%d" % antichain_id, self.augment_antichain(next_antichain))
                # 設定 antichain_mapping
                antichain_mapping[next_antichain_key] = next_antichain_node
            # 向 反鏈DAG 插入邊:    
            antichain_dag.add_edge(antichain_mapping[antichain_key],
                                   antichain_mapping[next_antichain_key])
            # 把最新反鏈節點插入queue,下次迭代使用
            antichain_queue.append(next_antichain)

    self._antichain_dag = antichain_dag
    return antichain_dag

這裡其實目的是設定 antichain_mapping。

流程是:

  • 從 antichain_queue 彈出第一個節點,賦值為 antichain,這裡為 node 8。
  • 獲取 antichain 的後續反鏈,對於8,這裡是[[10],[14]]。
  • 遍歷後續反鏈 [10,14]。
  • 以 10 為例,設定下一個反鏈節點的key 為 10。
  • 下一反鏈節點 10 被設定為其增強節點 [ 8, 10 ],即 ('node10',) = {AntichainNode} antichain_5 -- ['node8', 'node10']。

可以看到,尋找某節點後續反鏈的目的就是找到下一個圖分割點 A,然後為了確定 A 的執行時間(或者其他資訊),需要找到 A 的增強反鏈(一些增強反鏈就是一些狀態),A 的 antichain_mapping 就是其增強反鏈

antichain_mapping 示例如下:

antichain_mapping = {dict: 99} 
 ('node4',) = {AntichainNode} antichain_0 -- ['node4']
 ('node5',) = {AntichainNode} antichain_1 -- ['node5']
 ('node6',) = {AntichainNode} antichain_2 -- ['node6']
 ('node7',) = {AntichainNode} antichain_3 -- ['node7']
 ('node8',) = {AntichainNode} antichain_4 -- ['node8']
 ('node10',) = {AntichainNode} antichain_5 -- ['node8', 'node10'] # 最新設定
 ('node14',) = {AntichainNode} antichain_6 -- ['node14']
 ('node11',) = {AntichainNode} antichain_7 -- ['node8', 'node11']
 ('node15',) = {AntichainNode} antichain_8 -- ['node14', 'node15']
 ('node19',) = {AntichainNode} antichain_9 -- ['node19']
 ('node12',) = {AntichainNode} antichain_10 -- ['node8', 'node12']
 ('node16',) = {AntichainNode} antichain_11 -- ['node14', 'node16']
 ('node23',) = {AntichainNode} antichain_12 -- ['node20', 'node23']
 ('node17',) = {AntichainNode} antichain_13 -- ['node14', 'node17']

antichain_dag 示例如下,可以認為就是增強反鏈DAG

antichain_dag = {Graph}
	nodes = {dict: 99} 
   'antichain_0' = {AntichainNode} antichain_0 -- ['node4']
   'antichain_1' = {AntichainNode} antichain_1 -- ['node5']
   'antichain_2' = {AntichainNode} antichain_2 -- ['node6']
   'antichain_3' = {AntichainNode} antichain_3 -- ['node7']
   'antichain_4' = {AntichainNode} antichain_4 -- ['node8']
   'antichain_5' = {AntichainNode} antichain_5 -- ['node8', 'node10']
   'antichain_6' = {AntichainNode} antichain_6 -- ['node14']
   'antichain_7' = {AntichainNode} antichain_7 -- ['node8', 'node11']
   'antichain_8' = {AntichainNode} antichain_8 -- ['node14', 'node15']
   'antichain_9' = {AntichainNode} antichain_9 -- ['node19']
   'antichain_10' = {AntichainNode} antichain_10 -- ['node8', 'node12']
   'antichain_11' = {AntichainNode} antichain_11 -- ['node14', 'node16']
   'antichain_12' = {AntichainNode} antichain_12 -- ['node20', 'node23']
   'antichain_13' = {AntichainNode} antichain_13 -- ['node14', 'node17']
   'antichain_14' = {AntichainNode} antichain_14 -- ['node20', 'node30', 'node23']
   'antichain_15' = {AntichainNode} antichain_15 -- ['node20', 'node36', 'node23']
   'antichain_16' = {AntichainNode} antichain_16 -- ['node20', 'node43', 'node23']
   'antichain_17' = {AntichainNode} antichain_17 -- ['node20', 'node23', 'node24']

3.5 拓撲排序

得到了增強反鏈之後,需要進行拓撲排序之後才能使用

antichain_gr = gr.antichain_dag()
states = antichain_gr.topological_sort()

得出拓撲排序的目的是:如果按照拓撲序列的頂點次序,在到達某節點之前,可以保證它的所有前序活動都已經完成,從而整個工程順序執行,不會衝突

在圖論中,拓撲排序(Topological Sorting)是一個有向無環圖(DAG, Directed Acyclic Graph)的所有頂點的線性序列。且該序列必須滿足下面兩個條件:

  1. 每個頂點出現且只出現一次。
  2. 若存在一條從頂點 A 到頂點 B 的路徑,那麼在序列中頂點 A 出現在頂點 B 的前面。

有向無環圖(DAG)才有拓撲排序,非DAG圖沒有拓撲排序一說。一個有向無環圖可以有一個或多個拓撲排序序列。

例如,下面這個圖:

+--------+                  +--------+
|        +----------------> |        |
|   1    |                  |   4    +------------+
|        |    +-----------> |        |            |
+-----+--+    |             +---+----+            |
      |       |                 |                 v
      |       |                 |              +--+--+
      |       |                 |        +---> |  5  |
      |       |                 |        |     +-----+
      v       |                 |        |
              |                 v        |
+--------+    |             +---+-----+  |
|        +----+             |         |  |
|    2   +----------------->+    3    +--+
|        |                  |         |
+--------+                  +---------+

得到拓撲排序後的結果是 { 1, 2, 4, 3, 5 }。

這裡的拓撲排序演算法使用的是深度優先排序。

    def topological_sort(self):
        # Algorithm from https://en.wikipedia.org/wiki/Topological_sorting
        self.sorted_nodes = []
        self.marked_nodes = set()
        self.temporarily_marked_nodes = set()
        nodes = list(self.nodes.values())
        nodes.sort(key=lambda x: x.node_desc)
        for node in nodes:
            if node.node_id in self.marked_nodes:
                continue
            self.topological_sort_helper(node.node_id)
        return [self.nodes[node_id] for node_id in self.sorted_nodes]

    def topological_sort_helper(self, node_id):
        if node_id in self.marked_nodes:
            return
        if node_id in self.temporarily_marked_nodes:
            raise Exception("Graph has a cycle")
        self.temporarily_marked_nodes.add(node_id)
        if node_id in self.edges:
            out_nodes = list(self.edges[node_id])
            out_nodes.sort(key=lambda x: (x.node_desc, x.height))
            for out_node in out_nodes:
                self.topological_sort_helper(out_node.node_id)
        self.marked_nodes.add(node_id)
        self.temporarily_marked_nodes.remove(node_id)
        self.sorted_nodes.insert(0, node_id)

最終結果舉例如下,可以和上面的反鏈DAG antichain_dag 比對,看看異同:

states = {list: 99} 
 00 = {AntichainNode} antichain_0 -- ['node4']
 01 = {AntichainNode} antichain_1 -- ['node5']
 02 = {AntichainNode} antichain_2 -- ['node6']
 03 = {AntichainNode} antichain_3 -- ['node7']
 04 = {AntichainNode} antichain_4 -- ['node8']
 05 = {AntichainNode} antichain_5 -- ['node8', 'node10']
 06 = {AntichainNode} antichain_7 -- ['node8', 'node11']
 07 = {AntichainNode} antichain_10 -- ['node8', 'node12']
 08 = {AntichainNode} antichain_6 -- ['node14']
 09 = {AntichainNode} antichain_8 -- ['node14', 'node15']
 10 = {AntichainNode} antichain_11 -- ['node14', 'node16']
 11 = {AntichainNode} antichain_13 -- ['node14', 'node17']
 12 = {AntichainNode} antichain_9 -- ['node19']
 13 = {AntichainNode} antichain_12 -- ['node20', 'node23']
 14 = {AntichainNode} antichain_18 -- ['node23', 'node20', 'node26']
 15 = {AntichainNode} antichain_17 -- ['node23', 'node20', 'node24']
 16 = {AntichainNode} antichain_32 -- ['node23', 'node20', 'node28']
 17 = {AntichainNode} antichain_31 -- ['node23', 'node20', 'node26', 'node24']
 18 = {AntichainNode} antichain_63 -- ['node23', 'node20', 'node26', 'node28']
 19 = {AntichainNode} antichain_33 -- ['node20', 'node26', 'node29']
 20 = {AntichainNode} antichain_16 -- ['node20', 'node43', 'node23']
 21 = {AntichainNode} antichain_30 -- ['node23', 'node20', 'node43', 'node26']
 22 = {AntichainNode} antichain_29 -- ['node23', 'node20', 'node43', 'node24']
 23 = {AntichainNode} antichain_59 -- ['node23', 'node20', 'node43', 'node28']

我們 也可以和如下增強反鏈比對,看到 states 就是對增強反鏈DAG進行拓撲排序之後的結果,按照這個順序進行訓練是符合邏輯的。

_augmented_antichains = {dict: 99} 
 ('node4',) = {list: 1} ['node4']
 ('node5',) = {list: 1} ['node5']
 ('node6',) = {list: 1} ['node6']
 ('node7',) = {list: 1} ['node7']
 ('node8',) = {list: 1} ['node8']
 ('node10',) = {list: 2} ['node8', 'node10']
 ('node14',) = {list: 1} ['node14']
 ('node11',) = {list: 2} ['node8', 'node11']
 ('node15',) = {list: 2} ['node14', 'node15']
 ('node19',) = {list: 1} ['node19']
 ('node12',) = {list: 2} ['node8', 'node12']
 ('node16',) = {list: 2} ['node14', 'node16']
 ('node23',) = {list: 2} ['node20', 'node23']
 ('node17',) = {list: 2} ['node14', 'node17']
 ('node23', 'node30') = {list: 3} ['node20', 'node30', 'node23']
 ('node23', 'node36') = {list: 3} ['node20', 'node36', 'node23']
 ('node23', 'node43') = {list: 3} ['node20', 'node43', 'node23']
 ('node24',) = {list: 3} ['node23', 'node20', 'node24']
 ('node26',) = {list: 3} ['node23', 'node20', 'node26']
 ('node23', 'node30', 'node36') = {list: 4} ['node20', 'node36', 'node30', 'node23']
 ('node23', 'node30', 'node43') = {list: 4} ['node20', 'node43', 'node30', 'node23']
 ('node31',) = {list: 3} ['node20', 'node26', 'node31']
 ('node24', 'node30') = {list: 4} ['node23', 'node20', 'node30', 'node24']
 ('node26', 'node30') = {list: 4} ['node23', 'node20', 'node30', 'node26']
 ('node23', 'node36', 'node43') = {list: 4} ['node20', 'node43', 'node36', 'node23']
 ('node37',) = {list: 4} ['node32', 'node20', 'node26', 'node37']
 ('node24', 'node36') = {list: 4} ['node23', 'node20', 'node36', 'node24']
 ('node26', 'node36') = {list: 4} ['node23', 'node20', 'node36', 'node26']
 ('node44',) = {list: 2} ['node40', 'node44']
 ('node24', 'node43') = {list: 4} ['node23', 'node20', 'node43', 'node24']
 ('node26', 'node43') = {list: 4} ['node23', 'node20', 'node43', 'node26']
 ('node24', 'node26') = {list: 4} ['node23', 'node20', 'node26', 'node24']

3.6 總結

因為目前的演算法比較複雜,所以我們暫時總結一下目前為止的工作:

  • 計算出了每個節點的增強反鏈,最終得到增強反鏈組合 _augmented_antichains
  • 計算出了每個節點的後續反鏈。尋找某節點後續反鏈的目的就是找到下一個圖分割點 A,然後為了確定 A 的執行時間(或者其他資訊),需要找到 A 的增強反鏈(一些增強反鏈就是一些狀態)。_next_antichains 是後續反鏈組合。
  • antichain_dag 函式依據 _next_antichains_augmented_antichains 進行處理,構建一個反鏈 DAG,就是變數 antichain_dag。
  • 得到了增強反鏈DAG之後,需要進行拓撲排序之後才能使用。得出拓撲排序的目的是:如果按照拓撲序列的頂點次序,在到達某節點之前,可以保證它的所有前序活動都已經完成,從而整個工程順序執行,不會衝突
  • states 就是對增強反鏈DAG進行拓撲排序之後的結果,按照這個順序進行訓練是符合邏輯的。所以後續工作就是在 states 基礎上執行。

0x04 計算分割槽

至此,圖已經依據後續反鏈被分割成若干狀態(states),每個狀態很重要的一個屬性是其增強反鏈。states 就是對增強反鏈進行拓撲排序之後的結果,按照這個順序進行訓練是符合邏輯的。

自動分割槽演算法具體分為兩部分。

  • compute_partitioning 是使用動態規劃演算法對於這些狀態得出一個最優化結果,但是沒有做具體分割槽。
  • analyze_partitioning 是利用最優化結果來做具體分割槽,排序後得到了一個偏序結果。

下面我們逐一分析。

4.1 main函式的邏輯

main函式接下來與計算分割槽相關的邏輯如下:

  • 為每個狀態設定index。
  • 給每個狀態計算出輸出啟用值大小,具體是通過遍歷其反鏈(增強反鏈),可以認為就是其必要前序節點給自己的輸出。
  • 給每個狀態計算其資訊,比如計算時間,啟用大小,引數大小等等,都是通過前置節點完成的 。
  • 得到總體輸出大小 output_activation_sizes & 所有前置節點id,後面計算分割槽時候需要。
  • 依據profile估計出系統內部的計算時間,compute_times_row 是 i 節點到 後續節點(i+1, i+2, ...)的計算時間,下面類似。
  • 依據profile估計出系統內部的啟用值大小。
  • 依據profile估計出系統內部的引數大小。
  • 遍歷機器集&網路頻寬組合。流水線可以是straight(數目為1)或者並行(數目為num_machines),依據目前的資訊,以及機器數量,網路頻寬等,使用動態規劃演算法計算分割槽。假如機器集&網路頻寬組合有兩個,則會用每個組合進行一次動態規劃演算法,最後 all_As.append(A) 這裡就是兩個動態規劃的結果,就是考慮到各種必要因素之後的最優結果

具體程式碼如下:

def main(all_num_machines, profile_filename, network_bandwidths, memory_size,
         straight_pipeline, use_memory_constraint, use_fewer_machines,
         activation_compression_ratio, output_directory,
         print_configuration=True, verbose=False):
    gr = graph.Graph.from_str(open(profile_filename, 'r').read())

    # Zero out all metadata associated with inputs in graph, since the optimizer
    # shouldn't really get a choice with where to place the input (should always
    # be in the first stage).
    # 排除干擾,因為input必然在第一層,沒必要讓優化器再來選擇把輸入放在哪裡,所以先去除,後續會再加上。
    sources = gr.sources() # 對圖的輸入進行處理
    nodes_to_remove = OrderedDict()
    for source in sources:
        if source.node_desc.startswith("Input"): # 只處理input
            source.forward_compute_time = 0.0
            source.backward_compute_time = 0.0
            source.activation_size = 0.0
            source.parameter_size = 0.0
            nodes_to_remove[source] = []
            for out_node in gr.edges[source.node_id]:
                nodes_to_remove[source].append(out_node) # 記錄這些刪除source對應了哪些out節點,因為後續還要處理
            gr.remove_node(source) # 在圖中移除這些input source

    # Remove all unneeded sinks that are not used, makes code generation and
    # optimization easier.
    sinks = gr.sinks() # 對圖的輸出進行處理,移除沒有用到的輸出
    for sink in sinks:
        if sink.node_desc.startswith("__getitem__"):
            gr.remove_node(sink)

    antichain_gr = gr.antichain_dag() # 得到反鏈DAG
    states = antichain_gr.topological_sort() # 拓撲排序,得到一個排序好的節點列表

    ###########################################################################
    # 之前程式碼在上節分析過,我們本節從這裡繼續分析
    ###########################################################################
    
    states_indices = {} # 為每個狀態設定index
    for i in range(len(states)):
        states_indices[states[i]] = i
        
##################################### 執行時如下        
#states_indices = {dict: 99} 
# antichain_0 -- ['node4'] = {int} 0
# antichain_1 -- ['node5'] = {int} 1
# antichain_2 -- ['node6'] = {int} 2
# antichain_3 -- ['node7'] = {int} 3
# antichain_4 -- ['node8'] = {int} 4
# ......
         
    # 給每個狀態計算出輸出啟用值大小,具體是通過遍歷其反鏈(增強反鏈),可以認為就是其必要前序節點給自己的輸出
    for i in range(len(states)):
        for antichain_node in states[i].antichain:
            states[i].output_activation_size += gr.nodes[antichain_node].activation_size
       
    # 給每個狀態計算其資訊,比如計算時間,啟用大小,引數大小等等,都是通過前置節點完成的      
    for i in range(len(states)):
        antichain = states[i].antichain
        all_predecessors = gr.all_predecessors(antichain)
        states[i].compute_time = 0.0
        states[i].activation_size = 0.0
        states[i].parameter_size = 0.0
        for predecessor in all_predecessors: # 計算所有前置節點的資訊
            states[i].compute_time += ((predecessor.forward_compute_time +
                                        predecessor.backward_compute_time) / 1000.0)
            states[i].activation_size += predecessor.activation_size
            states[i].parameter_size += predecessor.parameter_size
    gr.reset()

    # 得到總體輸出大小 & 所有前置節點id,後面計算分割槽時候需要
    output_activation_sizes = [state.output_activation_size for state in states]
    all_predecessor_ids = [[states_indices[predecessor] for predecessor in
                            antichain_gr.predecessors(states[i].node_id)]
                           for i in range(len(states))]

##################################### 執行時如下      
# output_activation_sizes = {list: 99} 
# 00 = {float} 6291456.0
# 01 = {float} 12582912.0
# 02 = {float} 12582912.0
# 03 = {float} 6553600.0    
# .....
# all_predecessor_ids = {list: 99} 
#  00 = {list: 0} []
#  01 = {list: 1} [0]
#  02 = {list: 2} [0, 1]
#  03 = {list: 3} [0, 1, 2]
#  04 = {list: 4} [0, 1, 2, 3]
#  05 = {list: 5} [2, 3, 4, 0, 1]
#  06 = {list: 6} [2, 3, 4, 0, 1, 5]
#  07 = {list: 7} [6, 2, 3, 4, 0, 1, 5]
# ......
    
    compute_times = [] # 初始化計算時間
    activation_sizes = [] # 初始化啟用值大小
    parameter_sizes = [] # 初始化引數值大小
    for i in range(len(states)+1): # 具體計算每一個節點的資訊,去除他之前節點的影響
        compute_times_row = []
        activation_sizes_row = []
        parameter_sizes_row = []
        for j in range(len(states)): # 去除之前的節點
            if i == 0: # 列表中第一個節點
                compute_times_row.append(states[j].compute_time) # i 到 j 的計算時間
                activation_sizes_row.append(states[j].activation_size)
                parameter_sizes_row.append(states[j].parameter_size)
            else: # 列表中後續節點
                if j > (i-1):
                    compute_times_row.append(states[j].compute_time -
                        states[i-1].compute_time) # i 到 j 的計算時間
                    activation_sizes_row.append(states[j].activation_size -
                        states[i-1].activation_size)
                    parameter_sizes_row.append(states[j].parameter_size -
                        states[i-1].parameter_size)
                else:
                    compute_times_row.append(None)
                    activation_sizes_row.append(None)
                    parameter_sizes_row.append(None)
        compute_times.append(compute_times_row) # 依據profile估計出系統內部的計算時間,compute_times_row 是 i 節點到 後續節點(i+1, i+2, ...)的計算時間,下面類似
        activation_sizes.append(activation_sizes_row) # 依據profile估計出系統內部的啟用值大小
        parameter_sizes.append(parameter_sizes_row) # 依據profile估計出系統內部的引數大小

##################################### 執行時如下  
# compute_times = {list: 100} 
# 000 = {list: 99} [0.0070220000000000005, 0.012285, 0.012558, 0.021096000000,...
# 001 = {list: 99} [None, 0.005263, 0.005535999999999999, 0.014074000000000003, ...
# 002 = {list: 99} [None, None, 0.00027299999999999894, 0.008811000000000003, ...
# 003 = {list: 99} [None, None, None, 0.008538000000000004, 0.008538, ...
# 004 = {list: 99} [None, None, None, None, -3.469446951953614e-18, 0.000191999999...

    counter = 1
    all_As = []
    num_machines_in_machine = 1 #第一個節點就是1
    # all_num_machines, network_bandwidths 是使用者在輸入中指定
    # 遍歷機器集&網路頻寬組合。流水線可以是straight(數目為1)或者並行(數目為num_machines)
    for num_machines, network_bandwidth in zip(all_num_machines, network_bandwidths):
        print("Solving optimization problem with %d machines with inter-machine bandwidth of %.2f GB/s" % (num_machines, network_bandwidth / 10**9))
        import numpy as np
        print(np.array(compute_times))
        # 依據目前的資訊,以及機器數量,網路頻寬等計算分割槽
        A = compute_partitioning(compute_times, activation_sizes, parameter_sizes,
                                 output_activation_sizes, all_predecessor_ids,
                                 num_machines, num_machines_in_machine,
                                 network_bandwidth,
                                 final_level=(counter==len(network_bandwidths)))
        num_machines_in_machine = num_machines # 因為計算完了,所以設定為本階段的機器數目
        for i in range(len(compute_times)): # 遍歷機器
            for j in range(len(compute_times[0])): # 後續機器
                compute_times[i][j] = A[i][j][-1][0] # 記錄計算時間(本階段最後一個機器的計算時間)
        counter += 1
        all_As.append(A) # 新增邏輯關係,就是裡面包括了不同階段的優化邏輯
    print(np.array(compute_times))
    
    # 省略後續程式碼

其中compute_times 是一個計算時間的二維陣列,也可以認為是矩陣,具體舉例如下。

[w12,w13,w14,w15], // 第一個節點到後續節點的計算時間

[None, w23,w24,w25], // 第二個節點到後續節點的計算時間

[None, None, w34, w35], // 第三個節點到後續節點的計算時間

[None, None, None, w45], // 第四個節點到後續節點的計算時間

activation_sizes 和 parameter_sizes 與之類似。

4.2 動態規劃

4.2.1 總體思路

這裡有一些動態規劃的演算法需要分析。

分割演算法試圖減少模型的整體訓練時間。對於流水線系統,這個問題等價於最小化流水線最慢階段所花費的時間。該問題具有最優化子問題性質;在給定機器計數的情況下,使吞吐量最大化的管道由子管道組成,這些子管道分別使自己這個子管道的吞吐量最大化。因此,我們可以用動態規劃來尋找這個問題的最優解。

分割槽演算法獲取profiling步驟的輸出,並計算:

1)將層劃分為多個階段,

2)每個階段的複製因子(worker數),

3)保持訓練管道繁忙的最佳動態小批量數。

PipeDream的優化器假設機器拓撲是分層的,並且可以被組織成多個級別,如下圖所示。一個級別內的頻寬是相同的,而跨級別的頻寬是不同的。我們假設 k 級由 mk 個 k-1層元件構成 ,這些元件通過頻寬為Bk的鏈路連線。在下圖中,m2=2,m1=4。此外,我們定義m0為1。即 4 個 m0 構成一個 m1, 2個 m1 構成一個 m2。

層 0 就是綠色矩形,代表最底層的計算裝置,比如GPU,4個GPU構成了一個層1(虛線矩形,代表一個伺服器),2個層1構成了一個層2(就是下圖全部模組)。

PipeDream的優化器從最低層到最高層逐步解決動態規劃問題。直觀地說,這個過程在伺服器中找到最佳分割槽,然後使用這些分割槽在伺服器之間最優地分割模型

4.2.2 具體分析

假設 A(j, m) 表示使用m臺機器在第1層和第j層之間的最佳管道中,最慢階段所用的時間

我們演算法的目標是找到 A(N,M) 和相應的劃分。讓T( i → j,m) 表示跨越層 i 到 j 的單級所用的時間,此時間在m臺機器上覆制。

其中:

  • max中的左項是在此階段中所有層的總計算時間,右項是此階段中所有層的總通訊時間。

  • 因為計算和通訊可以重疊,所以不需要相加,直接取最大數值。

  • 由1到j的由m個機器組成的最佳流水線可以是單個階段複製m次,也可以由多個階段組成。

當最佳管道包含多個階段時,它可以被分解成一個最優的子管道(由從1到 i 的 由m − m′ 個機器組成)和後續的一個單獨階段(由i+1到j 的被 m' 個機器複製組成)。因此,利用最優子問題的性質,我們得到

其中,max中:

  • 第一項是第1層和第i層之間的最優子管道(由m-m'個機器組成)的最慢階段所用的時間。

  • 第二項是在層 i 和 i + 1 之間傳遞啟用和梯度所用的時間。

  • 第三項是最後單個階段的時間(由 m' 個資料並行的機器組成)。

我們具體看看如何計算,假設一個圖邏輯如下:

                       +----------------+
+-----+                |                +--------+
|     +------------->  |  k[m_prime]    |        |          +-----+
|  i  |                |                |        +--------->+     |
|     +----+           +----------------+                   |  j  |
+-----+    |                                      +-------->+     |
           |           +----------------+         |         +-----+
           |           |                |         |
           +-------->  |  k[m-m_prime]  +---------+
                       |                |
                       +----------------+

在 (A [i] [k] [m-m_prime] [0], last_stage_time, output_transfer_time, input_transfer_time )之中選一個最大的:

  • A [i] [k] [m-m_prime] [0] :i 到 k 之間的計算時間,是已經計算好的子問題
  • last_stage_time :last_stage_time 是 (k 到 j 的計算時間) + 傳輸時間
    • 其中compute_times[k + 1] [j] 是k 到 j 的計算時間,compute_times[k + 1] 就對應了k的輸出
    • 傳輸時間是依據k 到 j 的下一階段引數大小(parameter_sizes[k + 1 ] [j])計算得出。
    • 即:last_stage_time = compute_times[k + 1] +(parameter_sizes[k + 1 ] [j])
  • input_transfer_time :使用 k 的輸出啟用大小計算出來的傳輸時間(就是 j 的輸入)。
  • output_transfer_time :使用 j 的輸出啟用大小計算出來的傳輸時間。

因為傳輸和計算是可以重疊的,所以可以這樣取最大數值

最後得到的 A 就是動態規劃優化的結果,其中每一個元素 A[i][j][m] 是個三元組 (min_pipeline_time, optimal_split, optimal_num_machines)A[i][j][m] 表示節點 i 到 節點 j 之間的計算結果。三元組就是 (最小流水線時間,i 到 j 之間那個最佳分割點,最優機器數目)。

大致階段如下圖所示:

                                                       +----------------+
                                                       | i              |
                                                       |                |
                                                       |                |
                                                       +--+------+------+
                                                          |      |
                                                          |      +----------+
                                  A[i][k][m+m_prime][0]   |                 |
                                                          |                 |
                                                          v                 v
                                        +-----------------+-------+    +----+--------+
                                        | k[m-m_prime]            |    | k[m_prime]  |
                                        |                         |    |             |
last_stage_time = compute_times[k+1][j] |                         |    |             |
            + (parameter_sizes[k+1][j]) | output_activation_sizes |    |             |
                                        |                         |    |             |
                                        |                         |    |             |
                                        +-----------------+-------+    +-----+-------+
                                     input_transfer_time  |                  |
                                                          |      +-----------+
                                                          |      |
                                                          |      |
                                                          v      v
                                             +------------+------+------+
                                             | j                        |
                                             |                          |
                                             |                          |
                                             |                          |
                                             |  output_activation_sizes |
                                             |                          |
                                             +------------------+-------+
                                          output_transfer_time  |
                                                                |
                                                                |
                                                                v

具體程式碼如下:

def compute_partitioning(compute_times, activation_sizes, parameter_sizes,
                         output_activation_sizes, all_predecessor_ids,
                         num_machines, num_machines_within_machine,
                         bandwidth, final_level=True):
    # 初始化
    A = []
    for i in range(len(compute_times)): # 遍歷所有節點
        row_A = []
        for j in range(len(compute_times[0])): # 所有後續節點(即第一個節點的所有後續節點)
            row_row_A = []
            for m in range(num_machines): # 機器數目
                row_row_A.append((None, None, None))
            row_A.append(row_row_A)
        A.append(row_A)

    # 得到計算時間
    for i in range(len(compute_times)): # 遍歷所有節點
        for j in range(i, len(compute_times[0])): # 所有後續節點
            cum_compute_time = compute_times[i][j] # i --> j 的計算時間
            cum_activation_size = activation_sizes[i][j] # i --> j 的啟用大小
            cum_parameter_size = parameter_sizes[i][j] # i --> j 的引數大小
            max_m = 1 if straight_pipeline else num_machines # 線性還是並行流水線
            for m in range(max_m): # 遍歷流水線下一階段的機器
                # 儲存的資料大小
                stashed_data_size = math.ceil((num_machines - (m+1)) / (m+1)) * \
                                              (cum_activation_size + cum_parameter_size)
                # memory_size 是使用者傳進來的引數,就是每個機器有效的記憶體  
                # use_memory_constraint 也是使用者傳進來的引數,就是使用的記憶體限制
                if use_memory_constraint and stashed_data_size > memory_size:
                    continue
                # 資料並行通訊時間依據引數尺寸,頻寬,下一階段機器數量計算    
                data_parallel_communication_time = (4 * m * cum_parameter_size) / (bandwidth * (m+1))
                # 除以本階段機器數量,如果本階段機器多,當然就是分開計算了
                data_parallel_communication_time /= num_machines_within_machine

                if cum_compute_time is None:
                    # 需要計算下一階段中,每個機器的計算時間,所以還要除以(m+1)
                    A[i][j][m] = (None, None, None) # 直接賦值
                else:
                    # 三元組,分別是[(計算時間 + 通訊時間), None,(m+1)],對應的意義是 min_pipeline_time, optimal_split, optimal_num_machines,就對應了前面的公式 2
                    A[i][j][m] = (sum([cum_compute_time,
                                       data_parallel_communication_time]) / (m+1), None, (m+1))

    # 需要得到最小計算時間                
    min_machines = 1
    max_i = len(compute_times) if not final_level else 1
    for i in range(max_i): # 遍歷節點
        for m in range(min_machines, num_machines): # 遍歷下一階段機器的可能選擇
            for j in range(i+1, len(compute_times[0])): # 遍歷 i 的後續節點
                (min_pipeline_time, optimal_split, optimal_num_machines) = A[i][j][m]
                if use_fewer_machines and m > 0 and ( # 如果設定了用盡量少的機器,則如果小於min_pipeline_time,就設定新的 min_pipeline_time
                    min_pipeline_time is None or A[i][j][m-1][0] < min_pipeline_time):
                    (min_pipeline_time, optimal_split, optimal_num_machines) = A[i][j][m-1]
                # 遍歷 j 節點的前置機器 k,注意,j 是 i 的後續節點之一
                # 就是在 i --> k --> j 之間找到一個計算時間最小的,其中A[i][k][m-m_prime][0]已經是一個最優子問題了
                for k in all_predecessor_ids[j]:
                    # 如果k已經在之前計算過了,就跳過
                    if i > 0 and k in all_predecessor_ids[i-1]:
                        continue
                    # 設定質數    
                    max_m_prime = 2 if straight_pipeline else (m+1)
                    for m_prime in range(1, max_m_prime): # prime就是看看如何分割
                        # 輸入傳輸時間 input_transfer_time 使用 k 的輸出啟用尺寸計算
                        input_transfer_time = (2.0 * output_activation_sizes[k]) / \
                            (bandwidth * m_prime)
                        # 輸出傳輸時間 output_transfer_time 使用 j 的輸出啟用尺寸計算
                        output_transfer_time = None
                        if j < len(output_activation_sizes) -1:
                            output_transfer_time = (2.0 *
                                output_activation_sizes[j]) / (bandwidth * m_prime)
                        # last_stage_time 設定為 k 到 j 的計算時間, compute_times[k+1] 就對應了k的輸出
                        last_stage_time = compute_times[k+1][j]
                        if last_stage_time is None:
                            continue
                        # 設定為 k 到 j 的下一階段引數尺寸
                        last_stage_parameter_size = parameter_sizes[k+1][j]
                        # 設定為 k 到 j 的儲存資料尺寸
                        stashed_data_size = (activation_sizes[k+1][j]) + last_stage_parameter_size
                        # 依據機器資料計算
                        stashed_data_size *= math.ceil((num_machines - (m+1)) / m_prime)
                        # 超過機器記憶體就跳過
                        if use_memory_constraint and stashed_data_size > memory_size:
                            continue
                        # 加上傳輸時間,所以 last_stage_time 是 (k 到 j 的計算時間) + 傳輸時間
                        last_stage_time = sum([last_stage_time,
                                               ((4 * (m_prime - 1) *
                                                last_stage_parameter_size) / (bandwidth * m_prime))])
                        last_stage_time /= m_prime

                        # 如果從i到k沒有邊,則跳過
                        if A[i][k][m-m_prime][0] is None:
                            continue
                        # 如果i到k已經有計算時間,則選一個較大的    
                        pipeline_time = max(A[i][k][m-m_prime][0], last_stage_time)
                        if activation_compression_ratio is not None: # 如果壓縮
                            # 在(A[i][k][m-m_prime][0], last_stage_time, output_transfer_time, input_transfer_time 之中選一個最大的)
                            input_transfer_time /= activation_compression_ratio
                            # output_transfer_time 也壓縮
                            if output_transfer_time is not None:
                                output_transfer_time /= activation_compression_ratio
                            # 選一個大的    
                            pipeline_time = max(pipeline_time, input_transfer_time)
                            if output_transfer_time is not None:
                                pipeline_time = max(pipeline_time, output_transfer_time)
                                
                        # 如果比min_pipeline_time小,則設定 min_pipeline_time,為了下一次迴圈
                        if min_pipeline_time is None or min_pipeline_time > pipeline_time:
                            optimal_split = (k, m-m_prime) # 選一個優化分割點
                            optimal_num_machines = m_prime
                            min_pipeline_time = pipeline_time
                # 設定            
                A[i][j][m] = (min_pipeline_time, optimal_split, optimal_num_machines)

    return A

all_As 就是動態規劃的結果,示例如下:

all_As = {list: 2}  
 0 = {list: 100} 
  000 = {list: 99} 
   00 = {list: 5} [(0.0070220000000000005, None, 1), (0.1689894, None, 2), (0.14943257777777777, None, 3), (0.1258643, None, 4), (0.107310576, None, 5)]
   01 = {list: 5} [(0.012285, None, 1), (0.0070220000000000005, (0, 0), 1), (0.0865995, (0, 0), 2), (0.07639255555555556, (0, 0), 3), (0.06429175000000001, (0, 0), 4)]
   02 = {list: 5} [(0.012558, None, 1), (0.0070220000000000005, (0, 0), 1), (0.0070220000000000005, (1, 1), 1), (0.0070220000000000005, (1, 1), 2), (0.0070220000000000005, (1, 1), 3)]
   03 = {list: 5} [(0.021096, None, 1), (0.012285, (1, 0), 1), (0.008538, (2, 1), 1), (0.008538, (2, 2), 1), (0.008538, (2, 3), 1)]
   ......
  __len__ = {int} 100
  
1 = {list: 100} 
 000 = {list: 99} 
  00 = {list: 5} [(0.107310576, None, 1), (0.080131832, None, 2), (0.05930489777777778, None, 3), (0.046685052000000005, None, 4), (0.03840710336000001, None, 5)]
  01 = {list: 5} [(0.06429175000000001, None, 1), (0.072057299, None, 2), (0.05690740466666667, None, 3), (0.0460065055, None, 4), (0.03840166136, None, 5)]
  02 = {list: 5} [(0.0070220000000000005, None, 1), (0.043422424, None, 2), (0.037817488, None, 3), (0.031689068, None, 4), (0.026947711359999998, None, 5)]
  03 = {list: 5} [(0.008538, None, 1), (0.0419991328, (2, 0), 1), (0.043422424, (2, 1), 1), (0.0396227304, None, 4), (0.033697556608, None, 5)]
 ......
  __len__ = {int} 100
 __len__ = {int} 2

4.2.3 區別

我們接下來要分析程式碼作者兩個相似名字變數之間的區別。

activation_sizes :某個節點所有前置節點的activation_size 之和。

for predecessor in all_predecessors:
    states[i].compute_time += ((predecessor.forward_compute_time +
                                predecessor.backward_compute_time) / 1000.0)
    states[i].activation_size += predecessor.activation_size
    states[i].parameter_size += predecessor.parameter_size

用來計算stashed資料大小,用來看看是否超過了節點配置的記憶體額度

stashed_data_size = (activation_sizes[k+1][j]) + last_stage_parameter_size
stashed_data_size *= math.ceil((num_machines - (m+1)) / m_prime)
if use_memory_constraint and stashed_data_size > memory_size:
		continue

output_activation_sizes : 某個節點所有增強反鏈的activation_size之和。

for i in range(len(states)):
    for antichain_node in states[i].antichain:
        states[i].output_activation_size += gr.nodes[antichain_node].activation_size

用來計算輸出傳播時間和輸入傳播時間

input_transfer_time = (2.0 * output_activation_sizes[k]) / \
    (bandwidth * m_prime)
output_transfer_time = None
if j < len(output_activation_sizes) -1:
    output_transfer_time = (2.0 *
        output_activation_sizes[j]) / (bandwidth * m_prime)

0x05 分析分割槽

5.1 main函式邏輯

前面計算分割槽只是得到了一個動態規劃優化結果,需要在analyze_partitioning之中進行分析劃分之後,賦予到各個層(stage)

main函式接下來與計算分割槽相關的邏輯如下:

  • states是反鏈DAG的結果,all_As 就是動態規劃得到的優化結果,可能是多個
  • splits 初始化時候就只有一個二元組元素:最初的劃分 (0, len(states))。
  • 遍歷all_As的動態優化結果,對於每個動態優化結果,遍歷其各個邏輯關係,呼叫 analyze_partitioning 對分割槽進行分析,在splits分割中遍歷,splits會逐步更新(分割點逐步逐階段細化),analyze_partitioning 返回一個 partial_splits。
  • 遍歷 partial_splits,對於每一個分割點,獲取其增強反鏈(states)的所有前置節點,給這些節點打上stage_id。這裡是從前往後遍歷,所以stage_id數值是逐步增加。
  • 把圖寫到檔案之中。後續 convert_graph_to_model.py 會把這個檔案轉換成模型。
  • 做分析對比。

具體程式碼如下:

def main(all_num_machines, profile_filename, network_bandwidths, memory_size,
         straight_pipeline, use_memory_constraint, use_fewer_machines,
         activation_compression_ratio, output_directory,
         print_configuration=True, verbose=False):
    gr = graph.Graph.from_str(open(profile_filename, 'r').read())

    # Zero out all metadata associated with inputs in graph, since the optimizer
    # shouldn't really get a choice with where to place the input (should always
    # be in the first stage).
    # 排除干擾,因為input必然在第一層,沒必要讓優化器再來選擇把輸入放在哪裡,所以先去除,後續會再加上。
    sources = gr.sources() # 對圖的輸入進行處理
    nodes_to_remove = OrderedDict()
    for source in sources:
        if source.node_desc.startswith("Input"): # 只處理input
            source.forward_compute_time = 0.0
            source.backward_compute_time = 0.0
            source.activation_size = 0.0
            source.parameter_size = 0.0
            nodes_to_remove[source] = []
            for out_node in gr.edges[source.node_id]:
                nodes_to_remove[source].append(out_node) # 記錄這些刪除source對應了哪些out節點,因為後續還要處理
            gr.remove_node(source) # 在圖中移除這些input source

    # Remove all unneeded sinks that are not used, makes code generation and
    # optimization easier.
    sinks = gr.sinks() # 對圖的輸出進行處理,移除沒有用到的輸出
    for sink in sinks:
        if sink.node_desc.startswith("__getitem__"):
            gr.remove_node(sink)

    antichain_gr = gr.antichain_dag() # 得到反鏈DAG
    states = antichain_gr.topological_sort() # 拓撲排序,得到一個排序好的節點列表

    ###########################################################################
    # 計算階段
    ###########################################################################
    states_indices = {} # 為每個狀態設定index
    for i in range(len(states)):
        states_indices[states[i]] = i
        
##################################### 執行時如下        
#states_indices = {dict: 99} 
# antichain_0 -- ['node4'] = {int} 0
# antichain_1 -- ['node5'] = {int} 1
# antichain_2 -- ['node6'] = {int} 2
# antichain_3 -- ['node7'] = {int} 3
# antichain_4 -- ['node8'] = {int} 4
# ......
         
    # 給每個狀態計算出輸出啟用值大小,具體是通過遍歷其反鏈(增強反鏈),可以認為就是其必要前序節點給自己的輸出
    for i in range(len(states)):
        for antichain_node in states[i].antichain:
            states[i].output_activation_size += gr.nodes[antichain_node].activation_size
       
    # 給每個狀態計算其資訊,比如計算時間,啟用大小,引數大小等等,都是通過前置節點完成的      
    for i in range(len(states)):
        antichain = states[i].antichain
        all_predecessors = gr.all_predecessors(antichain)
        states[i].compute_time = 0.0
        states[i].activation_size = 0.0
        states[i].parameter_size = 0.0
        for predecessor in all_predecessors: # 計算所有前置節點的資訊
            states[i].compute_time += ((predecessor.forward_compute_time +
                                        predecessor.backward_compute_time) / 1000.0)
            states[i].activation_size += predecessor.activation_size
            states[i].parameter_size += predecessor.parameter_size
    gr.reset()

    # 得到總體輸出大小 & 所有前置節點id,後面計算分割槽時候需要
    output_activation_sizes = [state.output_activation_size for state in states]
    all_predecessor_ids = [[states_indices[predecessor] for predecessor in
                            antichain_gr.predecessors(states[i].node_id)]
                           for i in range(len(states))]

##################################### 執行時如下      
# output_activation_sizes = {list: 99} 
# 00 = {float} 6291456.0
# 01 = {float} 12582912.0
# 02 = {float} 12582912.0
# 03 = {float} 6553600.0    
# .....
# all_predecessor_ids = {list: 99} 
#  00 = {list: 0} []
#  01 = {list: 1} [0]
#  02 = {list: 2} [0, 1]
#  03 = {list: 3} [0, 1, 2]
#  04 = {list: 4} [0, 1, 2, 3]
#  05 = {list: 5} [2, 3, 4, 0, 1]
#  06 = {list: 6} [2, 3, 4, 0, 1, 5]
#  07 = {list: 7} [6, 2, 3, 4, 0, 1, 5]
# ......
    
    compute_times = [] # 初始化計算時間
    activation_sizes = [] # 初始化啟用值大小
    parameter_sizes = [] # 初始化引數值大小
    for i in range(len(states)+1): # 具體計算每一個節點的資訊,去除他之前節點的影響
        compute_times_row = []
        activation_sizes_row = []
        parameter_sizes_row = []
        for j in range(len(states)): # 去除之前的節點
            if i == 0: # 列表中第一個節點
                compute_times_row.append(states[j].compute_time) # i 到 j 的計算時間
                activation_sizes_row.append(states[j].activation_size)
                parameter_sizes_row.append(states[j].parameter_size)
            else: # 列表中後續節點
                if j > (i-1):
                    compute_times_row.append(states[j].compute_time -
                        states[i-1].compute_time) # i 到 j 的計算時間
                    activation_sizes_row.append(states[j].activation_size -
                        states[i-1].activation_size)
                    parameter_sizes_row.append(states[j].parameter_size -
                        states[i-1].parameter_size)
                else:
                    compute_times_row.append(None)
                    activation_sizes_row.append(None)
                    parameter_sizes_row.append(None)
        compute_times.append(compute_times_row) # 依據profile估計出系統內部的計算時間,compute_times_row 是 i 節點到 後續節點(i+1, i+2, ...)的計算時間,下面類似
        activation_sizes.append(activation_sizes_row) # 依據profile估計出系統內部的啟用值大小
        parameter_sizes.append(parameter_sizes_row) # 依據profile估計出系統內部的引數大小

##################################### 執行時如下  
# compute_times = {list: 100} 
# 000 = {list: 99} [0.0070220000000000005, 0.012285, 0.012558, 0.021096000000,...
# 001 = {list: 99} [None, 0.005263, 0.005535999999999999, 0.014074000000000003, ...
# 002 = {list: 99} [None, None, 0.00027299999999999894, 0.008811000000000003, ...
# 003 = {list: 99} [None, None, None, 0.008538000000000004, 0.008538, ...
# 004 = {list: 99} [None, None, None, None, -3.469446951953614e-18, 0.000191999999...

    counter = 1
    all_As = []
    num_machines_in_machine = 1 #第一個節點就是1
    # all_num_machines, network_bandwidths 是使用者在輸入中指定
    # 遍歷機器集&網路頻寬組合。流水線可以是straight(數目為1)或者並行(數目為num_machines)
    for num_machines, network_bandwidth in zip(all_num_machines, network_bandwidths):
        print("Solving optimization problem with %d machines with inter-machine bandwidth of %.2f GB/s" % (num_machines, network_bandwidth / 10**9))
        import numpy as np
        print(np.array(compute_times))
        # 依據目前的資訊,以及機器數量,網路頻寬等計算分割槽
        A = compute_partitioning(compute_times, activation_sizes, parameter_sizes,
                                 output_activation_sizes, all_predecessor_ids,
                                 num_machines, num_machines_in_machine,
                                 network_bandwidth,
                                 final_level=(counter==len(network_bandwidths)))
        num_machines_in_machine = num_machines # 因為計算完了,所以設定為本階段的機器數目
        for i in range(len(compute_times)): # 遍歷機器
            for j in range(len(compute_times[0])): # 後續機器
                compute_times[i][j] = A[i][j][-1][0] # 記錄計算時間(本階段最後一個機器的計算時間)
        counter += 1
        all_As.append(A) # 新增邏輯關係,就是裡面包括了不同階段的優化邏輯
    print(np.array(compute_times))
    
    ###########################################################################
    # 我們從這裡繼續分析
    ###########################################################################
    
    # 分析階段
    # 在 analyze_partitioning 內部做了具體分析
    # 這裡最重要的是對 gr.all_predecessors 做設定,就是設定 gr 之中每個node的stage_id,這樣就是利用stage_id把初始流水線重新劃分
    splits = [(0, len(states))] # 如何分割,states是反鏈DAG的結果,所以 splits 初始化時候就只有一個二元組元素:最初的劃分 (0, len(states))
    i = len(all_As) - 1 # all_As 就是動態規劃得到的優化結果
    while i >= 0: # 遍歷優化的出來的各個邏輯關係
        print("======================================")
        print("Level %d" % (i+1))
        print("======================================")
        new_splits = []
        stage_id = 0 # 在後續的convert_graph_to_model.py 之中會使用到
        for (start, end) in splits: # 在分割中遍歷,splits會逐步更新
            # 依據新的splits中的二元組重新計算
            partial_splits = \
                analyze_partitioning(all_As[i], states, start, end,
                                     network_bandwidths[i], all_num_machines[i],
                                     activation_compression_ratio,
                                     print_configuration, verbose)
            start_point = start # 起始點
            for split in partial_splits: # 遍歷分析得出的節點
                new_splits.append((start_point, split)) # 新增一個新的二元祖
                if i == 0:
                    predecessors = gr.all_predecessors(states[split-1].antichain)
                    for predecessor in predecessors:
                        if predecessor.stage_id is None:
                            predecessor.set_stage_id(stage_id) # 設定所在階段
                start_point = split # 下一個階段
                stage_id += 1 # 增加所在階段
            new_splits.append((start_point, end)) # 新增一個新的二元祖
            if i == 0:                
                predecessors = gr.all_predecessors(states[end-1].antichain)
                for predecessor in predecessors:
                    if predecessor.stage_id is None:
                        predecessor.set_stage_id(stage_id) # 設定所在階段
            stage_id += 1 # 增加所在階段
        
        print("Total number of stages: %d" % stage_id)
        splits = new_splits # 加入新的分割
        i -= 1

    # 以下是為了把圖寫到檔案之中。後續convert_graph_to_model.py會把這個檔案轉換成模型 
    for source in nodes_to_remove: # 之前移除了input節點,現在需要加回到圖中
        for out_node in nodes_to_remove[source]: # input對應的哪些輸出
            source.stage_id = 0
            gr.add_edge(source, out_node)

    if output_directory is not None:
        total_num_machines = 1
        for num_machines in all_num_machines:
            total_num_machines *= num_machines
        gr.to_dot(os.path.join(output_directory, "gpus=%d" % total_num_machines))
        gr_str = str(gr)
        with open(os.path.join(output_directory, "gpus=%d.txt" % total_num_machines), 'w') as f:
            f.write(gr_str)

    # 以下是為了做分析對比        
    # 計算資料並行需要的時間,以便接下來做比較,這個時間要比動態規劃時間長。        
    total_time = states[-1].compute_time # 最後一個階段的計算時間,是沒有經過優化的最初計算時間
    total_parameter_size = states[-1].parameter_size
    data_parallel_total_time = total_time # 先賦值為最後一階段的計算時間
    num_machines_in_machine = 1 # 本階段的機器數目
    # 遍歷流水線上各個階段,因為沒有優化,所以就是嚴格按照使用者原始配置的流水線階段來逐一計算
    for (num_machines, network_bandwidth) in zip(all_num_machines, network_bandwidths):
        # 計算傳輸時間。num_machines是下一階段流水線機器數目,所以頻寬需要乘以這個數字
        data_parallel_communication_time = (
            (4 * (num_machines - 1) * total_parameter_size) /
            (network_bandwidth * num_machines)) / num_machines_in_machine
        # 總時間需要加上傳輸時間
        data_parallel_total_time = sum(
            [data_parallel_total_time, data_parallel_communication_time]) / num_machines
        # 下個迭代中,本階段的機器數目需要設定為num_machines
        num_machines_in_machine = num_machines

    # 這個是用動態規劃演算法得出來的優化時間    
    pipeline_parallel_total_time = A[0][len(states)-1][num_machines-1][0]

    # 可以看到使用者需要注意哪些資料
    if verbose:
        print()
        print("Time taken by single-stage pipeline:", total_time)
        print("Time per stage in pipeline:", pipeline_parallel_total_time)
        print("Throughput increase (compared to single machine):",
              total_time / pipeline_parallel_total_time)
        dp_str = ",".join([str(elem) for elem in all_num_machines])
        print(("[Note that single-machine and (%s)-machine DP might not fit "
               "given memory constraints]") % dp_str)
        print("Throughput increase of (%s)-machine DP compared to single "
              "machine:" % dp_str, total_time / data_parallel_total_time)
        print("Throughput increase (compared to (%s)-machine DP):" % dp_str,
              data_parallel_total_time / pipeline_parallel_total_time)
    return pipeline_parallel_total_time, data_parallel_total_time    

5.2 分析階段

分析階段具體可以參見下面註釋。

def analyze_partitioning(A, states, start, end, network_bandwidth, num_machines,
                         activation_compression_ratio, print_configuration, verbose):
    # start,end 是本組節點的起始點,終止點
    metadata = A[start][end-1][num_machines-1] # 這是個三元組  (min_pipeline_time, optimal_split, optimal_num_machines)
    next_split = metadata[1] # metadata[1] 是 optimal_split,即 (k, m-m_prime)
    remaining_machines_left = num_machines
    splits = []
    replication_factors = []
    prev_split = end - 1 # 前一個分割點
    
    while next_split is not None: #是否繼續分割
        num_machines_used = metadata[2] # optimal_num_machines
        if verbose:
            print("-------------------------------------")
            print("Number of machines used: %d..." % num_machines_used)
            print("Split between layers %d and %d..." % (next_split[0], next_split[0] + 1))
            print("Split before antichain %s..." % (states[next_split[0]+1].antichain))
        splits.append(next_split[0]+1) # 得到了 k + 1,這是關鍵點,因為最後返回的是splits
        compute_time = states[prev_split-1].compute_time - \
            states[next_split[0]].compute_time
        parameter_size = states[prev_split-1].parameter_size - \
            states[next_split[0]].parameter_size

        dp_communication_time = (4 * (num_machines_used - 1) * parameter_size) \
            / (network_bandwidth * num_machines_used)
        pp_communication_time_input = ( # 下個階段的資料輸入時間
            2.0 * states[next_split[0]].output_activation_size *
            (1.0 / float(num_machines_used))) / network_bandwidth
        pp_communication_time_output = ( # 上個階段的資料輸出時間
            2.0 * states[prev_split-1].output_activation_size *
            (1.0 / float(num_machines_used))) / network_bandwidth
        # 如果需要壓縮,就進行壓縮
        if activation_compression_ratio is not None:
            pp_communication_time_input /= activation_compression_ratio
            pp_communication_time_output /= activation_compression_ratio
        if activation_compression_ratio is None:
            pp_communication_time_input = 0.0
            pp_communication_time_output = 0.0

        compute_time /= num_machines_used # 本階段計算時間
        dp_communication_time /= num_machines_used # 資料並行時間

        if verbose:
            print(("Compute time = %f, Data-parallel communication time = %f, "
                   "Pipeline-parallel communication time = %f...") % (
                compute_time, dp_communication_time,
                max(pp_communication_time_input, pp_communication_time_output)))
        prev_split = splits[-1] # 設定新的前一分割點
        # next_split 格式是 (k, m-m_prime),就是 optimal_split 的格式
        # A[i][j][m] 格式是 (min_pipeline_time, optimal_split, optimal_num_machines)
        metadata = A[start][next_split[0]][next_split[1]]
        next_split = metadata[1] # 設定新的下一次分割點,就是 optimal_split
        replication_factors.append(num_machines_used) # 每個階段的 replication factor
        remaining_machines_left -= num_machines_used # 剩餘機器
    if verbose:
        print("-------------------------------------")
        print("Number of machines used: %d..." % metadata[2])

    #     
    num_machines_used = metadata[2]
    remaining_machines_left -= num_machines_used # 剩餘的機器
    compute_time = states[prev_split-1].compute_time 
    parameter_size = states[prev_split-1].parameter_size
    dp_communication_time = ((4 * (num_machines_used - 1) * parameter_size) /
                             (network_bandwidth * num_machines_used)) 
    compute_time /= num_machines_used # 計算時間
    dp_communication_time /= num_machines_used # 資料並行通訊時間

    if verbose:
        print("Compute time = %f, Data-parallel communication time = %f..." %
              (compute_time, dp_communication_time))
        print("-------------------------------------")
    if print_configuration:
        print("Number of machines in budget not used: %d..." %
              remaining_machines_left)
        print()
        print("(Split start, split end) / compute time taken per stage "
              "/ replication factor per stage:")
    # 下面就是列印 (Split start, split end) / compute time taken per stage / replication factor per stage    
    prev_split = start
    splits.reverse() # 
    splits.append(end)
    replication_factors.append(num_machines_used)
    replication_factors.reverse()
    for i in range(len(splits)):
        time = 0.0
        if prev_split > 0:
            time = states[splits[i]-1].compute_time - states[prev_split-1].compute_time
        else:
            time = states[splits[i]-1].compute_time
        if print_configuration:
            print((prev_split, splits[i]), time, replication_factors[i])
        prev_split = splits[i]
    if print_configuration:
        print()
    return splits[:-1] # 最後一個不返回

我們還是用樣例進行說明。

這裡是從後面進行分割,舉例分析如下,這裡設定了總機器數目為10:

回憶在計算分割槽之中,A[i][j][m] = (min_pipeline_time, optimal_split, optimal_num_machines),optimal_split = (k, m-m_prime) 是一個本階段優化點。

所以在本函式之中,start = 0, end = 99,所以 metadata 為A[0][99][10],即 (0.01903199999999998, (95, 8), 1),next_split = (95, 8),prev_split = end - 1 = 98。

next_split 就是下一個分割點,splits 是目前的分割序列。

第一輪while迴圈

因為next_split = (95, 8),所以 splits = append(next_split[0]+1) = [96],因此計算 states[prev_split-1] - states[next_split[0]] = state[97] - state[95]。這樣把0~99分成了 0 ~95 和 96 ~ 99。

然後 prev_split = 96,去找A[ 0 ] [ 95] [8] 得到 meta = (0.019031999999999993, (78, 7), 1),next_split = (78, 7)。

所以下一輪從78這個分割點開始分割。

第二輪while迴圈

因為next_split = (78, 7),所以 splits = [96, 79],這就是新的分割序列。,因此計算 states[96-1] - states[next_split[0]] = state[96] - state[78]。這樣就使用 splits = [96, 79] 把0~99分成了 0 ~78,79 ~ 95 和 96 ~ 99。

然後 prev_split =79,去找A[ 0 ] [ 78 ] [ 7 ] 得到 meta = (0.011081, (48, 6), 1),next_split = (48, 6)。

所以下一輪從 48 這個分割點開始分割,以此類推。

while迴圈之後,得到 splits = [96, 79, 49, 15, 12, 7, 5, 3, 1]。

於是下面程式碼需要把順序調整過來。

prev_split = start
splits.reverse()
splits.append(end)
replication_factors.append(num_machines_used)
replication_factors.reverse()

得到:splits = { 1,3,5,7,12,15,49,79,96 }。然後加上 end = 99。

最後返回 splits[:-1],即返回 { 1,3,5,7,12,15,49,79,96 },去掉剛剛新增的end。

而依據 { 1,3,5,7,12,15,49,79,96 } 得到的最終分割序列 是 [(0, 1), (1, 3), (3, 5), (5, 7), (7, 12), (12, 15), (15, 49), (49, 79), (79, 96), (96, 99)],這個列表會在後續"設定stage"之中會用到。

5.3 設定stage

目前我們得到了一個理想分割序列,但是事情沒有結束,我們回憶一下分割槽演算法的目的:依據profile結果確定所有層的執行時間,然後使用動態規劃對模型進行劃分,將模型劃分為不同的stage,以及得到每個stage的replication數。

所以,分析的最終目的是給模型的每一個子層分配一個stage,如果某些子層屬於同一個stage,這些子層最終就被分配到同一個worker(節點)上執行

因為這裡涉及到多個子網,所以我們依然用例項來分析。

如果分成了兩個子網,假設:

all_num_machines = [5,5]
network_bandwidths = [800000000, 1000000000]

初始化 splits = [0,99]。

第一輪 while 中,i = 1,

對於 splits 結果[(0, 99)] 遍歷,每一段應用analyze_partitioning,得到 partial_splits 為 [3, 6, 30, 75, 99]。

最後,splits 更新為:[(0, 3), (3, 6), (6, 30), (30, 75), (75, 99)]。

此時不會設定stage_id

第二輪 while 中,i = 0,

對於第一輪的 splits 結果 [(0, 3), (3, 6), (6, 30), (30, 75), (75, 99)] 進行遍歷,對於這裡的每一段也應用 analyze_partitioning,比如對 (0,3) 應用analyze_partitioning,對 (3,6) 應用 analyze_partitioning,對(6,30) 也應用 analyze_partitioning,......,最後得到新的 partial_splits 為 [1, 2, 3, 4, 5, 6, 8, 10, 13, 28, 30, 45, 49, 51, 75, 79, 96, 99]。

最後,splits 更新為:[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 8), (8, 10), (10, 13), (13, 28), (28, 30), (30, 45), (45, 49), (49, 51), (51, 75), (75, 79), (79, 96), (96, 99)]。

這個列表就是理想分割序列

在此輪中,得到了partial_splits之後,會遍歷 for split in partial_splits: 然後對於每一個 split,利用

states[split-1].antichain 獲取其增強反鏈的所有前置節點,給這些節點打上 split 對應的 stage_id

回憶一下增強反鏈的意義:

  • 每個節點的增強反鏈包括:本身節點 + 部分前序節點
  • 對於增強反鏈概念,可以理解為:對於節點 A,他只有把節點 Z 一起考慮,才能唯一確定自己節點的執行時間

所以,對於 split = 1,1 - 1 = 0,於是就得到 states[0].antichain ,就是 'node4',那麼 'node4' 自己被打上了一個stage_id=0,說明 'node4' 被分到了一個 "與stage_id=0 所對應" 的 worker 節點上訓練。

如果有疑問,我們回憶一下state如何構建,就是有序的 "節點組合"。

antichain_gr = gr.antichain_dag()
states = antichain_gr.topological_sort()

具體如下。

states = {list: 99} 
 00 = {AntichainNode} antichain_0 -- ['node4'] # states[0].antichain
 01 = {AntichainNode} antichain_1 -- ['node5']
 02 = {AntichainNode} antichain_2 -- ['node6']
 03 = {AntichainNode} antichain_3 -- ['node7']
 04 = {AntichainNode} antichain_4 -- ['node8']
 05 = {AntichainNode} antichain_5 -- ['node8', 'node10']
 06 = {AntichainNode} antichain_7 -- ['node8', 'node11']
 07 = {AntichainNode} antichain_10 -- ['node8', 'node12']
 08 = {AntichainNode} antichain_6 -- ['node14']
 09 = {AntichainNode} antichain_8 -- ['node14', 'node15']
 10 = {AntichainNode} antichain_11 -- ['node14', 'node16']
 11 = {AntichainNode} antichain_13 -- ['node14', 'node17']
 12 = {AntichainNode} antichain_9 -- ['node19']
 13 = {AntichainNode} antichain_12 -- ['node20', 'node23']
 14 = {AntichainNode} antichain_18 -- ['node23', 'node20', 'node26']
 15 = {AntichainNode} antichain_17 -- ['node23', 'node20', 'node24']
 16 = {AntichainNode} antichain_32 -- ['node23', 'node20', 'node28']
 17 = {AntichainNode} antichain_31 -- ['node23', 'node20', 'node26', 'node24']
 18 = {AntichainNode} antichain_63 -- ['node23', 'node20', 'node26', 'node28']
 19 = {AntichainNode} antichain_33 -- ['node20', 'node26', 'node29']
 20 = {AntichainNode} antichain_16 -- ['node20', 'node43', 'node23']
 21 = {AntichainNode} antichain_30 -- ['node23', 'node20', 'node43', 'node26']
 22 = {AntichainNode} antichain_29 -- ['node23', 'node20', 'node43', 'node24']
 23 = {AntichainNode} antichain_59 -- ['node23', 'node20', 'node43', 'node28']

設定stage 具體程式碼如下:

splits = [(0, len(states))]
i = len(all_As) - 1
while i >= 0:
    new_splits = []
    stage_id = 0
    for (start, end) in splits:
        partial_splits = \
            analyze_partitioning(all_As[i], states, start, end,
                                 network_bandwidths[i], all_num_machines[i],
                                 activation_compression_ratio,
                                 print_configuration, verbose)
        start_point = start
        for split in partial_splits: # 遍歷這個偏序列表
            new_splits.append((start_point, split))
            if i == 0: # 最終的while
                # 針對每個節點,找到每個節點的所有反鏈
                predecessors = gr.all_predecessors(states[split-1].antichain)
                for predecessor in predecessors:
                    if predecessor.stage_id is None:
                        predecessor.set_stage_id(stage_id) # 打上stage id
            start_point = split
            stage_id += 1
        new_splits.append((start_point, end))
        if i == 0: # 最終的while
            predecessors = gr.all_predecessors(states[end-1].antichain)
            for predecessor in predecessors:
                if predecessor.stage_id is None:
                    predecessor.set_stage_id(stage_id) # 打上stage id
        stage_id += 1
    splits = new_splits
    i -= 1

5.4 總結

我們總結一下計算分割槽和分析分割槽所做的工作:

  • 反鏈DAG圖已經被分割成若干狀態(states),每個狀態很重要的一個屬性是其增強反鏈。states 就是對增強反鏈進行拓撲排序之後的結果,按照這個順序進行訓練是符合邏輯的。

  • compute_partitioning 是使用動態規劃演算法對於這些 states 狀態得出一個最優化結果但是這個計算分割槽只是得到了一個動態規劃優化結果,需要在analyze_partitioning之中進行分析劃分之後,賦予到各個層(stage)

  • analyze_partitioning 是利用動態規劃演算法的最優化結果來做具體分割槽,排序後得到了一個偏序結果,就是理想分割序列。

  • 依據 analyze_partitioning 的結果,給模型的每一個子層分配一個stage,如果某些子層屬於同一個stage,這些子層最終就被分配到同一個worker(節點)上執行

0x06 輸出

輸出檔案如下(摘錄部分),可以看到,關鍵之處在於給每一個節點加上了stage,具體如何使用我們將在下一篇進行分析。比如:

stage_id=0 對應的是 node4。

stage_id=1 對應的是 node5,node6。

stage_id=2 對應的是 node7。

stage_id=3 對應的是 node8,node10,node11,node12。

......

具體如下:

node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000 -- stage_id=0
node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000 -- stage_id=1
node6 -- Dropout(p=0.2) -- forward_compute_time=0.077, backward_compute_time=0.196, activation_size=12582912.0, parameter_size=0.000 -- stage_id=1
node7 -- LSTM(2048, 1024) -- forward_compute_time=3.190, backward_compute_time=5.348, activation_size=6553600.0, parameter_size=50364416.000 -- stage_id=2
node8 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=3
node10 -- Dropout(p=0.2) -- forward_compute_time=0.064, backward_compute_time=0.128, activation_size=6291456.0, parameter_size=0.000 -- stage_id=3
node11 -- LSTM(1024, 1024) -- forward_compute_time=2.491, backward_compute_time=4.203, activation_size=6553600.0, parameter_size=33587200.000 -- stage_id=3
node12 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=3
node14 -- Add -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=4
node15 -- Dropout(p=0.2) -- forward_compute_time=0.059, backward_compute_time=0.121, activation_size=6291456.0, parameter_size=0.000 -- stage_id=4
node16 -- LSTM(1024, 1024) -- forward_compute_time=2.492, backward_compute_time=4.201, activation_size=6553600.0, parameter_size=33587200.000 -- stage_id=4
node17 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=5
node19 -- Add -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=5
	node1 -- node4
	node4 -- node5
	node2 -- node5
	node5 -- node6
	node6 -- node7
	node7 -- node8
	node8 -- node10
	node10 -- node11
	node11 -- node12
	node12 -- node14
	node8 -- node14
	node14 -- node15
	node15 -- node16
	node16 -- node17
	node17 -- node19

0xFF 參考

[原始碼解析] 深度學習流水線並行之PipeDream(1)--- Profile階段

相關文章