[原始碼解析] 深度學習流水線並行 PipeDream(2)--- 計算分割槽
0x00 摘要
在前文中,我們介紹了PipeDream的總體架構和Profile階段,本文我們繼續介紹計算分割槽階段。其功能是:依據profile結果確定所有層的執行時間,然後使用動態規劃對模型進行劃分,將模型劃分為不同的stage,以及得到每個stage的replication數。計算結果具體如下圖所示:
流水線並行其他文章連結如下:
[原始碼解析] 深度學習流水線並行Gpipe(1)---流水線基本實現
[原始碼解析] 深度學習流水線並行GPipe (2) ----- 梯度累積
[原始碼解析] 深度學習流水線並行之PipeDream(1)--- Profile階段
0x01 前言
1.1 Profile檔案
我們首先看看profile檔案 profiler/translation/profiles/gnmt/graph.txt 內容,這裡只是做摘錄。
node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
node5 -- EmuBidirLSTM( (bidir): LSTM(1024, 1024, bidirectional=True) (layer1): LSTM(1024, 1024) (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
node2 -- Input1 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
node6 -- Dropout(p=0.2) -- forward_compute_time=0.077, backward_compute_time=0.196, activation_size=12582912.0, parameter_size=0.000
node7 -- LSTM(2048, 1024) -- forward_compute_time=3.190, backward_compute_time=5.348, activation_size=[6291456.0; 131072.0; 131072.0], parameter_size=50364416.000
node8 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
node9 -- __getitem__(1) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=131072.0, parameter_size=0.000
node10 -- Dropout(p=0.2) -- forward_compute_time=0.064, backward_compute_time=0.128, activation_size=6291456.0, parameter_size=0.000
node11 -- LSTM(1024, 1024) -- forward_compute_time=2.491, backward_compute_time=4.203, activation_size=[6291456.0; 131072.0; 131072.0], parameter_size=33587200.000
node12 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
node13 -- __getitem__(1) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=131072.0, parameter_size=0.000
node14 -- Add -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
node15 -- Dropout(p=0.2) -- forward_compute_time=0.059, backward_compute_time=0.121, activation_size=6291456.0, parameter_size=0.000
node16 -- LSTM(1024, 1024) -- forward_compute_time=2.492, backward_compute_time=4.201, activation_size=[6291456.0; 131072.0; 131072.0], parameter_size=33587200.000
node17 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
......
node1 -- node4
node4 -- node5
node2 -- node5
node5 -- node6
node6 -- node7
node7 -- node8
node7 -- node9
node8 -- node10
node10 -- node11
node11 -- node12
node11 -- node13
node12 -- node14
node8 -- node14
node14 -- node15
node15 -- node16
node16 -- node17
node16 -- node18
node17 -- node19
node14 -- node19
......
1.2 總體思路
在前文我們也提到了幾個挑戰,其中有:
- 如何高效劃分流水線。
- 模型特質和硬體拓撲會降低效率。分配演算法也必須考慮模型特質和硬體拓撲。
- 機器間的過度通訊會降低硬體效率。
- 如何防止流水線瓶頸。
- 由木桶原理我們可以知道,一個流水線管道的吞吐量由這個流水線上最慢環節的吞吐量決定。所以需要確保流水線中所有階段都大致花費相同的計算時間,否則最慢的階段將會成為整個流水線的瓶頸。
因此當跨機器將層劃分為不同的階段時,PipeDream的自動劃分演算法必須確保每個階段大致執行相同的總工作量。同時還必須確保各階段之間通訊的資料量儘可能小,以避免通訊中斷。
PipeDream的自動劃分演算法總體目標是輸出一個平衡的管道,演算法如下:
- 將DNN層劃分為多個階段,以便每個階段以大致相同的速率完成,即花費大致相同的計算時間。
- 嘗試以拓撲感知的方式儘量減少worker之間的通訊(例如,如果可能,向更高頻寬的鏈路傳送較大的輸出)。
- 因為DNN並不總可以在可用的workers做平均分配,為了進一步改進負載平衡,PipeDream允許複製一個stage,即在這個stage上使用多個worker進行資料並行。
這個劃分問題等價於最小化流水線的最慢階段所花費的時間,並且具有最優子問題屬性:在給定worker工作量前提下,吞吐量最大化的流水線由一系列子流水線構成,其中每一個子流水線針對較小worker工作量來最大化自己的輸出。因此PipeDream使用動態規劃來尋找最優解。
這裡給出對應的架構圖如下:
我們下面先看看計算分割槽之前的準備工作:圖相關工作和構建反鏈。
0x02 圖相關
圖的定義位於 graph/graph.py 檔案之中,主要資料結構有兩個:Graph 和 Node。
2.1 Graph
Graph就是圖的資料結構,其主要成員包括:
- nodes :圖內節點;
- edges :圖內每個節點的輸出邊;
- in_edges :圖的每個節點的輸入邊;
- _predecessors :每個節點的前序節點;
- _successors :每個節點的後序節點;
- _antichain_dag :反鏈DAG;
class Graph(object):
def __init__(self, node=None):
self.nodes = {} # 節點
if node is not None:
self.nodes[node.node_id] = node
self.edges = {} # 出邊
self.in_edges = {} # 入邊
self._predecessors = {} #每個節點的前序節點
self._successors = {} # 每個節點的後序節點
self._augmented_antichains = {}
self._deaugmented_augmented_antichains = {}
self._next_antichains = {}
self._antichain_dag = None # 反鏈DAG
if node is not None:
self.in_edges[node.node_id] = list()
節點定義如下,裡面就是從profile獲取到的結構,比如:
- forward_compute_time : 前向傳播時間;
- backward_compute_time :反向傳播時間;
- activation_size : 啟用值大小;
- parameter_size : 引數大小;
class Node(object):
def __init__(self, node_id, node_desc="", forward_compute_time=0.0,
backward_compute_time=0.0, activation_size=0.0, parameter_size=0.0,
stage_id=None):
self.node_id = node_id
self.node_desc = node_desc
self.forward_compute_time = forward_compute_time
self.backward_compute_time = backward_compute_time
self.activation_size = activation_size
self.parameter_size = parameter_size
self.stage_id = stage_id
self.depth = None
self.height = None
我們列印出執行時看看,可以發現 Graph 的具體情況。
gr = {Graph}
# 邊
edges = {dict: 39}
'node1' = {list: 1}
0 = {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
'node4' = {list: 1}
0 = {Node} node5 -- EmuBidirLSTM( (bidir): LSTM(1024, 1024, bidirectional=True) (layer1): LSTM(1024, 1024) (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
......
# 輸入邊
in_edges = {dict: 44}
'node4' = {list: 1}
0 = {Node} node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
'node5' = {list: 2}
0 = {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
1 = {Node} node2 -- Input1 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
......
# 節點
nodes = {dict: 48}
'node1' = {Node} node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
'node4' = {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
'node5' = {Node} node5 -- EmuBidirLSTM( (bidir): LSTM(1024, 1024, bidirectional=True) (layer1): LSTM(1024, 1024) (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
......
# 前置節點
_predecessors = {dict: 36}
'node4' = {set: 0} set()
__len__ = {int} 0
'node5' = {set: 1} {<graph.graph.Node object at 0x7fb055e4bf28>}
{Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
__len__ = {int} 1
'node6' = {set: 2} {<graph.graph.Node object at 0x7fb055e4bf98>, <graph.graph.Node object at 0x7fb055e4bf28>}
{Node} node5 -- EmuBidirLSTM( (bidir): LSTM(1024, 1024, bidirectional=True) (layer1): LSTM(1024, 1024) (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
{Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
__len__ = {int} 2
'node7' = {set: 3} {<graph.graph.Node object at 0x7fb055e4bf98>, <graph.graph.Node object at 0x7fb055e4bf28>, <graph.graph.Node object at 0x7fb055e670f0>}
{Node} node5 -- EmuBidirLSTM( (bidir): LSTM(1024, 1024, bidirectional=True) (layer1): LSTM(1024, 1024) (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
{Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
{Node} node6 -- Dropout(p=0.2) -- forward_compute_time=0.077, backward_compute_time=0.196, activation_size=12582912.0, parameter_size=0.000
__len__ = {int} 3
# 其他變數
_antichain_dag = {NoneType} None
_augmented_antichains = {dict: 0} {}
_deaugmented_augmented_antichains = {dict: 0} {}
_next_antichains = {dict: 0} {}
_successors = {dict: 0} {}
2.2 構建圖
圖是由profile檔案的字串構建出來。找出來profile檔案內容我們就可以知道,具體是針對每行進行不同處理。
node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
node5 -- EmuBidirLSTM( (bidir): LSTM(1024, 1024, bidirectional=True) (layer1): LSTM(1024, 1024) (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
node1 -- node4
node4 -- node5
node2 -- node5
構建圖具體程式碼如下:
@staticmethod
def from_str(graph_str):
gr = Graph()
graph_str_lines = graph_str.strip().split('\n')
for graph_str_line in graph_str_lines: # 逐行處理
if not graph_str_line.startswith('\t'):
node = Node.from_str(graph_str_line.strip()) # 構建節點
gr.nodes[node.node_id] = node
else:
# 構建邊
[in_node_id, node_id] = graph_str_line.strip().split(" -- ")
if node_id not in gr.in_edges: # 每個節點的輸入邊
gr.in_edges[node_id] = [gr.nodes[in_node_id]]
else:
gr.in_edges[node_id].append(gr.nodes[in_node_id])
if in_node_id not in gr.edges: # 每個節點的輸出邊
gr.edges[in_node_id] = [gr.nodes[node_id]]
else:
gr.edges[in_node_id].append(gr.nodes[node_id])
return gr
構建節點具體程式碼如下:
@staticmethod
def from_str(node_str):
node_str_tokens = node_str.strip().split(" -- ")
node_id = node_str_tokens[0] # 節點名字
node_desc = node_str_tokens[1] # 節點描述
node_metadata = node_str_tokens[2] # 後設資料
stage_id = None
if len(node_str_tokens) > 3:
stage_id = int(node_str_tokens[3].split("=")[1]) # 階段資訊
[forward_compute_time, backward_compute_time, activation_size, parameter_size] = node_metadata.split(", ")
forward_compute_time = float(forward_compute_time.split("=")[1]) # 前向傳播計算時間
backward_compute_time = float(backward_compute_time.split("=")[1]) # 後向傳播計算時間
if "[" in activation_size:
activation_size = activation_size.split("=")[1] # 啟用值大小
activation_size = sum([float(x) for x in activation_size.lstrip("[").rstrip("]").split("; ")])
else:
activation_size = float(activation_size.split("=")[1])
parameter_size = float(parameter_size.split("=")[1]) # 引數大小
# 構建節點
return Node(node_id, node_desc, forward_compute_time=forward_compute_time,
backward_compute_time=backward_compute_time, activation_size=activation_size,
parameter_size=parameter_size, stage_id=stage_id)
2.3 反鏈
在有向無環圖中,有如下的一些概念:
-
鏈 :一條鏈是一些點的集合,在此鏈上的任意兩個點x, y,滿足以下條件:或者 x 能到達 y ,或者 y 能到達 x 。也可以認為是某一個偏序集S的全序子集(所謂全序是指其中任意兩個元素可以比較)
-
反鏈 :一條反鏈也是一些點的集合,在此鏈上任意兩個點x, y,滿足如下條件: x 不能到達 y,且 y 也不能到達 x。也可以認為是某一個偏序集S的子集,其中任意兩個元素不可比較。
在PipeDream的圖資料結構之中,也有反鏈的概念。反鏈節點定義如下:
class AntichainNode(Node):
def __init__(self, node_id, antichain, node_desc=""):
self.antichain = antichain
self.output_activation_size = 0.0
super(AntichainNode, self).__init__(node_id, node_desc)
因為此處過於複雜,所以我們會在下面用一節專門分析。
0x03 構建反鏈
因為本節概念比較繞,所以我們先提前劇透。
尋找某節點後續反鏈的目的就是找到下一個圖分割點 A(可能是若干node的組合),為了確定 A 的執行時間(或者其他資訊),我們需要找到 A 的增強反鏈。
此處具體程式碼位於optimizer_graph_hierarchical.py
檔案。
我們利用如下邏輯來演示:
+-------+ +-------+
| node1 | | node2 |
+---+---+ +---+---+
| |
| |
| |
v v
+---+---+ +---+---+ +-------+ +-------+
| node4 +-----> | node5 +------> | node6 +------->+ node7 |
+-------+ +-------+ +-------+ +-+-+---+
| |
| |
+-------------+ |
| |
v v
+----+--+ +---+---+
| node9 | | node8 +-----+
+-------+ +---+---+ |
| |
+---------------------------------+ |
| |
v |
+----+---+ +--------+ +--------+ |
| node10 +-----> | node11 +------> | node12 | |
+--------+ +---+----+ +----+---+ |
| | |
| | |
v v |
+---+----+ +----+---+ |
| node13 | | node14 +<---+
+--------+ +-+----+-+
| |
+------+ +---+
| |
v v
+----+---+ +--+-----+
| node15 | | node19 |
+--------+ +--------+
3.1 main函式入口
我們首先從 main 函式看起。main函式第一部分是構建反鏈和拓撲排序,具體如下:
- 從圖中移除source節點。目的是排除干擾,因為input必然在第一層,沒必要讓優化器再來選擇把輸入放在哪裡,所以先去除,後續轉換模型時候會再加上。
- 對圖的輸出進行處理,移除沒有用到的輸出。
- 得到反鏈DAG。
- 對反鏈DAG進行拓撲排序,得到一個排序好的節點列表。
具體程式碼如下:
def main(all_num_machines, profile_filename, network_bandwidths, memory_size,
straight_pipeline, use_memory_constraint, use_fewer_machines,
activation_compression_ratio, output_directory,
print_configuration=True, verbose=False):
gr = graph.Graph.from_str(open(profile_filename, 'r').read())
# Zero out all metadata associated with inputs in graph, since the optimizer
# shouldn't really get a choice with where to place the input (should always
# be in the first stage).
# 排除干擾,因為input必然在第一層,沒必要讓優化器再來選擇把輸入放在哪裡,所以先去除,後續會再加上。
sources = gr.sources() # 對圖的輸入進行處理
nodes_to_remove = OrderedDict()
for source in sources:
if source.node_desc.startswith("Input"): # 只處理input
source.forward_compute_time = 0.0
source.backward_compute_time = 0.0
source.activation_size = 0.0
source.parameter_size = 0.0
nodes_to_remove[source] = []
for out_node in gr.edges[source.node_id]:
nodes_to_remove[source].append(out_node) # 記錄這些刪除source對應了哪些out節點,因為後續還要處理
gr.remove_node(source) # 在圖中移除這些input source
# Remove all unneeded sinks that are not used, makes code generation and
# optimization easier.
sinks = gr.sinks() # 對圖的輸出進行處理,移除沒有用到的輸出
for sink in sinks:
if sink.node_desc.startswith("__getitem__"):
gr.remove_node(sink)
antichain_gr = gr.antichain_dag() # 得到反鏈DAG
states = antichain_gr.topological_sort() # 拓撲排序,得到一個排序好的節點列表
# 後續程式碼暫時省略
這裡再取出反鏈節點定義如下,可以看出來和程式碼對應關係。
class AntichainNode(Node):
def __init__(self, node_id, antichain, node_desc=""):
self.antichain = antichain
self.output_activation_size = 0.0
super(AntichainNode, self).__init__(node_id, node_desc)
3.2 增強反鏈
首先要介紹先增強反鏈概念。每個節點的增強反鏈包括:本身節點 + 部分前序節點。
這個前序節點的選取演算法是:
- 獲取本節點的全部前序節點列表;
- 如果一個前序節點的"出邊目的節點"不在全部前序節點列表,且"出邊目的節點"不為本身,則選取此前序節點為增強反鏈的一部分。
從下面圖例中可以看出來,如果某一個節點 A,其前置節點中有一個分叉節點 Z,且這個分叉之中,有一個分叉繞過了節點 A,則對於節點 A,他的增強反鏈就是 [A, Z]。
對於增強反鏈概念,可以理解為:對於節點 A,他只有把節點 Z 一起考慮,才能唯一確定自己節點的執行時間。因為如果思考節點 A 的執行時間,我理解的大致思路是:
- 因為各個階段可以流水線並行,所以 A 的執行時間應該是以下三個時間的最大值:A的計算時間,A的輸入時間,A的輸出時間。
- A 的輸入時間是以下兩個時間的最大值: X --> A 節點輸出時間,Z --> A 節點的輸出時間。
- 但是因為不清楚 Z 的內部執行機制,所以不能確定 Z 的兩個輸出之間是否有依賴關係,比如 "必須先完成 Z--> D,才能輸出 Z--> A", 所以,也需要考慮 Z --> D 的傳輸時間。
所以,需要把 [ A,Z ] 放在一起作為一個狀態考慮,事實上 PipeDream 就是這麼處理的,用 [ A,Z ] 這個狀態來統一計算。
因為作為一個狀態考慮,所以給節點 A 計算輸出啟用值大小,具體是通過遍歷其反鏈(增強反鏈)來計算,就是把其增強反鏈的前序節點給自己的輸出都疊加起來。
+-----+ +-----+
| X | | Z |
+--+--+ +--+-++
| | |
| | |
+------+ +-------+ |
| | |
v v |
++---++ |
| A | |
++-+--+ |
| | |
+---------+ | |
| | |
v v v
+---+-+ +--+--+ +-+---+
| B | | C | | D |
+-----+ +-----+ +-----+
在程式碼之中,_augmented_antichains
是增強反鏈,也是一個字典類,key是節點名字,value是 key 節點的增強反鏈,比如:
augment_antichain函式作用就是對每個節點,找到其增強反鏈。
def augment_antichain(self, antichain):
# 引數 antichain 是一個節點列表
antichain_key = tuple(sorted(antichain))
# 如果key已經在擴大反鏈之中,就直接返回對應key的增強反鏈
if antichain_key in self._augmented_antichains:
return self._augmented_antichains[antichain_key]
extra_nodes = set()
all_predecessors = set()
# 遍歷引數list之中的反鏈節點,獲取每個節點的前置節點,歸併在all_predecessors之中。
for antichain_node in antichain:
predecessors = self.predecessors(antichain_node)
all_predecessors = all_predecessors.union(predecessors)
# 遍歷引數list之中的反鏈節點
for antichain_node in antichain:
# 獲取每個反鏈節點的前置節點列表
predecessors = self.predecessors(antichain_node)
# 遍歷每個前置節點
for predecessor in predecessors:
# 看每個前置節點的出邊,如果出邊不在前置節點列表之中,且 出邊節點不等於本反鏈節點
for out_node in self.edges[predecessor.node_id]:
if out_node not in predecessors and out_node.node_id != antichain_node:
# 把這個前置節點插入到附加節點列表中
extra_nodes.add(predecessor.node_id)
# 最終把個附加節點列表插入到增強節點之中
self._augmented_antichains[antichain_key] = list(extra_nodes) + antichain
return self._augmented_antichains[antichain_key]
比如對應下圖中的邏輯,初始化之後,_augmented_antichains 就是
_augmented_antichains = {dict: 1}
('node4',) = {list: 1} ['node4']
後續迭代node 5之後,_augmented_antichains 就是
_augmented_antichains = {dict: 2}
('node4',) = {list: 1} ['node4']
('node5',) = {list: 1} ['node5']
__len__ = {int} 2
繼續迭代,增強反鏈為:
_augmented_antichains = {dict: 7}
('node4',) = {list: 1} ['node4'] # node4的增強反鏈只有自己
('node5',) = {list: 1} ['node5'] # node5的增強反鏈只有自己
('node6',) = {list: 1} ['node6']
('node7',) = {list: 1} ['node7']
('node8',) = {list: 1} ['node8']
('node10',) = {list: 2} ['node8', 'node10'] # node10的增強反鏈是'node8', 'node10'
('node14',) = {list: 1} ['node14']
('node11',) = {list: 2} ['node8', 'node11'] # node11的增強反鏈是'node8', 'node11'
('node15',) = {list: 2} ['node14', 'node15']
('node19',) = {list: 1} ['node19']
('node12',) = {list: 2} ['node8', 'node12']
('node16',) = {list: 2} ['node14', 'node16']
('node23',) = {list: 2} ['node20', 'node23']
('node17',) = {list: 2} ['node14', 'node17']
圖例中可以看出來,因為有 node 8的出邊 [node 8,node 14] 存在,對於 node 10, node 11, node 12 來說,他們必須把 node 8 加入自己的增強反鏈之中。
對於 node 10,我們可以認為,必須結合 node 8之後,node 10 才能確定 node 10 的執行時間。下面圖上標記出來了 node 10 的 augmented 反鏈(本身節點 + 部分前序節點)。
+-------+ +-------+
| node1 | | node2 |
+---+---+ +---+---+
| |
| |
| |
v v
+---+---+ +---+---+ +-------+ +-------+
| node4 +-----> | node5 +------> | node6 +------->+ node7 |
+-------+ +-------+ +-------+ +-+-+---+
| |
| |
+-------------+ |
| |
v v augmented
+----+--+ +---+---+
| node9 | | node8 +-----+
+-------+ +---+---+ |
| |
+---------------------------------+ |
| |
v |
+----+---+ +--------+ +--------+ |
antichain | node10 +-----> | node11 +------> | node12 | |
+--------+ +---+----+ +----+---+ |
augmented | | |
| | |
v v |
+---+----+ +----+---+ |
| node13 | | node14 +<---+
+--------+ +-+----+-+
| |
+------+ +---+
| |
v v
+----+---+ +--+-----+
| node15 | | node19 |
+--------+ +--------+
3.3 後續反鏈
在程式碼之中,_next_antichains 是一個字典類,key是節點名字,value是 key 節點的後續反鏈。
比如,對於 node A 來說,下一個反鏈是 [ node B, node C ],其中 node B 和 node C 彼此之間無法排序。尋找反鏈的目的就是找到下一個圖分割點。
+-----+ +-----+
| X | | Z |
+--+--+ +--+-++
| | |
| | |
+------+ +-------+ |
| | |
v v |
++---++ |
| A | |
++-+--+ |
| | |
+---------+ | |
| | |
v v v
+---+-+ +--+--+ +-+---+
| B | | C | | D |
+-----+ +-----+ +-----+
對於每個節點 antichain ,next_antichains 函式獲取其後續反鏈。
def next_antichains(self, antichain):
# 構建antichain的反鏈key,其實就是 antichain 自己作為key
antichain_key = tuple(sorted(antichain))
# 如果key已經在後續反鏈之中,則返回這個後續反鏈
if antichain_key in self._next_antichains:
return self._next_antichains[antichain_key]
next_antichains = []
antichain_set = set(antichain)
# 獲取 antichain 的增強反鏈
augmented_antichain = self.augment_antichain(antichain)
# 遍歷增強反鏈
for augmented_antichain_node in augmented_antichain:
# 遍歷增強反鏈某節點的出邊
next_nodes = self.edges[augmented_antichain_node] if augmented_antichain_node in self.edges else []
# 遍歷增強反鏈某節點的出邊
for next_node in next_nodes:
# 如果出邊節點已經在反鏈集合之中,跳過,進入下一迴圈
if next_node.node_id in antichain_set:
continue
# 如果出邊節點是後續反鏈,則假如到反鏈列表
if self.is_next_antichain(augmented_antichain, next_node.node_id):
next_antichain = self.construct_antichain(augmented_antichain,
augmented_antichain_node,
next_node.node_id)
next_antichains.append(next_antichain)
# 最終把反鏈列表設定為key對應的反鏈
self._next_antichains[antichain_key] = next_antichains
return self._next_antichains[antichain_key]
is_next_antichain 方法用來判斷某新節點是否為後續反鏈。
def is_next_antichain(self, augmented_antichain, new_node):
successors = self.successors(new_node)
augmented_antichain_set = set(augmented_antichain)
# 遍歷新節點的後續節點
for successor in successors:
# 如果後續節點有一個在增強節點之中,就返回false,說明不是後續反鏈
if successor.node_id in augmented_antichain_set:
return False
# 否則就是後續反鏈
return True
_next_antichains舉例如下,大家可以結合之前的增強反鏈對比看看。
- 以 node 10 為例,其增強節點為:[ node 8,node 10 ],
- 遍歷這些增強節點,看每一個增強節點的出邊。8 的出邊 [ node 10,node 14 ],10 的出邊是 [ node 11]。
- 所以有三個點 node 10,node 11,node 14 可以繼續看。其中node 10 已經在[ node 8,node 10 ]之中,所以不考慮。
- 用 14 呼叫 is_next_antichain。
- is_next_antichain 之中,augmented_antichain 為 [ node 8, node 10],new_node 是 node 14。
- 得到 successors 集合為 [ node31,node16,node23,node44,node48 ....] 等22個節點,這些節點都不在 [ node 8, node 10] 之中,所以 is_next_antichain 為 true,14 是後續反鏈節點之一。
- 用 11 呼叫 is_next_antichain。
- is_next_antichain 之中,augmented_antichain 為 [ node 8, node 10],new_node 是 node 11。
- 得到 successors 集合為 [ node16,node40,node23,....] 等節點,這些節點都不在 [ node 8, node 10] 之中,所以 is_next_antichain 為 true,11 是後續反鏈節點之一。
所以 node 10 的後續反鏈是 [ ['node14'] ,[ 'node11'] ]。
對比 看看,node 10 的增強反鏈是 ['node8', 'node10'],
_next_antichains = {dict: 99}
('node4',) = {list: 1} [['node5']]
('node5',) = {list: 1} [['node6']]
('node6',) = {list: 1} [['node7']]
('node7',) = {list: 1} [['node8']]
('node8',) = {list: 2} [['node10'], ['node14']]
('node10',) = {list: 2} [['node14'], ['node11']] # 這裡
('node14',) = {list: 2} [['node15'], ['node19']]
('node11',) = {list: 2} [['node14'], ['node12']]
('node15',) = {list: 2} [['node19'], ['node16']]
('node19',) = {list: 1} [['node23']]
('node12',) = {list: 2} [['node14'], ['node14']]
('node16',) = {list: 2} [['node19'], ['node17']]
具體如下圖,可以看出來,node 11和 node 14確實是 node 10的後續反鏈,就是在這兩個節點上可以對於圖進行分割。
可以這麼理解:對於 node 10 來說,下一個反鏈是 [ node 11, node 14],其中 node 11 和 node 14 彼此之間無法排序。尋找後續反鏈的目的就是找到下一個圖分割點。
+-------+ +-------+
| node1 | | node2 |
+---+---+ +---+---+
| |
| |
| |
v v
+---+---+ +---+---+ +-------+ +-------+
| node4 +-----> | node5 +------> | node6 +------->+ node7 |
+-------+ +-------+ +-------+ +-+-+---+
| |
| |
+-------------+ |
| |
v v augmented
+----+--+ +---+---+
| node9 | | node8 +-----+
+-------+ +---+---+ |
| |
+---------------------------------+ |
| |
v next |
+----+---+ +--------+ +--------+ |
antichain | node10 +-----> | node11 +------> | node12 | |
+--------+ +---+----+ +----+---+ |
augmented | | |
| | |
v next v |
+---+----+ +----+---+ |
| node13 | | node14 +<---+
+--------+ +-+----+-+
| |
+------+ +---+
| |
v v
+----+---+ +--+-----+
| node15 | | node19 |
+--------+ +--------+
3.4 總體構建
antichain_dag 的目的是依據 增強反鏈列表 和 後續反鏈列表來構建一個反鏈 DAG。
我們以上面的圖例進行講解,以 node 8 為例。
def antichain_dag(self):
if self._antichain_dag is not None:
return self._antichain_dag
antichain_dag = Graph()
antichain_id = 0
antichain = [self.sources()[0].node_id] # 獲取source第一個節點。
# 構建首節點,同時利用 augment_antichain 來往_augmented_antichains 之中新增首節點。
source_node = AntichainNode("antichain_%d" % antichain_id, self.augment_antichain(antichain))
antichain_dag.source = source_node
antichain_queue = [antichain] # 把第一個節點插入queue
antichain_mapping = {tuple(sorted(antichain)): source_node}
# 如果queue之中還有節點
while len(antichain_queue) > 0:
antichain = antichain_queue.pop(0) # 彈出第一個節點,賦值為 antichain,這裡為 node 8
# key就是由 antichain 節點名字構建,比如 antichain_key = {tuple: 1} node8
antichain_key = tuple(sorted(antichain))
# 如果 antichain_key 已經位於self._next_antichains之中,即 antichain_key 的後續反鏈已經被記錄,就跳過去
if antichain_key in self._next_antichains:
continue
# 獲取 antichain 的後續反鏈,對於8,這裡是[[10],[14]]
next_antichains = self.next_antichains(antichain)
# 遍歷後續反鏈[10,14]
for next_antichain in next_antichains:
# 下一個反鏈節點的key 10
next_antichain_key = tuple(sorted(next_antichain))
if next_antichain_key not in antichain_mapping: # 如果存在,就跳過
antichain_id += 1
# 下一反鏈節點 10 被設定為其增強節點 [ 8, 10 ]
next_antichain_node = AntichainNode("antichain_%d" % antichain_id, self.augment_antichain(next_antichain))
# 設定 antichain_mapping
antichain_mapping[next_antichain_key] = next_antichain_node
# 向 反鏈DAG 插入邊:
antichain_dag.add_edge(antichain_mapping[antichain_key],
antichain_mapping[next_antichain_key])
# 把最新反鏈節點插入queue,下次迭代使用
antichain_queue.append(next_antichain)
self._antichain_dag = antichain_dag
return antichain_dag
這裡其實目的是設定 antichain_mapping。
流程是:
- 從 antichain_queue 彈出第一個節點,賦值為 antichain,這裡為 node 8。
- 獲取 antichain 的後續反鏈,對於8,這裡是[[10],[14]]。
- 遍歷後續反鏈 [10,14]。
- 以 10 為例,設定下一個反鏈節點的key 為 10。
- 下一反鏈節點 10 被設定為其增強節點 [ 8, 10 ],即 ('node10',) = {AntichainNode} antichain_5 -- ['node8', 'node10']。
可以看到,尋找某節點後續反鏈的目的就是找到下一個圖分割點 A,然後為了確定 A 的執行時間(或者其他資訊),需要找到 A 的增強反鏈(一些增強反鏈就是一些狀態),A 的 antichain_mapping 就是其增強反鏈。
antichain_mapping 示例如下:
antichain_mapping = {dict: 99}
('node4',) = {AntichainNode} antichain_0 -- ['node4']
('node5',) = {AntichainNode} antichain_1 -- ['node5']
('node6',) = {AntichainNode} antichain_2 -- ['node6']
('node7',) = {AntichainNode} antichain_3 -- ['node7']
('node8',) = {AntichainNode} antichain_4 -- ['node8']
('node10',) = {AntichainNode} antichain_5 -- ['node8', 'node10'] # 最新設定
('node14',) = {AntichainNode} antichain_6 -- ['node14']
('node11',) = {AntichainNode} antichain_7 -- ['node8', 'node11']
('node15',) = {AntichainNode} antichain_8 -- ['node14', 'node15']
('node19',) = {AntichainNode} antichain_9 -- ['node19']
('node12',) = {AntichainNode} antichain_10 -- ['node8', 'node12']
('node16',) = {AntichainNode} antichain_11 -- ['node14', 'node16']
('node23',) = {AntichainNode} antichain_12 -- ['node20', 'node23']
('node17',) = {AntichainNode} antichain_13 -- ['node14', 'node17']
antichain_dag 示例如下,可以認為就是增強反鏈DAG:
antichain_dag = {Graph}
nodes = {dict: 99}
'antichain_0' = {AntichainNode} antichain_0 -- ['node4']
'antichain_1' = {AntichainNode} antichain_1 -- ['node5']
'antichain_2' = {AntichainNode} antichain_2 -- ['node6']
'antichain_3' = {AntichainNode} antichain_3 -- ['node7']
'antichain_4' = {AntichainNode} antichain_4 -- ['node8']
'antichain_5' = {AntichainNode} antichain_5 -- ['node8', 'node10']
'antichain_6' = {AntichainNode} antichain_6 -- ['node14']
'antichain_7' = {AntichainNode} antichain_7 -- ['node8', 'node11']
'antichain_8' = {AntichainNode} antichain_8 -- ['node14', 'node15']
'antichain_9' = {AntichainNode} antichain_9 -- ['node19']
'antichain_10' = {AntichainNode} antichain_10 -- ['node8', 'node12']
'antichain_11' = {AntichainNode} antichain_11 -- ['node14', 'node16']
'antichain_12' = {AntichainNode} antichain_12 -- ['node20', 'node23']
'antichain_13' = {AntichainNode} antichain_13 -- ['node14', 'node17']
'antichain_14' = {AntichainNode} antichain_14 -- ['node20', 'node30', 'node23']
'antichain_15' = {AntichainNode} antichain_15 -- ['node20', 'node36', 'node23']
'antichain_16' = {AntichainNode} antichain_16 -- ['node20', 'node43', 'node23']
'antichain_17' = {AntichainNode} antichain_17 -- ['node20', 'node23', 'node24']
3.5 拓撲排序
得到了增強反鏈之後,需要進行拓撲排序之後才能使用。
antichain_gr = gr.antichain_dag()
states = antichain_gr.topological_sort()
得出拓撲排序的目的是:如果按照拓撲序列的頂點次序,在到達某節點之前,可以保證它的所有前序活動都已經完成,從而整個工程順序執行,不會衝突。
在圖論中,拓撲排序(Topological Sorting)是一個有向無環圖(DAG, Directed Acyclic Graph)的所有頂點的線性序列。且該序列必須滿足下面兩個條件:
- 每個頂點出現且只出現一次。
- 若存在一條從頂點 A 到頂點 B 的路徑,那麼在序列中頂點 A 出現在頂點 B 的前面。
有向無環圖(DAG)才有拓撲排序,非DAG圖沒有拓撲排序一說。一個有向無環圖可以有一個或多個拓撲排序序列。
例如,下面這個圖:
+--------+ +--------+
| +----------------> | |
| 1 | | 4 +------------+
| | +-----------> | | |
+-----+--+ | +---+----+ |
| | | v
| | | +--+--+
| | | +---> | 5 |
| | | | +-----+
v | | |
| v |
+--------+ | +---+-----+ |
| +----+ | | |
| 2 +----------------->+ 3 +--+
| | | |
+--------+ +---------+
得到拓撲排序後的結果是 { 1, 2, 4, 3, 5 }。
這裡的拓撲排序演算法使用的是深度優先排序。
def topological_sort(self):
# Algorithm from https://en.wikipedia.org/wiki/Topological_sorting
self.sorted_nodes = []
self.marked_nodes = set()
self.temporarily_marked_nodes = set()
nodes = list(self.nodes.values())
nodes.sort(key=lambda x: x.node_desc)
for node in nodes:
if node.node_id in self.marked_nodes:
continue
self.topological_sort_helper(node.node_id)
return [self.nodes[node_id] for node_id in self.sorted_nodes]
def topological_sort_helper(self, node_id):
if node_id in self.marked_nodes:
return
if node_id in self.temporarily_marked_nodes:
raise Exception("Graph has a cycle")
self.temporarily_marked_nodes.add(node_id)
if node_id in self.edges:
out_nodes = list(self.edges[node_id])
out_nodes.sort(key=lambda x: (x.node_desc, x.height))
for out_node in out_nodes:
self.topological_sort_helper(out_node.node_id)
self.marked_nodes.add(node_id)
self.temporarily_marked_nodes.remove(node_id)
self.sorted_nodes.insert(0, node_id)
最終結果舉例如下,可以和上面的反鏈DAG antichain_dag 比對,看看異同:
states = {list: 99}
00 = {AntichainNode} antichain_0 -- ['node4']
01 = {AntichainNode} antichain_1 -- ['node5']
02 = {AntichainNode} antichain_2 -- ['node6']
03 = {AntichainNode} antichain_3 -- ['node7']
04 = {AntichainNode} antichain_4 -- ['node8']
05 = {AntichainNode} antichain_5 -- ['node8', 'node10']
06 = {AntichainNode} antichain_7 -- ['node8', 'node11']
07 = {AntichainNode} antichain_10 -- ['node8', 'node12']
08 = {AntichainNode} antichain_6 -- ['node14']
09 = {AntichainNode} antichain_8 -- ['node14', 'node15']
10 = {AntichainNode} antichain_11 -- ['node14', 'node16']
11 = {AntichainNode} antichain_13 -- ['node14', 'node17']
12 = {AntichainNode} antichain_9 -- ['node19']
13 = {AntichainNode} antichain_12 -- ['node20', 'node23']
14 = {AntichainNode} antichain_18 -- ['node23', 'node20', 'node26']
15 = {AntichainNode} antichain_17 -- ['node23', 'node20', 'node24']
16 = {AntichainNode} antichain_32 -- ['node23', 'node20', 'node28']
17 = {AntichainNode} antichain_31 -- ['node23', 'node20', 'node26', 'node24']
18 = {AntichainNode} antichain_63 -- ['node23', 'node20', 'node26', 'node28']
19 = {AntichainNode} antichain_33 -- ['node20', 'node26', 'node29']
20 = {AntichainNode} antichain_16 -- ['node20', 'node43', 'node23']
21 = {AntichainNode} antichain_30 -- ['node23', 'node20', 'node43', 'node26']
22 = {AntichainNode} antichain_29 -- ['node23', 'node20', 'node43', 'node24']
23 = {AntichainNode} antichain_59 -- ['node23', 'node20', 'node43', 'node28']
我們 也可以和如下增強反鏈比對,看到 states 就是對增強反鏈DAG進行拓撲排序之後的結果,按照這個順序進行訓練是符合邏輯的。
_augmented_antichains = {dict: 99}
('node4',) = {list: 1} ['node4']
('node5',) = {list: 1} ['node5']
('node6',) = {list: 1} ['node6']
('node7',) = {list: 1} ['node7']
('node8',) = {list: 1} ['node8']
('node10',) = {list: 2} ['node8', 'node10']
('node14',) = {list: 1} ['node14']
('node11',) = {list: 2} ['node8', 'node11']
('node15',) = {list: 2} ['node14', 'node15']
('node19',) = {list: 1} ['node19']
('node12',) = {list: 2} ['node8', 'node12']
('node16',) = {list: 2} ['node14', 'node16']
('node23',) = {list: 2} ['node20', 'node23']
('node17',) = {list: 2} ['node14', 'node17']
('node23', 'node30') = {list: 3} ['node20', 'node30', 'node23']
('node23', 'node36') = {list: 3} ['node20', 'node36', 'node23']
('node23', 'node43') = {list: 3} ['node20', 'node43', 'node23']
('node24',) = {list: 3} ['node23', 'node20', 'node24']
('node26',) = {list: 3} ['node23', 'node20', 'node26']
('node23', 'node30', 'node36') = {list: 4} ['node20', 'node36', 'node30', 'node23']
('node23', 'node30', 'node43') = {list: 4} ['node20', 'node43', 'node30', 'node23']
('node31',) = {list: 3} ['node20', 'node26', 'node31']
('node24', 'node30') = {list: 4} ['node23', 'node20', 'node30', 'node24']
('node26', 'node30') = {list: 4} ['node23', 'node20', 'node30', 'node26']
('node23', 'node36', 'node43') = {list: 4} ['node20', 'node43', 'node36', 'node23']
('node37',) = {list: 4} ['node32', 'node20', 'node26', 'node37']
('node24', 'node36') = {list: 4} ['node23', 'node20', 'node36', 'node24']
('node26', 'node36') = {list: 4} ['node23', 'node20', 'node36', 'node26']
('node44',) = {list: 2} ['node40', 'node44']
('node24', 'node43') = {list: 4} ['node23', 'node20', 'node43', 'node24']
('node26', 'node43') = {list: 4} ['node23', 'node20', 'node43', 'node26']
('node24', 'node26') = {list: 4} ['node23', 'node20', 'node26', 'node24']
3.6 總結
因為目前的演算法比較複雜,所以我們暫時總結一下目前為止的工作:
- 計算出了每個節點的增強反鏈,最終得到增強反鏈組合
_augmented_antichains
。 - 計算出了每個節點的後續反鏈。尋找某節點後續反鏈的目的就是找到下一個圖分割點 A,然後為了確定 A 的執行時間(或者其他資訊),需要找到 A 的增強反鏈(一些增強反鏈就是一些狀態)。_next_antichains 是後續反鏈組合。
- antichain_dag 函式依據
_next_antichains
和_augmented_antichains
進行處理,構建一個反鏈 DAG,就是變數 antichain_dag。 - 得到了增強反鏈DAG之後,需要進行拓撲排序之後才能使用。得出拓撲排序的目的是:如果按照拓撲序列的頂點次序,在到達某節點之前,可以保證它的所有前序活動都已經完成,從而整個工程順序執行,不會衝突。
- states 就是對增強反鏈DAG進行拓撲排序之後的結果,按照這個順序進行訓練是符合邏輯的。所以後續工作就是在 states 基礎上執行。
0x04 計算分割槽
至此,圖已經依據後續反鏈被分割成若干狀態(states),每個狀態很重要的一個屬性是其增強反鏈。states 就是對增強反鏈進行拓撲排序之後的結果,按照這個順序進行訓練是符合邏輯的。
自動分割槽演算法具體分為兩部分。
- compute_partitioning 是使用動態規劃演算法對於這些狀態得出一個最優化結果,但是沒有做具體分割槽。
- analyze_partitioning 是利用最優化結果來做具體分割槽,排序後得到了一個偏序結果。
下面我們逐一分析。
4.1 main函式的邏輯
main函式接下來與計算分割槽相關的邏輯如下:
- 為每個狀態設定index。
- 給每個狀態計算出輸出啟用值大小,具體是通過遍歷其反鏈(增強反鏈),可以認為就是其必要前序節點給自己的輸出。
- 給每個狀態計算其資訊,比如計算時間,啟用大小,引數大小等等,都是通過前置節點完成的 。
- 得到總體輸出大小 output_activation_sizes & 所有前置節點id,後面計算分割槽時候需要。
- 依據profile估計出系統內部的計算時間,compute_times_row 是 i 節點到 後續節點(i+1, i+2, ...)的計算時間,下面類似。
- 依據profile估計出系統內部的啟用值大小。
- 依據profile估計出系統內部的引數大小。
- 遍歷機器集&網路頻寬組合。流水線可以是straight(數目為1)或者並行(數目為num_machines),依據目前的資訊,以及機器數量,網路頻寬等,使用動態規劃演算法計算分割槽。假如機器集&網路頻寬組合有兩個,則會用每個組合進行一次動態規劃演算法,最後 all_As.append(A) 這裡就是兩個動態規劃的結果,就是考慮到各種必要因素之後的最優結果。
具體程式碼如下:
def main(all_num_machines, profile_filename, network_bandwidths, memory_size,
straight_pipeline, use_memory_constraint, use_fewer_machines,
activation_compression_ratio, output_directory,
print_configuration=True, verbose=False):
gr = graph.Graph.from_str(open(profile_filename, 'r').read())
# Zero out all metadata associated with inputs in graph, since the optimizer
# shouldn't really get a choice with where to place the input (should always
# be in the first stage).
# 排除干擾,因為input必然在第一層,沒必要讓優化器再來選擇把輸入放在哪裡,所以先去除,後續會再加上。
sources = gr.sources() # 對圖的輸入進行處理
nodes_to_remove = OrderedDict()
for source in sources:
if source.node_desc.startswith("Input"): # 只處理input
source.forward_compute_time = 0.0
source.backward_compute_time = 0.0
source.activation_size = 0.0
source.parameter_size = 0.0
nodes_to_remove[source] = []
for out_node in gr.edges[source.node_id]:
nodes_to_remove[source].append(out_node) # 記錄這些刪除source對應了哪些out節點,因為後續還要處理
gr.remove_node(source) # 在圖中移除這些input source
# Remove all unneeded sinks that are not used, makes code generation and
# optimization easier.
sinks = gr.sinks() # 對圖的輸出進行處理,移除沒有用到的輸出
for sink in sinks:
if sink.node_desc.startswith("__getitem__"):
gr.remove_node(sink)
antichain_gr = gr.antichain_dag() # 得到反鏈DAG
states = antichain_gr.topological_sort() # 拓撲排序,得到一個排序好的節點列表
###########################################################################
# 之前程式碼在上節分析過,我們本節從這裡繼續分析
###########################################################################
states_indices = {} # 為每個狀態設定index
for i in range(len(states)):
states_indices[states[i]] = i
##################################### 執行時如下
#states_indices = {dict: 99}
# antichain_0 -- ['node4'] = {int} 0
# antichain_1 -- ['node5'] = {int} 1
# antichain_2 -- ['node6'] = {int} 2
# antichain_3 -- ['node7'] = {int} 3
# antichain_4 -- ['node8'] = {int} 4
# ......
# 給每個狀態計算出輸出啟用值大小,具體是通過遍歷其反鏈(增強反鏈),可以認為就是其必要前序節點給自己的輸出
for i in range(len(states)):
for antichain_node in states[i].antichain:
states[i].output_activation_size += gr.nodes[antichain_node].activation_size
# 給每個狀態計算其資訊,比如計算時間,啟用大小,引數大小等等,都是通過前置節點完成的
for i in range(len(states)):
antichain = states[i].antichain
all_predecessors = gr.all_predecessors(antichain)
states[i].compute_time = 0.0
states[i].activation_size = 0.0
states[i].parameter_size = 0.0
for predecessor in all_predecessors: # 計算所有前置節點的資訊
states[i].compute_time += ((predecessor.forward_compute_time +
predecessor.backward_compute_time) / 1000.0)
states[i].activation_size += predecessor.activation_size
states[i].parameter_size += predecessor.parameter_size
gr.reset()
# 得到總體輸出大小 & 所有前置節點id,後面計算分割槽時候需要
output_activation_sizes = [state.output_activation_size for state in states]
all_predecessor_ids = [[states_indices[predecessor] for predecessor in
antichain_gr.predecessors(states[i].node_id)]
for i in range(len(states))]
##################################### 執行時如下
# output_activation_sizes = {list: 99}
# 00 = {float} 6291456.0
# 01 = {float} 12582912.0
# 02 = {float} 12582912.0
# 03 = {float} 6553600.0
# .....
# all_predecessor_ids = {list: 99}
# 00 = {list: 0} []
# 01 = {list: 1} [0]
# 02 = {list: 2} [0, 1]
# 03 = {list: 3} [0, 1, 2]
# 04 = {list: 4} [0, 1, 2, 3]
# 05 = {list: 5} [2, 3, 4, 0, 1]
# 06 = {list: 6} [2, 3, 4, 0, 1, 5]
# 07 = {list: 7} [6, 2, 3, 4, 0, 1, 5]
# ......
compute_times = [] # 初始化計算時間
activation_sizes = [] # 初始化啟用值大小
parameter_sizes = [] # 初始化引數值大小
for i in range(len(states)+1): # 具體計算每一個節點的資訊,去除他之前節點的影響
compute_times_row = []
activation_sizes_row = []
parameter_sizes_row = []
for j in range(len(states)): # 去除之前的節點
if i == 0: # 列表中第一個節點
compute_times_row.append(states[j].compute_time) # i 到 j 的計算時間
activation_sizes_row.append(states[j].activation_size)
parameter_sizes_row.append(states[j].parameter_size)
else: # 列表中後續節點
if j > (i-1):
compute_times_row.append(states[j].compute_time -
states[i-1].compute_time) # i 到 j 的計算時間
activation_sizes_row.append(states[j].activation_size -
states[i-1].activation_size)
parameter_sizes_row.append(states[j].parameter_size -
states[i-1].parameter_size)
else:
compute_times_row.append(None)
activation_sizes_row.append(None)
parameter_sizes_row.append(None)
compute_times.append(compute_times_row) # 依據profile估計出系統內部的計算時間,compute_times_row 是 i 節點到 後續節點(i+1, i+2, ...)的計算時間,下面類似
activation_sizes.append(activation_sizes_row) # 依據profile估計出系統內部的啟用值大小
parameter_sizes.append(parameter_sizes_row) # 依據profile估計出系統內部的引數大小
##################################### 執行時如下
# compute_times = {list: 100}
# 000 = {list: 99} [0.0070220000000000005, 0.012285, 0.012558, 0.021096000000,...
# 001 = {list: 99} [None, 0.005263, 0.005535999999999999, 0.014074000000000003, ...
# 002 = {list: 99} [None, None, 0.00027299999999999894, 0.008811000000000003, ...
# 003 = {list: 99} [None, None, None, 0.008538000000000004, 0.008538, ...
# 004 = {list: 99} [None, None, None, None, -3.469446951953614e-18, 0.000191999999...
counter = 1
all_As = []
num_machines_in_machine = 1 #第一個節點就是1
# all_num_machines, network_bandwidths 是使用者在輸入中指定
# 遍歷機器集&網路頻寬組合。流水線可以是straight(數目為1)或者並行(數目為num_machines)
for num_machines, network_bandwidth in zip(all_num_machines, network_bandwidths):
print("Solving optimization problem with %d machines with inter-machine bandwidth of %.2f GB/s" % (num_machines, network_bandwidth / 10**9))
import numpy as np
print(np.array(compute_times))
# 依據目前的資訊,以及機器數量,網路頻寬等計算分割槽
A = compute_partitioning(compute_times, activation_sizes, parameter_sizes,
output_activation_sizes, all_predecessor_ids,
num_machines, num_machines_in_machine,
network_bandwidth,
final_level=(counter==len(network_bandwidths)))
num_machines_in_machine = num_machines # 因為計算完了,所以設定為本階段的機器數目
for i in range(len(compute_times)): # 遍歷機器
for j in range(len(compute_times[0])): # 後續機器
compute_times[i][j] = A[i][j][-1][0] # 記錄計算時間(本階段最後一個機器的計算時間)
counter += 1
all_As.append(A) # 新增邏輯關係,就是裡面包括了不同階段的優化邏輯
print(np.array(compute_times))
# 省略後續程式碼
其中compute_times 是一個計算時間的二維陣列,也可以認為是矩陣,具體舉例如下。
[w12,w13,w14,w15], // 第一個節點到後續節點的計算時間
[None, w23,w24,w25], // 第二個節點到後續節點的計算時間
[None, None, w34, w35], // 第三個節點到後續節點的計算時間
[None, None, None, w45], // 第四個節點到後續節點的計算時間
activation_sizes 和 parameter_sizes 與之類似。
4.2 動態規劃
4.2.1 總體思路
這裡有一些動態規劃的演算法需要分析。
分割演算法試圖減少模型的整體訓練時間。對於流水線系統,這個問題等價於最小化流水線最慢階段所花費的時間。該問題具有最優化子問題性質;在給定機器計數的情況下,使吞吐量最大化的管道由子管道組成,這些子管道分別使自己這個子管道的吞吐量最大化。因此,我們可以用動態規劃來尋找這個問題的最優解。
分割槽演算法獲取profiling步驟的輸出,並計算:
1)將層劃分為多個階段,
2)每個階段的複製因子(worker數),
3)保持訓練管道繁忙的最佳動態小批量數。
PipeDream的優化器假設機器拓撲是分層的,並且可以被組織成多個級別,如下圖所示。一個級別內的頻寬是相同的,而跨級別的頻寬是不同的。我們假設 k 級由 mk 個 k-1層元件構成 ,這些元件通過頻寬為Bk的鏈路連線。在下圖中,m2=2,m1=4。此外,我們定義m0為1。即 4 個 m0 構成一個 m1, 2個 m1 構成一個 m2。
層 0 就是綠色矩形,代表最底層的計算裝置,比如GPU,4個GPU構成了一個層1(虛線矩形,代表一個伺服器),2個層1構成了一個層2(就是下圖全部模組)。
PipeDream的優化器從最低層到最高層逐步解決動態規劃問題。直觀地說,這個過程在伺服器中找到最佳分割槽,然後使用這些分割槽在伺服器之間最優地分割模型。
4.2.2 具體分析
假設 A(j, m) 表示使用m臺機器在第1層和第j層之間的最佳管道中,最慢階段所用的時間。
我們演算法的目標是找到 A(N,M) 和相應的劃分。讓T( i → j,m) 表示跨越層 i 到 j 的單級所用的時間,此時間在m臺機器上覆制。
其中:
-
max中的左項是在此階段中所有層的總計算時間,右項是此階段中所有層的總通訊時間。
-
因為計算和通訊可以重疊,所以不需要相加,直接取最大數值。
-
由1到j的由m個機器組成的最佳流水線可以是單個階段複製m次,也可以由多個階段組成。
當最佳管道包含多個階段時,它可以被分解成一個最優的子管道(由從1到 i 的 由m − m′ 個機器組成)和後續的一個單獨階段(由i+1到j 的被 m' 個機器複製組成)。因此,利用最優子問題的性質,我們得到
其中,max中:
-
第一項是第1層和第i層之間的最優子管道(由m-m'個機器組成)的最慢階段所用的時間。
-
第二項是在層 i 和 i + 1 之間傳遞啟用和梯度所用的時間。
-
第三項是最後單個階段的時間(由 m' 個資料並行的機器組成)。
我們具體看看如何計算,假設一個圖邏輯如下:
+----------------+
+-----+ | +--------+
| +-------------> | k[m_prime] | | +-----+
| i | | | +--------->+ |
| +----+ +----------------+ | j |
+-----+ | +-------->+ |
| +----------------+ | +-----+
| | | |
+--------> | k[m-m_prime] +---------+
| |
+----------------+
在 (A [i] [k] [m-m_prime] [0], last_stage_time, output_transfer_time, input_transfer_time )之中選一個最大的:
- A [i] [k] [m-m_prime] [0] :i 到 k 之間的計算時間,是已經計算好的子問題。
- last_stage_time :last_stage_time 是 (k 到 j 的計算時間) + 傳輸時間。
- 其中compute_times[k + 1] [j] 是k 到 j 的計算時間,compute_times[k + 1] 就對應了k的輸出。
- 傳輸時間是依據k 到 j 的下一階段引數大小(parameter_sizes[k + 1 ] [j])計算得出。
- 即:last_stage_time = compute_times[k + 1] +(parameter_sizes[k + 1 ] [j])
- input_transfer_time :使用 k 的輸出啟用大小計算出來的傳輸時間(就是 j 的輸入)。
- output_transfer_time :使用 j 的輸出啟用大小計算出來的傳輸時間。
因為傳輸和計算是可以重疊的,所以可以這樣取最大數值。
最後得到的 A 就是動態規劃優化的結果,其中每一個元素 A[i][j][m]
是個三元組 (min_pipeline_time, optimal_split, optimal_num_machines)
。 A[i][j][m]
表示節點 i 到 節點 j 之間的計算結果。三元組就是 (最小流水線時間,i 到 j 之間那個最佳分割點,最優機器數目)。
大致階段如下圖所示:
+----------------+
| i |
| |
| |
+--+------+------+
| |
| +----------+
A[i][k][m+m_prime][0] | |
| |
v v
+-----------------+-------+ +----+--------+
| k[m-m_prime] | | k[m_prime] |
| | | |
last_stage_time = compute_times[k+1][j] | | | |
+ (parameter_sizes[k+1][j]) | output_activation_sizes | | |
| | | |
| | | |
+-----------------+-------+ +-----+-------+
input_transfer_time | |
| +-----------+
| |
| |
v v
+------------+------+------+
| j |
| |
| |
| |
| output_activation_sizes |
| |
+------------------+-------+
output_transfer_time |
|
|
v
具體程式碼如下:
def compute_partitioning(compute_times, activation_sizes, parameter_sizes,
output_activation_sizes, all_predecessor_ids,
num_machines, num_machines_within_machine,
bandwidth, final_level=True):
# 初始化
A = []
for i in range(len(compute_times)): # 遍歷所有節點
row_A = []
for j in range(len(compute_times[0])): # 所有後續節點(即第一個節點的所有後續節點)
row_row_A = []
for m in range(num_machines): # 機器數目
row_row_A.append((None, None, None))
row_A.append(row_row_A)
A.append(row_A)
# 得到計算時間
for i in range(len(compute_times)): # 遍歷所有節點
for j in range(i, len(compute_times[0])): # 所有後續節點
cum_compute_time = compute_times[i][j] # i --> j 的計算時間
cum_activation_size = activation_sizes[i][j] # i --> j 的啟用大小
cum_parameter_size = parameter_sizes[i][j] # i --> j 的引數大小
max_m = 1 if straight_pipeline else num_machines # 線性還是並行流水線
for m in range(max_m): # 遍歷流水線下一階段的機器
# 儲存的資料大小
stashed_data_size = math.ceil((num_machines - (m+1)) / (m+1)) * \
(cum_activation_size + cum_parameter_size)
# memory_size 是使用者傳進來的引數,就是每個機器有效的記憶體
# use_memory_constraint 也是使用者傳進來的引數,就是使用的記憶體限制
if use_memory_constraint and stashed_data_size > memory_size:
continue
# 資料並行通訊時間依據引數尺寸,頻寬,下一階段機器數量計算
data_parallel_communication_time = (4 * m * cum_parameter_size) / (bandwidth * (m+1))
# 除以本階段機器數量,如果本階段機器多,當然就是分開計算了
data_parallel_communication_time /= num_machines_within_machine
if cum_compute_time is None:
# 需要計算下一階段中,每個機器的計算時間,所以還要除以(m+1)
A[i][j][m] = (None, None, None) # 直接賦值
else:
# 三元組,分別是[(計算時間 + 通訊時間), None,(m+1)],對應的意義是 min_pipeline_time, optimal_split, optimal_num_machines,就對應了前面的公式 2
A[i][j][m] = (sum([cum_compute_time,
data_parallel_communication_time]) / (m+1), None, (m+1))
# 需要得到最小計算時間
min_machines = 1
max_i = len(compute_times) if not final_level else 1
for i in range(max_i): # 遍歷節點
for m in range(min_machines, num_machines): # 遍歷下一階段機器的可能選擇
for j in range(i+1, len(compute_times[0])): # 遍歷 i 的後續節點
(min_pipeline_time, optimal_split, optimal_num_machines) = A[i][j][m]
if use_fewer_machines and m > 0 and ( # 如果設定了用盡量少的機器,則如果小於min_pipeline_time,就設定新的 min_pipeline_time
min_pipeline_time is None or A[i][j][m-1][0] < min_pipeline_time):
(min_pipeline_time, optimal_split, optimal_num_machines) = A[i][j][m-1]
# 遍歷 j 節點的前置機器 k,注意,j 是 i 的後續節點之一
# 就是在 i --> k --> j 之間找到一個計算時間最小的,其中A[i][k][m-m_prime][0]已經是一個最優子問題了
for k in all_predecessor_ids[j]:
# 如果k已經在之前計算過了,就跳過
if i > 0 and k in all_predecessor_ids[i-1]:
continue
# 設定質數
max_m_prime = 2 if straight_pipeline else (m+1)
for m_prime in range(1, max_m_prime): # prime就是看看如何分割
# 輸入傳輸時間 input_transfer_time 使用 k 的輸出啟用尺寸計算
input_transfer_time = (2.0 * output_activation_sizes[k]) / \
(bandwidth * m_prime)
# 輸出傳輸時間 output_transfer_time 使用 j 的輸出啟用尺寸計算
output_transfer_time = None
if j < len(output_activation_sizes) -1:
output_transfer_time = (2.0 *
output_activation_sizes[j]) / (bandwidth * m_prime)
# last_stage_time 設定為 k 到 j 的計算時間, compute_times[k+1] 就對應了k的輸出
last_stage_time = compute_times[k+1][j]
if last_stage_time is None:
continue
# 設定為 k 到 j 的下一階段引數尺寸
last_stage_parameter_size = parameter_sizes[k+1][j]
# 設定為 k 到 j 的儲存資料尺寸
stashed_data_size = (activation_sizes[k+1][j]) + last_stage_parameter_size
# 依據機器資料計算
stashed_data_size *= math.ceil((num_machines - (m+1)) / m_prime)
# 超過機器記憶體就跳過
if use_memory_constraint and stashed_data_size > memory_size:
continue
# 加上傳輸時間,所以 last_stage_time 是 (k 到 j 的計算時間) + 傳輸時間
last_stage_time = sum([last_stage_time,
((4 * (m_prime - 1) *
last_stage_parameter_size) / (bandwidth * m_prime))])
last_stage_time /= m_prime
# 如果從i到k沒有邊,則跳過
if A[i][k][m-m_prime][0] is None:
continue
# 如果i到k已經有計算時間,則選一個較大的
pipeline_time = max(A[i][k][m-m_prime][0], last_stage_time)
if activation_compression_ratio is not None: # 如果壓縮
# 在(A[i][k][m-m_prime][0], last_stage_time, output_transfer_time, input_transfer_time 之中選一個最大的)
input_transfer_time /= activation_compression_ratio
# output_transfer_time 也壓縮
if output_transfer_time is not None:
output_transfer_time /= activation_compression_ratio
# 選一個大的
pipeline_time = max(pipeline_time, input_transfer_time)
if output_transfer_time is not None:
pipeline_time = max(pipeline_time, output_transfer_time)
# 如果比min_pipeline_time小,則設定 min_pipeline_time,為了下一次迴圈
if min_pipeline_time is None or min_pipeline_time > pipeline_time:
optimal_split = (k, m-m_prime) # 選一個優化分割點
optimal_num_machines = m_prime
min_pipeline_time = pipeline_time
# 設定
A[i][j][m] = (min_pipeline_time, optimal_split, optimal_num_machines)
return A
all_As 就是動態規劃的結果,示例如下:
all_As = {list: 2}
0 = {list: 100}
000 = {list: 99}
00 = {list: 5} [(0.0070220000000000005, None, 1), (0.1689894, None, 2), (0.14943257777777777, None, 3), (0.1258643, None, 4), (0.107310576, None, 5)]
01 = {list: 5} [(0.012285, None, 1), (0.0070220000000000005, (0, 0), 1), (0.0865995, (0, 0), 2), (0.07639255555555556, (0, 0), 3), (0.06429175000000001, (0, 0), 4)]
02 = {list: 5} [(0.012558, None, 1), (0.0070220000000000005, (0, 0), 1), (0.0070220000000000005, (1, 1), 1), (0.0070220000000000005, (1, 1), 2), (0.0070220000000000005, (1, 1), 3)]
03 = {list: 5} [(0.021096, None, 1), (0.012285, (1, 0), 1), (0.008538, (2, 1), 1), (0.008538, (2, 2), 1), (0.008538, (2, 3), 1)]
......
__len__ = {int} 100
1 = {list: 100}
000 = {list: 99}
00 = {list: 5} [(0.107310576, None, 1), (0.080131832, None, 2), (0.05930489777777778, None, 3), (0.046685052000000005, None, 4), (0.03840710336000001, None, 5)]
01 = {list: 5} [(0.06429175000000001, None, 1), (0.072057299, None, 2), (0.05690740466666667, None, 3), (0.0460065055, None, 4), (0.03840166136, None, 5)]
02 = {list: 5} [(0.0070220000000000005, None, 1), (0.043422424, None, 2), (0.037817488, None, 3), (0.031689068, None, 4), (0.026947711359999998, None, 5)]
03 = {list: 5} [(0.008538, None, 1), (0.0419991328, (2, 0), 1), (0.043422424, (2, 1), 1), (0.0396227304, None, 4), (0.033697556608, None, 5)]
......
__len__ = {int} 100
__len__ = {int} 2
4.2.3 區別
我們接下來要分析程式碼作者兩個相似名字變數之間的區別。
activation_sizes :某個節點所有前置節點的activation_size 之和。
for predecessor in all_predecessors:
states[i].compute_time += ((predecessor.forward_compute_time +
predecessor.backward_compute_time) / 1000.0)
states[i].activation_size += predecessor.activation_size
states[i].parameter_size += predecessor.parameter_size
用來計算stashed資料大小,用來看看是否超過了節點配置的記憶體額度。
stashed_data_size = (activation_sizes[k+1][j]) + last_stage_parameter_size
stashed_data_size *= math.ceil((num_machines - (m+1)) / m_prime)
if use_memory_constraint and stashed_data_size > memory_size:
continue
output_activation_sizes : 某個節點所有增強反鏈的activation_size之和。
for i in range(len(states)):
for antichain_node in states[i].antichain:
states[i].output_activation_size += gr.nodes[antichain_node].activation_size
用來計算輸出傳播時間和輸入傳播時間。
input_transfer_time = (2.0 * output_activation_sizes[k]) / \
(bandwidth * m_prime)
output_transfer_time = None
if j < len(output_activation_sizes) -1:
output_transfer_time = (2.0 *
output_activation_sizes[j]) / (bandwidth * m_prime)
0x05 分析分割槽
5.1 main函式邏輯
前面計算分割槽只是得到了一個動態規劃優化結果,需要在analyze_partitioning之中進行分析劃分之後,賦予到各個層(stage)。
main函式接下來與計算分割槽相關的邏輯如下:
- states是反鏈DAG的結果,all_As 就是動態規劃得到的優化結果,可能是多個。
- splits 初始化時候就只有一個二元組元素:最初的劃分 (0, len(states))。
- 遍歷all_As的動態優化結果,對於每個動態優化結果,遍歷其各個邏輯關係,呼叫 analyze_partitioning 對分割槽進行分析,在splits分割中遍歷,splits會逐步更新(分割點逐步逐階段細化),analyze_partitioning 返回一個 partial_splits。
- 遍歷 partial_splits,對於每一個分割點,獲取其增強反鏈(states)的所有前置節點,給這些節點打上stage_id。這裡是從前往後遍歷,所以stage_id數值是逐步增加。
- 把圖寫到檔案之中。後續 convert_graph_to_model.py 會把這個檔案轉換成模型。
- 做分析對比。
具體程式碼如下:
def main(all_num_machines, profile_filename, network_bandwidths, memory_size,
straight_pipeline, use_memory_constraint, use_fewer_machines,
activation_compression_ratio, output_directory,
print_configuration=True, verbose=False):
gr = graph.Graph.from_str(open(profile_filename, 'r').read())
# Zero out all metadata associated with inputs in graph, since the optimizer
# shouldn't really get a choice with where to place the input (should always
# be in the first stage).
# 排除干擾,因為input必然在第一層,沒必要讓優化器再來選擇把輸入放在哪裡,所以先去除,後續會再加上。
sources = gr.sources() # 對圖的輸入進行處理
nodes_to_remove = OrderedDict()
for source in sources:
if source.node_desc.startswith("Input"): # 只處理input
source.forward_compute_time = 0.0
source.backward_compute_time = 0.0
source.activation_size = 0.0
source.parameter_size = 0.0
nodes_to_remove[source] = []
for out_node in gr.edges[source.node_id]:
nodes_to_remove[source].append(out_node) # 記錄這些刪除source對應了哪些out節點,因為後續還要處理
gr.remove_node(source) # 在圖中移除這些input source
# Remove all unneeded sinks that are not used, makes code generation and
# optimization easier.
sinks = gr.sinks() # 對圖的輸出進行處理,移除沒有用到的輸出
for sink in sinks:
if sink.node_desc.startswith("__getitem__"):
gr.remove_node(sink)
antichain_gr = gr.antichain_dag() # 得到反鏈DAG
states = antichain_gr.topological_sort() # 拓撲排序,得到一個排序好的節點列表
###########################################################################
# 計算階段
###########################################################################
states_indices = {} # 為每個狀態設定index
for i in range(len(states)):
states_indices[states[i]] = i
##################################### 執行時如下
#states_indices = {dict: 99}
# antichain_0 -- ['node4'] = {int} 0
# antichain_1 -- ['node5'] = {int} 1
# antichain_2 -- ['node6'] = {int} 2
# antichain_3 -- ['node7'] = {int} 3
# antichain_4 -- ['node8'] = {int} 4
# ......
# 給每個狀態計算出輸出啟用值大小,具體是通過遍歷其反鏈(增強反鏈),可以認為就是其必要前序節點給自己的輸出
for i in range(len(states)):
for antichain_node in states[i].antichain:
states[i].output_activation_size += gr.nodes[antichain_node].activation_size
# 給每個狀態計算其資訊,比如計算時間,啟用大小,引數大小等等,都是通過前置節點完成的
for i in range(len(states)):
antichain = states[i].antichain
all_predecessors = gr.all_predecessors(antichain)
states[i].compute_time = 0.0
states[i].activation_size = 0.0
states[i].parameter_size = 0.0
for predecessor in all_predecessors: # 計算所有前置節點的資訊
states[i].compute_time += ((predecessor.forward_compute_time +
predecessor.backward_compute_time) / 1000.0)
states[i].activation_size += predecessor.activation_size
states[i].parameter_size += predecessor.parameter_size
gr.reset()
# 得到總體輸出大小 & 所有前置節點id,後面計算分割槽時候需要
output_activation_sizes = [state.output_activation_size for state in states]
all_predecessor_ids = [[states_indices[predecessor] for predecessor in
antichain_gr.predecessors(states[i].node_id)]
for i in range(len(states))]
##################################### 執行時如下
# output_activation_sizes = {list: 99}
# 00 = {float} 6291456.0
# 01 = {float} 12582912.0
# 02 = {float} 12582912.0
# 03 = {float} 6553600.0
# .....
# all_predecessor_ids = {list: 99}
# 00 = {list: 0} []
# 01 = {list: 1} [0]
# 02 = {list: 2} [0, 1]
# 03 = {list: 3} [0, 1, 2]
# 04 = {list: 4} [0, 1, 2, 3]
# 05 = {list: 5} [2, 3, 4, 0, 1]
# 06 = {list: 6} [2, 3, 4, 0, 1, 5]
# 07 = {list: 7} [6, 2, 3, 4, 0, 1, 5]
# ......
compute_times = [] # 初始化計算時間
activation_sizes = [] # 初始化啟用值大小
parameter_sizes = [] # 初始化引數值大小
for i in range(len(states)+1): # 具體計算每一個節點的資訊,去除他之前節點的影響
compute_times_row = []
activation_sizes_row = []
parameter_sizes_row = []
for j in range(len(states)): # 去除之前的節點
if i == 0: # 列表中第一個節點
compute_times_row.append(states[j].compute_time) # i 到 j 的計算時間
activation_sizes_row.append(states[j].activation_size)
parameter_sizes_row.append(states[j].parameter_size)
else: # 列表中後續節點
if j > (i-1):
compute_times_row.append(states[j].compute_time -
states[i-1].compute_time) # i 到 j 的計算時間
activation_sizes_row.append(states[j].activation_size -
states[i-1].activation_size)
parameter_sizes_row.append(states[j].parameter_size -
states[i-1].parameter_size)
else:
compute_times_row.append(None)
activation_sizes_row.append(None)
parameter_sizes_row.append(None)
compute_times.append(compute_times_row) # 依據profile估計出系統內部的計算時間,compute_times_row 是 i 節點到 後續節點(i+1, i+2, ...)的計算時間,下面類似
activation_sizes.append(activation_sizes_row) # 依據profile估計出系統內部的啟用值大小
parameter_sizes.append(parameter_sizes_row) # 依據profile估計出系統內部的引數大小
##################################### 執行時如下
# compute_times = {list: 100}
# 000 = {list: 99} [0.0070220000000000005, 0.012285, 0.012558, 0.021096000000,...
# 001 = {list: 99} [None, 0.005263, 0.005535999999999999, 0.014074000000000003, ...
# 002 = {list: 99} [None, None, 0.00027299999999999894, 0.008811000000000003, ...
# 003 = {list: 99} [None, None, None, 0.008538000000000004, 0.008538, ...
# 004 = {list: 99} [None, None, None, None, -3.469446951953614e-18, 0.000191999999...
counter = 1
all_As = []
num_machines_in_machine = 1 #第一個節點就是1
# all_num_machines, network_bandwidths 是使用者在輸入中指定
# 遍歷機器集&網路頻寬組合。流水線可以是straight(數目為1)或者並行(數目為num_machines)
for num_machines, network_bandwidth in zip(all_num_machines, network_bandwidths):
print("Solving optimization problem with %d machines with inter-machine bandwidth of %.2f GB/s" % (num_machines, network_bandwidth / 10**9))
import numpy as np
print(np.array(compute_times))
# 依據目前的資訊,以及機器數量,網路頻寬等計算分割槽
A = compute_partitioning(compute_times, activation_sizes, parameter_sizes,
output_activation_sizes, all_predecessor_ids,
num_machines, num_machines_in_machine,
network_bandwidth,
final_level=(counter==len(network_bandwidths)))
num_machines_in_machine = num_machines # 因為計算完了,所以設定為本階段的機器數目
for i in range(len(compute_times)): # 遍歷機器
for j in range(len(compute_times[0])): # 後續機器
compute_times[i][j] = A[i][j][-1][0] # 記錄計算時間(本階段最後一個機器的計算時間)
counter += 1
all_As.append(A) # 新增邏輯關係,就是裡面包括了不同階段的優化邏輯
print(np.array(compute_times))
###########################################################################
# 我們從這裡繼續分析
###########################################################################
# 分析階段
# 在 analyze_partitioning 內部做了具體分析
# 這裡最重要的是對 gr.all_predecessors 做設定,就是設定 gr 之中每個node的stage_id,這樣就是利用stage_id把初始流水線重新劃分
splits = [(0, len(states))] # 如何分割,states是反鏈DAG的結果,所以 splits 初始化時候就只有一個二元組元素:最初的劃分 (0, len(states))
i = len(all_As) - 1 # all_As 就是動態規劃得到的優化結果
while i >= 0: # 遍歷優化的出來的各個邏輯關係
print("======================================")
print("Level %d" % (i+1))
print("======================================")
new_splits = []
stage_id = 0 # 在後續的convert_graph_to_model.py 之中會使用到
for (start, end) in splits: # 在分割中遍歷,splits會逐步更新
# 依據新的splits中的二元組重新計算
partial_splits = \
analyze_partitioning(all_As[i], states, start, end,
network_bandwidths[i], all_num_machines[i],
activation_compression_ratio,
print_configuration, verbose)
start_point = start # 起始點
for split in partial_splits: # 遍歷分析得出的節點
new_splits.append((start_point, split)) # 新增一個新的二元祖
if i == 0:
predecessors = gr.all_predecessors(states[split-1].antichain)
for predecessor in predecessors:
if predecessor.stage_id is None:
predecessor.set_stage_id(stage_id) # 設定所在階段
start_point = split # 下一個階段
stage_id += 1 # 增加所在階段
new_splits.append((start_point, end)) # 新增一個新的二元祖
if i == 0:
predecessors = gr.all_predecessors(states[end-1].antichain)
for predecessor in predecessors:
if predecessor.stage_id is None:
predecessor.set_stage_id(stage_id) # 設定所在階段
stage_id += 1 # 增加所在階段
print("Total number of stages: %d" % stage_id)
splits = new_splits # 加入新的分割
i -= 1
# 以下是為了把圖寫到檔案之中。後續convert_graph_to_model.py會把這個檔案轉換成模型
for source in nodes_to_remove: # 之前移除了input節點,現在需要加回到圖中
for out_node in nodes_to_remove[source]: # input對應的哪些輸出
source.stage_id = 0
gr.add_edge(source, out_node)
if output_directory is not None:
total_num_machines = 1
for num_machines in all_num_machines:
total_num_machines *= num_machines
gr.to_dot(os.path.join(output_directory, "gpus=%d" % total_num_machines))
gr_str = str(gr)
with open(os.path.join(output_directory, "gpus=%d.txt" % total_num_machines), 'w') as f:
f.write(gr_str)
# 以下是為了做分析對比
# 計算資料並行需要的時間,以便接下來做比較,這個時間要比動態規劃時間長。
total_time = states[-1].compute_time # 最後一個階段的計算時間,是沒有經過優化的最初計算時間
total_parameter_size = states[-1].parameter_size
data_parallel_total_time = total_time # 先賦值為最後一階段的計算時間
num_machines_in_machine = 1 # 本階段的機器數目
# 遍歷流水線上各個階段,因為沒有優化,所以就是嚴格按照使用者原始配置的流水線階段來逐一計算
for (num_machines, network_bandwidth) in zip(all_num_machines, network_bandwidths):
# 計算傳輸時間。num_machines是下一階段流水線機器數目,所以頻寬需要乘以這個數字
data_parallel_communication_time = (
(4 * (num_machines - 1) * total_parameter_size) /
(network_bandwidth * num_machines)) / num_machines_in_machine
# 總時間需要加上傳輸時間
data_parallel_total_time = sum(
[data_parallel_total_time, data_parallel_communication_time]) / num_machines
# 下個迭代中,本階段的機器數目需要設定為num_machines
num_machines_in_machine = num_machines
# 這個是用動態規劃演算法得出來的優化時間
pipeline_parallel_total_time = A[0][len(states)-1][num_machines-1][0]
# 可以看到使用者需要注意哪些資料
if verbose:
print()
print("Time taken by single-stage pipeline:", total_time)
print("Time per stage in pipeline:", pipeline_parallel_total_time)
print("Throughput increase (compared to single machine):",
total_time / pipeline_parallel_total_time)
dp_str = ",".join([str(elem) for elem in all_num_machines])
print(("[Note that single-machine and (%s)-machine DP might not fit "
"given memory constraints]") % dp_str)
print("Throughput increase of (%s)-machine DP compared to single "
"machine:" % dp_str, total_time / data_parallel_total_time)
print("Throughput increase (compared to (%s)-machine DP):" % dp_str,
data_parallel_total_time / pipeline_parallel_total_time)
return pipeline_parallel_total_time, data_parallel_total_time
5.2 分析階段
分析階段具體可以參見下面註釋。
def analyze_partitioning(A, states, start, end, network_bandwidth, num_machines,
activation_compression_ratio, print_configuration, verbose):
# start,end 是本組節點的起始點,終止點
metadata = A[start][end-1][num_machines-1] # 這是個三元組 (min_pipeline_time, optimal_split, optimal_num_machines)
next_split = metadata[1] # metadata[1] 是 optimal_split,即 (k, m-m_prime)
remaining_machines_left = num_machines
splits = []
replication_factors = []
prev_split = end - 1 # 前一個分割點
while next_split is not None: #是否繼續分割
num_machines_used = metadata[2] # optimal_num_machines
if verbose:
print("-------------------------------------")
print("Number of machines used: %d..." % num_machines_used)
print("Split between layers %d and %d..." % (next_split[0], next_split[0] + 1))
print("Split before antichain %s..." % (states[next_split[0]+1].antichain))
splits.append(next_split[0]+1) # 得到了 k + 1,這是關鍵點,因為最後返回的是splits
compute_time = states[prev_split-1].compute_time - \
states[next_split[0]].compute_time
parameter_size = states[prev_split-1].parameter_size - \
states[next_split[0]].parameter_size
dp_communication_time = (4 * (num_machines_used - 1) * parameter_size) \
/ (network_bandwidth * num_machines_used)
pp_communication_time_input = ( # 下個階段的資料輸入時間
2.0 * states[next_split[0]].output_activation_size *
(1.0 / float(num_machines_used))) / network_bandwidth
pp_communication_time_output = ( # 上個階段的資料輸出時間
2.0 * states[prev_split-1].output_activation_size *
(1.0 / float(num_machines_used))) / network_bandwidth
# 如果需要壓縮,就進行壓縮
if activation_compression_ratio is not None:
pp_communication_time_input /= activation_compression_ratio
pp_communication_time_output /= activation_compression_ratio
if activation_compression_ratio is None:
pp_communication_time_input = 0.0
pp_communication_time_output = 0.0
compute_time /= num_machines_used # 本階段計算時間
dp_communication_time /= num_machines_used # 資料並行時間
if verbose:
print(("Compute time = %f, Data-parallel communication time = %f, "
"Pipeline-parallel communication time = %f...") % (
compute_time, dp_communication_time,
max(pp_communication_time_input, pp_communication_time_output)))
prev_split = splits[-1] # 設定新的前一分割點
# next_split 格式是 (k, m-m_prime),就是 optimal_split 的格式
# A[i][j][m] 格式是 (min_pipeline_time, optimal_split, optimal_num_machines)
metadata = A[start][next_split[0]][next_split[1]]
next_split = metadata[1] # 設定新的下一次分割點,就是 optimal_split
replication_factors.append(num_machines_used) # 每個階段的 replication factor
remaining_machines_left -= num_machines_used # 剩餘機器
if verbose:
print("-------------------------------------")
print("Number of machines used: %d..." % metadata[2])
#
num_machines_used = metadata[2]
remaining_machines_left -= num_machines_used # 剩餘的機器
compute_time = states[prev_split-1].compute_time
parameter_size = states[prev_split-1].parameter_size
dp_communication_time = ((4 * (num_machines_used - 1) * parameter_size) /
(network_bandwidth * num_machines_used))
compute_time /= num_machines_used # 計算時間
dp_communication_time /= num_machines_used # 資料並行通訊時間
if verbose:
print("Compute time = %f, Data-parallel communication time = %f..." %
(compute_time, dp_communication_time))
print("-------------------------------------")
if print_configuration:
print("Number of machines in budget not used: %d..." %
remaining_machines_left)
print()
print("(Split start, split end) / compute time taken per stage "
"/ replication factor per stage:")
# 下面就是列印 (Split start, split end) / compute time taken per stage / replication factor per stage
prev_split = start
splits.reverse() #
splits.append(end)
replication_factors.append(num_machines_used)
replication_factors.reverse()
for i in range(len(splits)):
time = 0.0
if prev_split > 0:
time = states[splits[i]-1].compute_time - states[prev_split-1].compute_time
else:
time = states[splits[i]-1].compute_time
if print_configuration:
print((prev_split, splits[i]), time, replication_factors[i])
prev_split = splits[i]
if print_configuration:
print()
return splits[:-1] # 最後一個不返回
我們還是用樣例進行說明。
這裡是從後面進行分割,舉例分析如下,這裡設定了總機器數目為10:
回憶在計算分割槽之中,A[i][j][m] = (min_pipeline_time, optimal_split, optimal_num_machines),optimal_split = (k, m-m_prime)
是一個本階段優化點。
所以在本函式之中,start = 0, end = 99,所以 metadata 為A[0][99][10]
,即 (0.01903199999999998, (95, 8), 1),next_split = (95, 8),prev_split = end - 1 = 98。
next_split 就是下一個分割點,splits 是目前的分割序列。
第一輪while迴圈:
因為next_split = (95, 8),所以 splits = append(next_split[0]+1) = [96],因此計算 states[prev_split-1] - states[next_split[0]] = state[97] - state[95]。這樣把0~99分成了 0 ~95 和 96 ~ 99。
然後 prev_split = 96,去找A[ 0 ] [ 95] [8] 得到 meta = (0.019031999999999993, (78, 7), 1),next_split = (78, 7)。
所以下一輪從78這個分割點開始分割。
第二輪while迴圈:
因為next_split = (78, 7),所以 splits = [96, 79],這就是新的分割序列。,因此計算 states[96-1] - states[next_split[0]] = state[96] - state[78]。這樣就使用 splits = [96, 79] 把0~99分成了 0 ~78,79 ~ 95 和 96 ~ 99。
然後 prev_split =79,去找A[ 0 ] [ 78 ] [ 7 ] 得到 meta = (0.011081, (48, 6), 1),next_split = (48, 6)。
所以下一輪從 48 這個分割點開始分割,以此類推。
while迴圈之後,得到 splits = [96, 79, 49, 15, 12, 7, 5, 3, 1]。
於是下面程式碼需要把順序調整過來。
prev_split = start
splits.reverse()
splits.append(end)
replication_factors.append(num_machines_used)
replication_factors.reverse()
得到:splits = { 1,3,5,7,12,15,49,79,96 }。然後加上 end = 99。
最後返回 splits[:-1],即返回 { 1,3,5,7,12,15,49,79,96 },去掉剛剛新增的end。
而依據 { 1,3,5,7,12,15,49,79,96 } 得到的最終分割序列 是 [(0, 1), (1, 3), (3, 5), (5, 7), (7, 12), (12, 15), (15, 49), (49, 79), (79, 96), (96, 99)],這個列表會在後續"設定stage"之中會用到。
5.3 設定stage
目前我們得到了一個理想分割序列,但是事情沒有結束,我們回憶一下分割槽演算法的目的:依據profile結果確定所有層的執行時間,然後使用動態規劃對模型進行劃分,將模型劃分為不同的stage,以及得到每個stage的replication數。
所以,分析的最終目的是給模型的每一個子層分配一個stage,如果某些子層屬於同一個stage,這些子層最終就被分配到同一個worker(節點)上執行。
因為這裡涉及到多個子網,所以我們依然用例項來分析。
如果分成了兩個子網,假設:
all_num_machines = [5,5]
network_bandwidths = [800000000, 1000000000]
初始化 splits = [0,99]。
第一輪 while 中,i = 1,
對於 splits 結果[(0, 99)] 遍歷,每一段應用analyze_partitioning,得到 partial_splits 為 [3, 6, 30, 75, 99]。
最後,splits 更新為:[(0, 3), (3, 6), (6, 30), (30, 75), (75, 99)]。
此時不會設定stage_id。
第二輪 while 中,i = 0,
對於第一輪的 splits 結果 [(0, 3), (3, 6), (6, 30), (30, 75), (75, 99)] 進行遍歷,對於這裡的每一段也應用 analyze_partitioning,比如對 (0,3) 應用analyze_partitioning,對 (3,6) 應用 analyze_partitioning,對(6,30) 也應用 analyze_partitioning,......,最後得到新的 partial_splits 為 [1, 2, 3, 4, 5, 6, 8, 10, 13, 28, 30, 45, 49, 51, 75, 79, 96, 99]。
最後,splits 更新為:[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 8), (8, 10), (10, 13), (13, 28), (28, 30), (30, 45), (45, 49), (49, 51), (51, 75), (75, 79), (79, 96), (96, 99)]。
這個列表就是理想分割序列。
在此輪中,得到了partial_splits之後,會遍歷 for split in partial_splits:
然後對於每一個 split,利用
states[split-1].antichain
獲取其增強反鏈的所有前置節點,給這些節點打上 split 對應的 stage_id。
回憶一下增強反鏈的意義:
- 每個節點的增強反鏈包括:本身節點 + 部分前序節點。
- 對於增強反鏈概念,可以理解為:對於節點 A,他只有把節點 Z 一起考慮,才能唯一確定自己節點的執行時間。
所以,對於 split = 1,1 - 1 = 0,於是就得到 states[0].antichain
,就是 'node4',那麼 'node4' 自己被打上了一個stage_id=0,說明 'node4' 被分到了一個 "與stage_id=0 所對應" 的 worker 節點上訓練。
如果有疑問,我們回憶一下state如何構建,就是有序的 "節點組合"。
antichain_gr = gr.antichain_dag()
states = antichain_gr.topological_sort()
具體如下。
states = {list: 99}
00 = {AntichainNode} antichain_0 -- ['node4'] # states[0].antichain
01 = {AntichainNode} antichain_1 -- ['node5']
02 = {AntichainNode} antichain_2 -- ['node6']
03 = {AntichainNode} antichain_3 -- ['node7']
04 = {AntichainNode} antichain_4 -- ['node8']
05 = {AntichainNode} antichain_5 -- ['node8', 'node10']
06 = {AntichainNode} antichain_7 -- ['node8', 'node11']
07 = {AntichainNode} antichain_10 -- ['node8', 'node12']
08 = {AntichainNode} antichain_6 -- ['node14']
09 = {AntichainNode} antichain_8 -- ['node14', 'node15']
10 = {AntichainNode} antichain_11 -- ['node14', 'node16']
11 = {AntichainNode} antichain_13 -- ['node14', 'node17']
12 = {AntichainNode} antichain_9 -- ['node19']
13 = {AntichainNode} antichain_12 -- ['node20', 'node23']
14 = {AntichainNode} antichain_18 -- ['node23', 'node20', 'node26']
15 = {AntichainNode} antichain_17 -- ['node23', 'node20', 'node24']
16 = {AntichainNode} antichain_32 -- ['node23', 'node20', 'node28']
17 = {AntichainNode} antichain_31 -- ['node23', 'node20', 'node26', 'node24']
18 = {AntichainNode} antichain_63 -- ['node23', 'node20', 'node26', 'node28']
19 = {AntichainNode} antichain_33 -- ['node20', 'node26', 'node29']
20 = {AntichainNode} antichain_16 -- ['node20', 'node43', 'node23']
21 = {AntichainNode} antichain_30 -- ['node23', 'node20', 'node43', 'node26']
22 = {AntichainNode} antichain_29 -- ['node23', 'node20', 'node43', 'node24']
23 = {AntichainNode} antichain_59 -- ['node23', 'node20', 'node43', 'node28']
設定stage 具體程式碼如下:
splits = [(0, len(states))]
i = len(all_As) - 1
while i >= 0:
new_splits = []
stage_id = 0
for (start, end) in splits:
partial_splits = \
analyze_partitioning(all_As[i], states, start, end,
network_bandwidths[i], all_num_machines[i],
activation_compression_ratio,
print_configuration, verbose)
start_point = start
for split in partial_splits: # 遍歷這個偏序列表
new_splits.append((start_point, split))
if i == 0: # 最終的while
# 針對每個節點,找到每個節點的所有反鏈
predecessors = gr.all_predecessors(states[split-1].antichain)
for predecessor in predecessors:
if predecessor.stage_id is None:
predecessor.set_stage_id(stage_id) # 打上stage id
start_point = split
stage_id += 1
new_splits.append((start_point, end))
if i == 0: # 最終的while
predecessors = gr.all_predecessors(states[end-1].antichain)
for predecessor in predecessors:
if predecessor.stage_id is None:
predecessor.set_stage_id(stage_id) # 打上stage id
stage_id += 1
splits = new_splits
i -= 1
5.4 總結
我們總結一下計算分割槽和分析分割槽所做的工作:
-
反鏈DAG圖已經被分割成若干狀態(states),每個狀態很重要的一個屬性是其增強反鏈。states 就是對增強反鏈進行拓撲排序之後的結果,按照這個順序進行訓練是符合邏輯的。
-
compute_partitioning 是使用動態規劃演算法對於這些 states 狀態得出一個最優化結果,但是這個計算分割槽只是得到了一個動態規劃優化結果,需要在analyze_partitioning之中進行分析劃分之後,賦予到各個層(stage)。
-
analyze_partitioning 是利用動態規劃演算法的最優化結果來做具體分割槽,排序後得到了一個偏序結果,就是理想分割序列。
-
依據 analyze_partitioning 的結果,給模型的每一個子層分配一個stage,如果某些子層屬於同一個stage,這些子層最終就被分配到同一個worker(節點)上執行。
0x06 輸出
輸出檔案如下(摘錄部分),可以看到,關鍵之處在於給每一個節點加上了stage,具體如何使用我們將在下一篇進行分析。比如:
stage_id=0 對應的是 node4。
stage_id=1 對應的是 node5,node6。
stage_id=2 對應的是 node7。
stage_id=3 對應的是 node8,node10,node11,node12。
......
具體如下:
node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000 -- stage_id=0
node5 -- EmuBidirLSTM( (bidir): LSTM(1024, 1024, bidirectional=True) (layer1): LSTM(1024, 1024) (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000 -- stage_id=1
node6 -- Dropout(p=0.2) -- forward_compute_time=0.077, backward_compute_time=0.196, activation_size=12582912.0, parameter_size=0.000 -- stage_id=1
node7 -- LSTM(2048, 1024) -- forward_compute_time=3.190, backward_compute_time=5.348, activation_size=6553600.0, parameter_size=50364416.000 -- stage_id=2
node8 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=3
node10 -- Dropout(p=0.2) -- forward_compute_time=0.064, backward_compute_time=0.128, activation_size=6291456.0, parameter_size=0.000 -- stage_id=3
node11 -- LSTM(1024, 1024) -- forward_compute_time=2.491, backward_compute_time=4.203, activation_size=6553600.0, parameter_size=33587200.000 -- stage_id=3
node12 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=3
node14 -- Add -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=4
node15 -- Dropout(p=0.2) -- forward_compute_time=0.059, backward_compute_time=0.121, activation_size=6291456.0, parameter_size=0.000 -- stage_id=4
node16 -- LSTM(1024, 1024) -- forward_compute_time=2.492, backward_compute_time=4.201, activation_size=6553600.0, parameter_size=33587200.000 -- stage_id=4
node17 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=5
node19 -- Add -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=5
node1 -- node4
node4 -- node5
node2 -- node5
node5 -- node6
node6 -- node7
node7 -- node8
node8 -- node10
node10 -- node11
node11 -- node12
node12 -- node14
node8 -- node14
node14 -- node15
node15 -- node16
node16 -- node17
node17 -- node19