[原始碼解析] 快手八卦 --- 機器學習分散式訓練新思路(1)

羅西的思考發表於2022-01-04

[原始碼解析] 快手八卦 --- 機器學習分散式訓練新思路(1)

0x00 摘要

“Bagua“ 是快手和蘇黎世理工(ETH Zürich)聯合開發的分散式訓練框架。其專門針對分散式的場景設計特定的優化演算法,實現演算法和系統層面的聯合優化,力圖極致化分散式訓練的效率。其特點是:

  • 並行效能顯著提高;

  • 對網路環境更魯棒;

  • “一鍵式”使用;

  • 分散式通訊演算法易擴充性;

  • 可用於工業級場景大規模使用;

  • 安全、故障易排查;

本文以:

為基礎來分析學習。本文學習“bagua"總體設計思路和負載均衡資料載入器。

0x01 設計思路

以下摘錄於快手官方帖子 快手八卦!突破 TensorFlow、PyTorch 並行瓶頸的開源分散式訓練框架來了! 和 ETH PPT,按照自己理解有調整。

1.1 如何通訊

在資料並行之中,從單機單卡的訓練到多機多卡訓練的核心,是每個卡把自己的計算結果進行累加和傳播,所以一個關鍵點是兩個worker之間如何進行通訊。

這個過程好比每個人把自己知道的資訊傳遞給他人,然後又從其他人那裡獲取資訊,最後完成全域性的資訊同步。如果把計算單元之間的資訊同步類比為人與人之間的資訊同步,那麼社會實踐經驗告訴我們,“八卦”可能是訊息傳遞最高效的模式。“八卦”訊息傳播具有去中心化、非同步通訊、資訊壓縮的特點,這與 Bagua 裡面實現的通訊演算法剛好一一呼應。

1.2 通訊模式分類

針對通訊模式,有如下分類。

1.2.1 系統架構

按照系統架構來區分,是引數伺服器和Allreduce。

下圖是引數伺服器和Allreduce正規化的圖例。

  • 引數伺服器架構中,模型可以被分割成分片(shard)並分佈到多個節點(我們稱這些節點為 "引數伺服器")。在訓練階段,worker定期從引數伺服器獲取模型,利用計算單元(如GPU)進行前向和後向傳播,並將梯度推送給引數伺服器,而引數伺服器彙總梯度並更新引數。
  • Allreduce正規化之中,所有worker都與他們的鄰居合作進行模型/梯度交換。現有的系統通常採用環形拓撲結構進行兩階段的交流:首先,正規化將模型/梯度劃分為n個塊(其中n為節點數),並使用不同起點和終點的n個環來聚合n個塊;其次,位於不同節點的每個塊的聚合結果會在環內進行廣播。

1.2.2 同步角度

從通訊同步角度看可以分為同步或是非同步(Synchronous or Asynchronous):

  • 同步模式中,在每一次迭代過程中,所有工作節點都需要進行通訊,並且下一步迭代必須等待當前迭代的通訊完成才能開始。
  • 反之,非同步式分佈演算法 則不需要等待時間:當某個節點完成計算後就可直接傳遞本地梯度,進行模型更新。

1.2.3 通訊拓撲

從通訊拓撲角度看可以分成中心化或是去中心化(Centralized or Decentralized):

  • 在中心化的通訊模式中,梯度或模型的同步過程需要所有的工作節點進行參與,因此,較高的網路延時往往會導致訓練效率的降低。
  • 去中心化的通訊模式往往可以有效的解決上述問題:在該模式下,工作節點可以被連線成特定的拓撲結構(例如環),在通訊過程中,每一個工作節點只與和它相鄰的節點進行通訊。

1.2.4 壓縮

從通訊壓縮與否角度看,有完整精度模式或資訊壓縮模式(Full-Precision or Low-Precision)兩種:

  • 完整精度模式會使用與本地模型相同的 32 位浮點數(float32)進行傳輸。
  • 另一方面,在通訊存在瓶頸的情況下,基於大量已有研究通過量化 (quantization) 或稀疏化 (sparsification) 等方法壓縮梯度,再用壓縮後的梯度更新引數。在很多場景下,可以達到和完整精度相同的精度,同時提升通訊效率。

1.3 挑戰

快手在實現之中,遇到了三個挑戰:

  • 理論基礎:通訊模式需要有理論的支撐,需要嚴格在理論上證明通訊是有效的,收斂的。
  • 系統設計:現有分散式學習系統都無法滿足所有的新的通訊模式,所以需要設計新的系統結構,才能利用這種演算法帶來的優勢。
    • 引數伺服器基本操作put/get,無法實現去中心化和誤差補償。
    • Allreduce是全域性性的,無法實現去中心化或者非同步模式。
  • 評測:需要在大規模真實場景下對各種演算法進行評測。

1.4 Bagua 實現

1.4.1 分層

Bagua 具體分為三層:

  • 演算法層:在邏輯層基礎之上,實現了具體演算法,比如某一個演算法是去中心化,壓縮,非同步的。
  • 邏輯通訊層:在物理通訊層基礎之上,實現了多種通訊原語,比如去中心化,精度,同步等等,這些通訊原語不是針對某一類演算法特殊設計的,而對上層是統一的。
  • 物理通訊層:在此層整合了一些常見通訊庫,從而提供了基本的send,receive操作。

1.4.2 通訊演算法選項

針對通訊模式分類,Bagua 相應將通訊過程抽象成了如下的演算法選項:

  • 中心化或是去中心化(Centralized or Decentralized)。

  • 同步或是非同步(Synchronous or Asynchronous)。

  • 完整精度模式或資訊壓縮模式(Full-Precision or Low-Precision)。

雖然為了提升通訊效率,Bagua 沒有依照傳統的方式同步所有計算節點的結果,甚至每次同步的資訊還有偏差,但是得益於最新理論上的進展,這幾種通訊策略以及他們的組合最終收斂解的正確性和效率仍然能得到充分保證,而且計算複雜度跟同步中心化和資訊無損的方法相當,但是通訊效率更高。

img

Bagua 提供了一套詳盡的通訊模式來支援使用者在上述模式中任意選擇組合,我們將這一分散式訓練系統對於上述演算法選項的支援情況總結在下表中:

img

從表格中不難看出,現有框架的優化只是針對較為通用的演算法(中心化同步完整精度),對於其他的演算法組合,這些系統的支援非常有限。對於中心化同步進行資訊壓縮,這些系統往往只能支援較為簡單的 float32->float16 壓縮,相較而言,Bagua 則可以支援更為複雜的 ByteGrad,QAdam 等演算法。對於其他的演算法組合,現有的框架通常無法支援,而 Bagua 則可以自由支援。

1.4.3 總體

BAGUA的核心是一個訓練演算法,由開發者使用BAGUA提供的通訊原語和抽象概念來實現。演算法將終端使用者提供的神經網路作為輸入,併為其配備一個特定於演算法的通訊功能。具體來說,演算法的開發者會在執行的不同階段將這個通訊功能註冊為鉤子。

1.4.4 優化

然而,簡單地支援演算法選項並不能直接在大規模叢集上帶來效能的提升。Bagua 的核心優勢在於,為了追求極致化的效能,而實現演算法和實現的聯合優化。具體來講,基於上述的通訊層抽象,使用者既可以方便得選擇系統提供的各種演算法組合從而獲得效能提升,又能靈活得實現新的分散式 SGD 演算法 —— Bagua 將自動為這一演算法實現提供系統層優化。這些系統優化包含:

  • 將通訊時間隱藏在計算時間中。
  • 引數分桶及其記憶體管理。
  • 分層化的通訊實現。

想要強調的是,這些系統實現層面的優化是對於各種演算法組合廣泛適用,而非侷限在某一特定的演算法設定上。因此,所有的系統優化都可以被靈活的複用到各種演算法實現中去,這在保證“端到端”的效能提升的同時,也為開發新的分散式演算法提供了良好的平臺。

1.5 流程圖

我們使用官方號的圖例做一下總結

img

0x02 分析思路

通過官方文章我們可以發現對於分析學習來說有如下情況:

  • 通訊方面的優化實現是八卦專案的一大特點。
  • 底層 Rust 語言筆者不熟悉。
  • 通盤研究整體程式碼不現實。

因此我們決定以 中心化、非同步通訊,分層化的通訊實現 為中心,再結合幾個特色實現來學習分析。本文學習負載均衡資料載入器。

0x03 Load Balanced Data Loader

在某些場景下當訓練資料中樣本的計算複雜度是不同的,比如在 NLP 和語音任務中每個樣本的長度就不同。這時,使用八卦的負載均衡資料載入器可以大大提高分散式訓練吞吐量,在這種情況下,worker 的工作負載是相似的。我們接下來就從例項入手,看看如何實現資料載入的負載均衡

我們先看看負載均衡的需求,假如我們有兩個模型副本進行資料並行,有如下資料,假如這些資料代表的是資料複雜度(會影響計算時間)

[ 7,  1, 11,  5,  10,  2,  9, 4,  6,  0,  8,  3]

那麼第一個模型副本收到的資料為:[7,11,10,9,6, 8]。第二個模型副本收到的資料為:[1,5,2,4,0,3]。可以看出來兩個模型在每個batch收到資料的複雜度不同,會造成負載不均衡。

                         +  8                         + 3
                         |                            |
                         |  6                         | 0
                         |                            |
                         |  9                         | 4
                         |                            |
batch 3   +----------->  |  10                        | 2  <----------+  batch 3
                         |                            |
batch 2   +----------->  |  11                        | 5  <----------+  batch 2
                         |                            |
batch 1   +----------->  v  7                         v 1  <----------+  batch 1

                  +-------------------+        +-------------------+
                  |                   |        |                   |
                  |     worker 0      |        |     worker 1      |
                  |                   |        |                   |
                  |                   |        |                   |
                  +-------------------+        +-------------------+

理想狀態應該是兩個模型每個batch收到的資料複雜度都相仿,比如第一個模型收到 [1,3,5,7,9],第二個模型的資料是[2,4,6,8,10],在下圖的輸入下,可以看到每次batch資料複雜度相仿,從而達到負載均衡的效果:

                         +                            +
                         |  9                         | 10
                         |                            |
                         |  7                         | 8
                         |                            |
batch 3   +----------->  |  5                         | 6  <----------+  batch 3
                         |                            |
batch 2   +----------->  |  3                         | 4  <----------+  batch 2
                         |                            |
batch 1   +----------->  v  1                         v 2  <----------+  batch 1

                  +-------------------+        +-------------------+
                  |                   |        |                   |
                  |     worker 0      |        |     worker 1      |
                  |                   |        |                   |
                  |                   |        |                   |
                  +-------------------+        +-------------------+

3.1 使用

我們直接使用原始碼中的例子修改學習一下。

import torch
from load_balancing_data_loader import LoadBalancingDistributedSampler
from torch.utils.data import TensorDataset, DataLoader

def test_load_balancing_distributed_batch_sampler():
    num_replicas = 2 # 分成兩個副本
    total_batch = 3 

    n = sum([i + 1 for i in range(total_batch)]) * num_replicas
    dataset = TensorDataset(torch.randn(n, 2), torch.randperm(n))

    sampler = LoadBalancingDistributedSampler(
        dataset,
        complexity_fn=lambda x: x[1],
        num_replicas=num_replicas,
        rank=0,
        shuffle=True, # 需要shuffle
        random_level=0.5, # 加入隨機
    )

    dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler)

    cur_idx = 0
    for i, data in enumerate(dataloader):
        batch_size = data[0].shape[0]
        cur_idx += batch_size * num_replicas
        print(cur_idx)

test_load_balancing_distributed_batch_sampler()

因為此處程式碼十分繞,所以我們逐次解析。

3.2 生成資料集

首先是生成資料集部分。torch.randn(n, 2) 生成了隨機張量,torch.randperm(n) 生成了 n 的隨機排序。這裡假定 n 是12。

# 生成了資料集
n = sum([i + 1 for i in range(total_batch)]) * num_replicas
dataset = TensorDataset(torch.randn(n, 2), torch.randperm(n))

TensorDataset 類似 zip 命令,生成了tuple列表。

dataset = {TensorDataset: 12} 
 tensors = {tuple: 2} (
   
  0 = {Tensor: 12} tensor([[-1.5556,  0.6848],\n        [ 2.0811,  1.5011],\n        [ 0.7434, -0.4990],\n        [-0.2706,  1.7227],\n        [ 0.2179,  0.0622],\n        [-0.3014, -0.6435],\n        [-0.1773, -1.3405],\n        [-1.8212,  0.3702],\n        [-0.5526, -0.2077],\n        [-1.6543,  0.3109],\n        [ 0.3265,  0.5987],\n        [-1.5566,  0.2854]])
   
   1 = {Tensor: 12} tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3])

得出目前的TensorDataset如下 ,0 是實際資料,1 是資料複雜度,後續處理的目的就是按照資料複雜度對這些張量排序。我們可以設想下,最終排序應該就是一個複雜度均勻的排序結果。

+-----------------------------------------------------------------------------+
| TensorDataset                                                               |
|                                                                             |
|   0 = {Tensor: 12} tensor([[-1.5556,  0.6848],......                        |
|                                                                             |
|   1 = {Tensor: 12} tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3]) |
|                                                                             |
+-----------------------------------------------------------------------------+

3.3 初始化

我們來到了 LoadBalancingDistributedSampler 的初始化。

def __init__(
    self,
    dataset: Dataset,
    complexity_fn: Callable[..., int],
    num_replicas: Optional[int] = None,
    rank: Optional[int] = None,
    shuffle: bool = True,
    seed: int = 0,
    drop_last: bool = False,
    random_level: float = 0,
) -> None:
    if num_replicas is None:
        num_replicas = dist.get_world_size()
    if rank is None:
        rank = dist.get_rank()

    self.dataset = dataset
    self.num_replicas = num_replicas
    self.rank = rank
    self.epoch = 0
    self.drop_last = drop_last

    # If the dataset length is evenly divisible by # of replicas, then there
    # is no need to drop any data, since the dataset will be split equally.
    dataset_len = len(self.dataset)  # type: ignore
    if self.drop_last and dataset_len % self.num_replicas != 0:  # type: ignore
        # Split to nearest available length that is evenly divisible.
        # This is to ensure each rank receives the same amount of data when
        # using this Sampler.
        self.num_samples = math.ceil(
            # `type:ignore` is required because Dataset cannot provide a default __len__
            # see NOTE in pytorch/torch/utils/data/sampler.py
            (dataset_len - self.num_replicas)
            / self.num_replicas
        )
    else:
        self.num_samples = math.ceil(dataset_len / self.num_replicas)  # type: ignore
    self.total_size = self.num_samples * self.num_replicas
    self.shuffle = shuffle
    self.seed = seed

""" 
此時變數為
self = {LoadBalancingDistributedSampler: 6} 
 dataset = {TensorDataset: 12} <torch.utils.data.dataset.TensorDataset object at 0x7ff7385aecf8>
 drop_last = {bool} False
 epoch = {int} 0
 num_replicas = {int} 2
 num_samples = {int} 6
 rank = {int} 0
 seed = {int} 0
 shuffle = {bool} True
 total_size = {int} 12 
"""       
    
    # 以下是與PyTorch原生的主要不同之處
    self.item_complexity_map = dict()
    for item_index in range(dataset_len):
        # 每一個item都有一個complexity
        self.item_complexity_map[item_index] = complexity_fn(
            self.dataset[item_index]
        )

"""
complexity_fn 是選取 tuple 的第二個元素作為複雜度,我們回憶一下資料集的複雜度
{Tensor: 12} tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3])

所以得到了複雜度map如下:
item_complexity_map = {dict: 12} {0: tensor(7), 1: tensor(8), 2: tensor(11), 3: tensor(4), 4: tensor(5), 5: tensor(2), 6: tensor(9), 7: tensor(10), 8: tensor(0), 9: tensor(6), 10: tensor(1), 11: tensor(3)}
 0 = {Tensor} tensor(7) # 第 0 個元素複雜度是 7
 1 = {Tensor} tensor(8) # 第 1 個元素複雜度是 8
 2 = {Tensor} tensor(11)
 3 = {Tensor} tensor(4)
 4 = {Tensor} tensor(5)
 5 = {Tensor} tensor(2)
 6 = {Tensor} tensor(9)
 7 = {Tensor} tensor(10)
 8 = {Tensor} tensor(0)
 9 = {Tensor} tensor(6)
 10 = {Tensor} tensor(1)
 11 = {Tensor} tensor(3)
"""        
        
    # 按照複雜度排序    
    self.ordered_item_complexity_map = OrderedDict(
        sorted(self.item_complexity_map.items(), key=lambda t: t[1])
    )
    
"""
排序之後如下:
ordered_item_complexity_map = {OrderedDict: 12} OrderedDict([(8, tensor(0)), (10, tensor(1)), (5, tensor(2)), (11, tensor(3)), (3, tensor(4)), (4, tensor(5)), (9, tensor(6)), (0, tensor(7)), (1, tensor(8)), (6, tensor(9)), (7, tensor(10)), (2, tensor(11))])
 8 = {Tensor} tensor(0) 第8個元素複雜度最低,是0
 10 = {Tensor} tensor(1) # 第10個元素複雜度次低,是1
 5 = {Tensor} tensor(2)
 11 = {Tensor} tensor(3)
 3 = {Tensor} tensor(4)
 4 = {Tensor} tensor(5)
 9 = {Tensor} tensor(6)
 0 = {Tensor} tensor(7)
 1 = {Tensor} tensor(8)
 6 = {Tensor} tensor(9)
 7 = {Tensor} tensor(10)
 2 = {Tensor} tensor(11)
"""    
    
    max_complexity = max(self.item_complexity_map.values()) # 11
    min_complexity = min(self.item_complexity_map.values()) # 0
    self.random_number = int((max_complexity - min_complexity) * random_level + 1) # 6
    
# random_number = {int} 1
  

擴充如下:

  • TensorDataset ,0 = ... 是實際資料,1 = ... 是資料複雜度,後續就是按照複雜度排序,而且所有排序或者打亂都沒有對原始資料進行移動,而是通過額外空間完成。
  • 初始化內部會對複雜度進行排序,
    • item_complexity_map 是得到每個元素的原始複雜度,比如 0: 7 表示第 0 個元素複雜度是 7。
    • ordered_item_complexity_map 就是排序之後的結構,其中 (8, 0) 表示第8個元素複雜度最低,是0,整個map是升序排列。

TensorDataset 的邏輯圖擴充如下,現在資料集 ordered_item_complexity_map 之中按照複雜度從低到高進行排序了。

+-----------------------------------------------------------------------------+
| TensorDataset                                                               |
|                                                                             |
|   0 = {Tensor: 12} tensor([[-1.5556,  0.6848],......                        |
|                                                                             |
|   1 = {Tensor: 12} tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3]) |
|                                                                             |
+-------------------------------------------+---------------------------------+
                                            |
                                            |
                                            v
+-------------------------------------------+------------------------------------------+
| LoadBalancingDistributedSampler.__init__                                             |
|                                                                                      |
|                                                                                      |
|  item_complexity_map = {dict: 12} {0: 7, 1: 8, 2: 11, 3: 4, 4: 5, 5: 2,              |
|                                                                                      |
|                                    6: 9, 7: 10, 8: 0, 9: 6, 10: 1, 11: 3}            |
|                                           +                                          |
|                                           |                                          |
|                                           |  sorted                                  |
|                                           |                                          |
|                                           v                                          |
|  ordered_item_complexity_map = {OrderedDict: 12} [(8, 0), (10, 1), (5, 2), (11, 3),  |
|                                                                                      |
|                    (3, 4), (4, 5), (9, 6), (0, 7), (1, 8), (6, 9), (7, 10), (2, 11)] |
|                                                                                      |
+--------------------------------------------------------------------------------------+

3.4 使用

示例程式碼之中接下來是使用資料:

dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler)

cur_idx = 0
for i, data in enumerate(dataloader):
    batch_size = data[0].shape[0]
    cur_idx += batch_size * num_replicas
    print(cur_idx)

3.4.1 獲取資料

我們接下來看看如何獲取資料,就是如何從loader拿到sample。

  • 首先會呼叫 shuffle_chunks 來打亂資料。
  • 然後得到自己rank對應的index。
def __iter__(self) -> Iterator:
    index_chunks, chunk_indices = self.shuffle_chunks() # 打亂資料
    # subsample
    indices = [index_chunks[i][self.rank] for i in chunk_indices] # 用 rank來提取資料

"""
得到資料如下:
chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3] 把 index_chunks 順序打亂,chunk_indices 是打亂之後的結果
index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]] 均勻分成兩組
indices = {list: 6} [8, 7, 6, 5, 4, 0] 得到自己rank對應的index
"""    
    return iter(indices)

3.4.2 shuffle

我們看看shuffle 具體程式碼如下,這裡最終要分成 6 = 12(資料數目) / 2( num_replicas ) 組資料。

def shuffle_chunks(self):
    def chunks_wrap_padding(lst, n):
        """Yield successive n-sized chunks from lst."""
        num_chunks = max(1, self.num_samples)
        num_elements = num_chunks * n
        current_lst = []
        for i in range(num_elements):
            current_lst.append(lst[i % len(lst)])
            if len(current_lst) == n:
                yield current_lst
                current_lst = []

    if self.shuffle: # 需要再次打亂
        # deterministically shuffle based on epoch and seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)

        if self.random_number > 0:
            # 這裡的打亂機制很巧妙,就是隨機再生成複雜度,然後加到原先複雜度map上
            item_complexity_map = self.item_complexity_map.copy() # 原來map做個拷貝
            complexity_random_ints = torch.randint( # 新生成了一些複雜度變化值
                self.random_number, (len(item_complexity_map),), generator=g
            ).tolist()
"""
complexity_random_ints = {list: 12} [2, 3, 5, 0, 1, 3, 1, 1, 1, 3, 5, 2]

item_complexity_map = {dict: 12} {0: tensor(7), 1: tensor(8), 2: tensor(11), 3: tensor(4), 4: tensor(5), 5: tensor(2), 6: tensor(9), 7: tensor(10), 8: tensor(0), 9: tensor(6), 10: tensor(1), 11: tensor(3)}
"""
            
            # 原來複雜度map + 複雜度變化值
            for k, v in zip(item_complexity_map, complexity_random_ints):
                item_complexity_map[k] += v
"""
生成新的複雜度
item_complexity_map = {0: tensor(9), 1: tensor(11), 2: tensor(16), 3: tensor(4), 4: tensor(6), 5: tensor(5), 6: tensor(10), 7: tensor(11), 8: tensor(1), 9: tensor(9), 10: tensor(6), 11: tensor(5)}
"""
        
            # 再次對新複雜度排序
            ordered_item_complexity_map = OrderedDict(
                sorted(item_complexity_map.items(), key=lambda t: t[1])
            )

"""
ordered_item_complexity_map = {OrderedDict: 12} OrderedDict([(8, tensor(1)), (3, tensor(4)), (5, tensor(5)), (11, tensor(5)), (4, tensor(6)), (10, tensor(6)), (0, tensor(9)), (9, tensor(9)), (6, tensor(10)), (1, tensor(11)), (7, tensor(11)), (2, tensor(16))])
 8 = {Tensor} tensor(1)
 3 = {Tensor} tensor(4)
 5 = {Tensor} tensor(5)
 11 = {Tensor} tensor(5)
 4 = {Tensor} tensor(6)
 10 = {Tensor} tensor(6)
 0 = {Tensor} tensor(9)
 9 = {Tensor} tensor(9)
 6 = {Tensor} tensor(10)
 1 = {Tensor} tensor(11)
 7 = {Tensor} tensor(11)
 2 = {Tensor} tensor(16)
 __len__ = {int} 12
"""
        else:
            ordered_item_complexity_map = self.ordered_item_complexity_map

        index_chunks = list( # 按照 num_replicas 進行分片
            chunks_wrap_padding(
                list(ordered_item_complexity_map.keys()), self.num_replicas
            )
        )

"""
被均勻分配成兩組,每組中兩個元素的複雜度接近
index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]]
 0 = {list: 2} [8, 3]
 1 = {list: 2} [5, 11]
 2 = {list: 2} [4, 10]
 3 = {list: 2} [0, 9]
 4 = {list: 2} [6, 1]
 5 = {list: 2} [7, 2]
 __len__ = {int} 6
"""        
        # 再次打亂 index_chunks
        chunk_indices = torch.randperm(len(index_chunks), generator=g).tolist()  # type: ignore
    
"""
chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3]
"""    
    
    else:
        index_chunks = list(
            chunks_wrap_padding(
                list(self.ordered_item_complexity_map.keys()), self.num_replicas
            )
        )
        chunk_indices = list(range(len(index_chunks)))  # type: ignore

    if not self.drop_last:
        # add extra samples to make it evenly divisible
        padding_size = self.num_samples - len(chunk_indices)
        if padding_size <= len(chunk_indices):
            chunk_indices += chunk_indices[:padding_size]
        else:
            chunk_indices += (
                chunk_indices * math.ceil(padding_size / len(chunk_indices))
            )[:padding_size]
    else:
        # remove tail of data to make it evenly divisible.
        chunk_indices = chunk_indices[: self.num_samples]
    assert len(chunk_indices) == self.num_samples
    return index_chunks, chunk_indices

總體擴充如下:

  • TensorDataset ,0 = ... 是實際資料,1 = ... 是資料複雜度,後續就是按照複雜度排序:
  • LoadBalancingDistributedSampler.__init__ 初始化內部會對複雜度進行排序,
    • item_complexity_map 是得到每個元素的複雜度,比如 0: 7 表示第 0 個元素複雜度是 7。
    • ordered_item_complexity_map 就是按照複雜度排序之後的結構,其中 (8, 0) 表示第8個元素複雜度最低,是0。
  • shuffle_chunks 內部繼續處理,這裡的打亂機制很巧妙,沒有移動資料,而是隨機再生成複雜度,然後加到原先複雜度map上,這樣就打亂了
    • complexity_random_ints 新生成了一些複雜度變化值。
    • item_complexity_map 把原來map做個拷貝。
    • item_complexity_map 繼續操作,即:新複雜度 = 原來複雜度map + 複雜度變化值。
    • ordered_item_complexity_map 對新複雜度排序。
    • 對 ordered_item_complexity_map 按照 num_replicas 進行分片,得到 index_chunks,ordered_item_complexity_map 被均勻分配成六組,每組中兩個元素的複雜度接近
    • 然後再次打亂 index_chunks,得到 chunk_indices,就是為了把index順序打亂而已。
+--------------------------------------------------------------------------------------+
| TensorDataset                                                                        |
|                                                                                      |
|   0 = {Tensor: 12} tensor([[-1.5556,  0.6848],......                                 |
|                                                                                      |
|   1 = {Tensor: 12} tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3])          |
|                                                                                      |
+-------------------------------------------+------------------------------------------+
                                            |
                                            |
                                            v
+-------------------------------------------+------------------------------------------+
| LoadBalancingDistributedSampler.__init__                                             |
|                                                                                      |
|                                                                                      |
|  item_complexity_map = {dict: 12} {0: 7, 1: 8, 2: 11, 3: 4, 4: 5, 5: 2,              |
|                                                                                      |
|                                    6: 9, 7: 10, 8: 0, 9: 6, 10: 1, 11: 3}            |
|                                           +                                          |
|                                           |                                          |
|                                           |  sorted                                  |
|                                           |                                          |
|                                           v                                          |
|  ordered_item_complexity_map = {OrderedDict: 12} [(8, 0), (10, 1), (5, 2), (11, 3),  |
|                                                                                      |
|                    (3, 4), (4, 5), (9, 6), (0, 7), (1, 8), (6, 9), (7, 10), (2, 11)] |
|                                                                                      |
+-------------------------------------------+------------------------------------------+
                                            |
                                            |
                                            v
+-------------------------------------------+------------------------------------------+
| __iter__                                                                             |
|                                                                                      |
+-------------------------------------------+------------------------------------------+
                                            |
                                            |
                                            v
+-------------------------------------------+------------------------------------------+
|                                                                                      |
| shuffle_chunks()                                                                     |
|                                                                                      |
|                                                                                      |
|   complexity_random_ints = {list: 12} [2, 3, 5, 0, 1, 3, 1, 1, 1, 3, 5, 2]           |
|                                                                                      |
|                                                                                      |
|                                                                                      |
|   item_complexity_map = {0: 9, 1: 11, 2: 16, 3: 4, 4: 6, 5: 5, 6: 10, 7: 11, 8: 1,   |
|                                                                                      |
|                                                                9: 9, 10: 6, 11: 5}   |
|                                                                                      |
|                                                                                      |
|                                                                                      |
|   ordered_item_complexity_map = {OrderedDict: 12} [(8, 1), (3, 4), (5, 5), (11, 5),  |
|                                                                                      |
|                                                    (4, 6), (10, 6), (0, 9), (9, 9),  |
|                                                                                      |
|                                                (6, 10), (1, 11), (7, 11), (2, 16)])  |
|                                                                                      |
|                                           +                                          |
|                                           |                                          |
|                                           |                                          |
|                                           v                                          |
|                                                                                      |
|     index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]]      |
|                                                                                      |
|                                                                                      |
|     chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3]                                     |
|                                                                                      |
|                                                                                      |
+--------------------------------------------------------------------------------------+

3.4.3 梳理

shuffle 細化

看到這裡讀者可能有點暈,所以我們需要具體梳理一下。

ordered_item_complexity_map 就是按照複雜度排序之後的結構,其中 (8, 0) 表示第8個元素複雜度最低,是0。ordered_item_complexity_map 擁有 12個元素,按照兩個副本分配,所以 ordered_item_complexity_map 應該被均勻分配成六組,每組中兩個元素的複雜度接近

index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]] 是最終的結果,這裡[8, 3]是一組,複雜度接近,[5, 11]是一組,複雜度接近,比如結合 ordered_item_complexity_map 來看:

  • (8, 1), (3, 4) 就是說,第 8 個元素複雜度是1,第3個元素複雜度是4,所以 index 8,index 3 被分成一組。

  • (5, 5), (11, 5) 就是說,第 5 個元素複雜度是5,第11個元素複雜度是5,所以 index 5,index 11 被分成一組。

shuffle_chunks 的演示如下:

+--------------------------------------------------------------------------------------+
| shuffle_chunks                                                                       |
|                                                                                      |
|                                                                                      |
|                                      +--------------+     +---------------+          |
|   ordered_item_complexity_map = [ +--+(8, 1), (3, 4)|   +-+(5, 5), (11, 5)|          |
|                                   |  +--------------+   | +---------------+          |
|                                   |                     |                            |
|                                   |  +---------------+  | +---------------+          |
|                              +-------+(4, 6), (10, 6)|  | |(0, 9), (9, 9) +-------+  |
|                              |    |  +---------------+  | +---------------+       |  |
|                              |    |                     |                         |  |
|                              |    |  +----------------+ | +----------------+      |  |
|                              |    |  |(6, 10), (1, 11)| | |(7, 11), (2, 16)|  ]   |  |
|                              |    |  +-------------+--+ | +----------+-----+      |  |
|                              |    |                |    |            |            |  |
|                              +------------------+  +-------------+   +----+       |  |
|                                   |             |       |        |        |       |  |
|                                   |        +------------+   +---------------------+  |
|                                   |        |    |           |    |        |          |
|                                   v        v    v           v    v        v          |
|     index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]]      |
|                                                                                      |
|                                      +                                               |
|                                      |                                               |
|                                      |                                               |
|                                      v                                               |
|                                                                                      |
|     chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3]                                     |
|                                                                                      |
+--------------------------------------------------------------------------------------+
二次打亂

我們結合原始資料再來分析,先回頭看看 獲取資料。

def __iter__(self) -> Iterator:
    index_chunks, chunk_indices = self.shuffle_chunks()
    # subsample
    indices = [index_chunks[i][self.rank] for i in chunk_indices]

"""
得到資料如下:
chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3] 把 index_chunks 順序打亂
index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]] 均勻分成兩組
indices = {list: 6} [8, 7, 6, 5, 4, 0] 得到自己rank對應的index
"""    
    
    assert len(indices) == self.num_samples

    return iter(indices)

原始資料為 :[ 7, 8, 11, 4, 5, 2, 9, 10, 0, 6, 1, 3],後續會按照原始資料的index 來排序

按照複雜度排序/shuffle之後,rank 0 就是 [8, 5, 4, 0, 6, 7]。rank 1 就是 [3, 11, 10, 9, 1, 2]。

rank 0 和 rank 1 的batch 是 [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]] ,兩兩一組。

但是,還需要再次打亂順序,因為目前這個batch是按照複雜度從小到大排序,這樣會影響訓練效果,所以需要打亂這個順序。所以就按照 chunk_indices [0, 5, 4, 1, 2, 3] 這個順序來打亂。

打亂之後的順序是:[[8, 3], [7, 2], [6, 1], [5, 11], [4, 10], [0, 9]]。

  • 假如本worker 是 rank 0,則會獲取 index_chunks 這六組資料中和自己對應的,得到 [8, 7, 6, 5, 4, 0]。

  • 假如本worker rank 1,則是 [3,2,1,11,10,9]。注意,這些還都是原始資料的index。

具體演示如下圖(這裡只給出 rank 0 的效果):

+--------------------------------------------------------------------------------------+
| shuffle_chunks                                                                       |
|                                                                                      |
|                                      +--------------+     +---------------+          |
|   ordered_item_complexity_map = [ +--+(8, 1), (3, 4)|   +-+(5, 5), (11, 5)|          |
|                                   |  +--------------+   | +---------------+          |
|                                   |                     |                            |
|                                   |  +---------------+  | +---------------+          |
|                               +------+(4, 6), (10, 6)|  | |(0, 9), (9, 9) +------+   |
|                               |   |  +---------------+  | +---------------+      |   |
|                               |   |                     |                        |   |
|                               |   |  +----------------+ | +----------------+     |   |
|                               |   |  |(6, 10), (1, 11)| | |(7, 11), (2, 16)|  ]  |   |
|                               |   |  +-------------+--+ | +----------+-----+     |   |
|                               |   |                |    |            |           |   |
|                               +-----------------+  +-------------+   +----+      |   |
|                                   |             |       |        |        |      |   |
|                                   |        +------------+   +--------------------+   |
|                                   |        |    |           |    |        |          |
|                                   v        v    v           v    v        v          |
|     index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]]      |
|                                      +                                               |
|                                      |                                               |
|                                      |                                               |
|                                      v                                               |
|     chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3]                                     |
|                                                                                      |
+--------------------------------------+-----------------------------------------------+
                                       |
                                       |
                                       v

+--------------------------------------------------------------------------------------+
| __iter__                                                                             |
|                                    0       1        2        3       4       5       |
|        index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]]   |
|                                   +       +        +        +       +       +        |
|                                   |       |        |        |       |       |        |
|                                   +----+  +-----+  |  +-----+       |       |        |
|                                        |        |  |  |             |       |        |
|                                        |        |  |  |             |       |        |
|                                        v        v  v  v             |       |        |
|                   indices = {list: 6} [8, 7, 6, 5, 4, 0]            |       |        |
|                                           ^  ^                      |       |        |
|                                           |  |                      |       |        |
|                                           |  +----------------------+       |        |
|                                           |                                 |        |
|                                           +---------------------------------+        |
|                                                                                      |
+--------------------------------------------------------------------------------------+
最終效果

我們看看最終效果是什麼:

  • 原始資料為 :[ 7, 8, 11, 4, 5, 2, 9, 10, 0, 6, 1, 3]。

  • 最終shuffle/二次打亂之後的資料為:rank 0 是 [8, 7, 6, 5, 4, 0],rank 1 則是 [3,2,1,11,10,9]。這裡數值是原始資料的index。

  • 最終結果是:

    • batch如下,rank 0 和 rank 1 的batch 是 [[8, 3], [7, 2], [6, 1], [5, 11], [4, 10], [0, 9]],兩兩一組。這裡數值是原始資料的index。
    • rank 0 的資料是 [0, 10, 9, 2, 5, 7],rank 1的資料是[4, 11, 7, 3, 1, 6],這裡數值就是原始資料的數值了。

具體如下圖,可以看到,因為過程之中引入了隨機值,所以不是理想均衡狀態,但已經比較均衡了:

                         + 7                          + 6
                         |                            |
                         | 5                          | 1
                         |                            |
                         | 2                          | 3
                         |                            |
batch 3   +----------->  | 9                          | 7  <----------+  batch 3
                         |                            |
batch 2   +----------->  | 10                         | 11 <----------+  batch 2
                         |                            |
batch 1   +----------->  v 0                          v 4  <----------+  batch 1

                  +-------------------+        +-------------------+
                  |                   |        |                   |
                  |     worker 0      |        |     worker 1      |
                  |                   |        |                   |
                  |                   |        |                   |
                  +-------------------+        +-------------------+

0xFF 參考

PyTorch internals

快手八卦!突破 TensorFlow、PyTorch 並行瓶頸的開源分散式訓練框架來了!

https://arxiv.org/pdf/2107.01499.pdf

[1] Dean, Jeffrey, Greg S. Corrado, Rajat Monga, Kai Chen, Matthieu Devin, Quoc V. Le, Mark Z. Mao et al. “Large scale distributed deep networks.” (2012).

[2] Zhengyuan Zhou, Panayotis Mertikopoulos, Nicholas Bambos, Peter Glynn, Yinyu Ye, Li-Jia Li, and Li Fei-Fei. 2018. Distributed asynchronous optimization with unbounded delays: How slow can you go?. In International Conference on Machine Learning. PMLR, 5970–5979.

[3] DanAlistarh, DemjanGrubic, JerryLi, RyotaTomioka, and MilanVojnovic. 2016. QSGD: Communication-efficient SGD via gradient quantization and encoding. arXiv preprint arXiv:1610.02132 (2016).

[4] Dan Alistarh, Torsten Hoefler, Mikael Johansson, Sarit Khirirat, Nikola Konstanti- nov, and Cédric Renggli. 2018. The convergence of sparsified gradient methods. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. 5977–5987.

[5] Anastasia Koloskova, Sebastian Stich, and Martin Jaggi. 2019. Decentralized stochastic optimization and gossip algorithms with compressed communication. In International Conference on Machine Learning. PMLR, 3478–3487.

[6] Xiangru Lian, Ce Zhang, Huan Zhang, Cho-Jui Hsieh, Wei Zhang, and Ji Liu. 2017. Can decentralized algorithms outperform centralized algorithms? a case study for decentralized parallel stochastic gradient descent. In Proceedings of the 31st International Conference on Neural Information Processing Systems. 5336–5346.

[7] Christopher De Sa, Matthew Feldman, Christopher Ré, and Kunle Olukotun. 2017. Understanding and optimizing asynchronous low-precision stochastic gradient descent. In Proceedings of the 44th Annual International Symposium on Computer Architecture. 561–574.

[8] Xiangru Lian, Wei Zhang, Ce Zhang, and Ji Liu. 2018. Asynchronous decentral- ized parallel stochastic gradient descent. In International Conference on Machine Learning. PMLR, 3043–3052.

[9] Hanlin Tang, Shaoduo Gan, Ce Zhang, Tong Zhang, and Ji Liu. 2018. Com- munication compression for decentralized training. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. 7663–7673.

[10] Ji Liu, Ce Zhang, et al. 2020. Distributed Learning Systems with First-Order Methods. Foundations and Trends® in Databases 9, 1 (2020), 1–100.![]

相關文章