[原始碼解析] 快手八卦 --- 機器學習分散式訓練新思路(3)

羅西的思考發表於2022-01-06

[原始碼解析] 快手八卦 --- 機器學習分散式訓練新思路(3)

0x00 摘要

“Bagua“ 是快手和蘇黎世理工(ETH Zürich)聯合開發的分散式訓練框架。其專門針對分散式的場景設計特定的優化演算法,實現演算法和系統層面的聯合優化,力圖極致化分散式訓練的效率。其特點是:

  • 並行效能顯著提高;

  • 對網路環境更魯棒;

  • “一鍵式”使用;

  • 分散式通訊演算法易擴充性;

  • 可用於工業級場景大規模使用;

  • 安全、故障易排查;

本文以:

為基礎來分析學習。本文介紹去中心化和非同步通訊。

本系列前兩篇連結為:

[原始碼解析] 快手八卦 --- 機器學習分散式訓練新思路(1)

[原始碼解析] 快手八卦 --- 機器學習分散式訓練新思路(2)

0x02 去中心化

官方文章中是這樣介紹其設計思路的:

  • 中心化或是去中心化(Centralized or Decentralized):在中心化的通訊模式中,梯度或模型的同步過程需要所有的工作節點進行參與,因此,較高的網路延時往往會導致訓練效率的降低。去中心化的通訊模式 往往可以有效的解決這一問題:在該模式下,工作節點可以被連線成特定的拓撲結構(例如環),在通訊過程中,每一個工作節點只與和它相鄰的節點進行通訊。

以下結合 https://tutorials.baguasys.com/algorithms/decentralized 來學習。

2.1 示例用法

使用者可以在原始碼之中找到執行去中心化 SGD 的完整示例,這裡只是簡單介紹。

您需要初始化八卦演算法:

from bagua.torch_api.algorithms import decentralized
algorithm = decentralized.DecentralizedAlgorithm()

然後用以下方法裝飾您的模型:

model = model.with_bagua([optimizer], algorithm)

2.2 去中心化培訓概述

Decentralized SGD 是一種資料並行的分散式學習演算法,它消除了所有 worker 之間必有存在一個集中式全域性模型的需求,這使得它在通訊模式上與基於 Allreduce 或基於引數伺服器的演算法有很大不同。使用去中心化 SGD,每個 worker 只需要與一個或幾個特定的 worker 交換資料,而不是全域性聚合資料。因此,去中心化通訊的通訊連線數比 Allreduce 少得多,通訊開銷比 Parameter Server 更均衡。儘管去中心化 SGD 可能會導致每個 worker 的模型不同,但理論上已經證明,去中心化 SGD 演算法的收斂速度與其對應中心化版本相同。

2.3 去中心化訓練演算法

目前,不時有許多去中心化訓練演算法被提出。這些令人驚歎的工作集中在去中心化訓練的不同方面,如對等選擇(peer selection)、資料壓縮、非同步等,並提供了許多遠見。到目前為止,八卦已經結合了兩種基本的去中心化演算法,即去中心化 SGD和 低精度去中心化 SGD。憑藉八卦對去中心化的自動系統支援,我們預計在不久的將來會實現越來越多的去中心化演算法。

2.4 Decentralized SGD

現在我們將描述在八卦中實現的 Decentralized SGD 演算法。讓我們假設worker 的數量是 n,worker上的模型引數 是:

\[x^{(i)} ,i∈ \{0,...,n−1\} \]

每個工作人員都能夠直接從任何其他工作人員傳送或接收資料。在每次迭代 t 中,演算法重複以下步驟:

  1. 迭代t 之中,每個worker 計算本地梯度 \(g^{(t)}_t\)

  2. 將本地模型與其選定的對等模型做平均:

    \[x_{t+\frac{1}{2}}^{(i)} = \frac{x^{(i)}_{t} + x_t^{(j)}}{2} \]

  3. 用區域性梯度更新平均模型

    \[X^{(i)}_{t+1} = X^{(i)}_{t+\frac{1}{2}} - γg_t^{(i)} \]

在第 2 步中,我們採用一種策略為每次迭代中的每個 worker 選擇一個 peer,這樣所有 worker 都正確配對並且資料交換是有效的,因為每個 worker 可以在迭代之間與不同的 peer 交換資料。簡而言之,我們的策略將工作人員平均分成兩組,並在兩組之間動態配對 worker,每次迭代都不同。

2.5 通訊開銷

去中心化 SGD 的通訊開銷與網路程度(degree of network)高度相關,即一個 worker 與其他 worker 的連線數。不同的拓撲或策略會導致不同程度的網路。很明顯,我們之前描述的Decentralized SGD演算法的網路度為1。因此,在每次迭代中,一個worker只需要與一個worker建立一個連線來交換模型大小1倍的資料。我們比較了不同通訊模式在最繁忙節點延遲和頻寬方面的通訊複雜性。

演算法 延遲複雜度 頻寬複雜度
Allreduce(環) O(n) O(1)
引數伺服器 O(1) O(n)
八卦的Decentralized SGD O(1) O(1)

2.6 分析

前面官方教程之中,這部分是關鍵:

在第 2 步中,我們採用一種策略為每次迭代中的每個 worker 選擇一個 peer,這樣所有 worker 都正確配對並且資料交換是有效的,因為每個 worker 可以在迭代之間與不同的 peer 交換資料。簡而言之,我們的策略將工作人員平均分成兩組,並在兩組之間動態配對 worker,每次迭代都不同。

我們就以此出發來進行分析學習。

2.6.1 DecentralizedAlgorithmImpl

2.6.1.1 定義

引數 peer_selection_mode 可以有兩種選擇:

  • all表示在每個通訊步驟中平均所有worker的權重。
  • shift_one 是指每個 worker 在每個通訊步驟中選擇一個不同的對等點進行權重平均。
class DecentralizedAlgorithmImpl(AlgorithmImpl):
    def __init__(
        self,
        process_group: BaguaProcessGroup,
        hierarchical: bool = True,
        peer_selection_mode: str = "all",
        communication_interval: int = 1,
    ):
        """
        Implementation of the
        `Decentralized SGD <https://tutorials.baguasys.com/algorithms/decentralized>`_
        algorithm.

        Args:
            process_group (BaguaProcessGroup): The process group to work on.
            hierarchical (bool): Enable hierarchical communication.
            peer_selection_mode (str): Can be ``"all"`` or ``"shift_one"``. ``"all"`` means all workers'
                weights are averaged in each communication step. ``"shift_one"`` means each worker
                selects a different peer to do weights average in each communication step.
            communication_interval (int): Number of iterations between two communication steps.

        """
        super(DecentralizedAlgorithmImpl, self).__init__(process_group)
        self.hierarchical = hierarchical
        self.peer_selection_mode = peer_selection_mode
        self.communication_interval = communication_interval
        self.cuda_event = torch.cuda.Event()
2.6.1.2 初始化狀態

_init_states 方法把權重張量初始化到 bucket._peer_weight。

提一下,LowPrecisionDecentralizedAlgorithmImpl 是初始化了左右兩個 peer_weight,因為精力所限,本文不對其進行分析,有興趣的讀者可以自行深入。

def _init_states(self, bucket: BaguaBucket):
    weight_tensor = bucket.flattened_tensor()
    bucket._peer_weight = weight_tensor.to_bagua_tensor("peer_weight")
2.6.1.3 初始化操作

init_operations 使用 append_decentralized_synchronous_op 配置了 bucket 的 _decentralized_op 成員變數。

def init_operations(
    self,
    bagua_module: BaguaModule,
    bucket: BaguaBucket,
):
    self._init_states(bucket)
    torch.cuda.synchronize()
    bucket.clear_ops()
    decentralized_op = bucket.append_decentralized_synchronous_op( # 配置成員變數
        peer_weight=bucket._peer_weight,
        hierarchical=self.hierarchical,
        peer_selection_mode=self.peer_selection_mode,
        group=self.process_group,
    )
    bucket._decentralized_op = decentralized_op
2.6.1.4 Post操作

init_post_backward_hook 註冊了 post hook 操作,會把去中心化平均的結果拷貝回來,後面會在進行細化分析。

def init_post_backward_hook(self, bagua_module: BaguaModule):
    def hook():
        if self._should_communicate(bagua_module):
            bagua_module._bagua_backend.wait_pending_comm_ops()

            torch.cuda.current_stream().record_event(self.cuda_event)
            self.cuda_event.synchronize()
            for bucket in bagua_module.bagua_buckets:
                bucket._decentralized_op.copy_back_peer_weight( # 拷貝回來
                    bucket.backend_bucket
                )

    return hook

演算法如下,append_decentralized_synchronous_op 用來通訊,init_post_backward_hook 把去中心化平均的結果拷貝回來。

+--------------------------------------------------------------------+
|DecentralizedAlgorithmImpl                                          |
|                                                                    |
|     process_group                                                  |
|                                                                    |
|     decentralized_op = bucket.append_decentralized_synchronous_op  |
|                                                                    |
|     peer_selection_mode                                            |
|                                                                    |
|     init_post_backward_hook                                        |
|                                                                    |
+--------------------------------------------------------------------+

2.6.2 BaguaBucket

我們接下來進入 BaguaBucket,其是聚集了一系列 Bagua 張量,其最終呼叫 backend_bucket 進行處理,就是 rust 的 BaguaBucketPy。

class BaguaBucket:
    def __init__(
        self, tensors: List[BaguaTensor], name: str, flatten: bool, alignment: int = 1
    ) -> None:
        """
        Create a Bagua bucket with a list of Bagua tensors.
        """
        self.tensors = tensors
        """
        The tensors contained within the bucket.
        """
        self.bagua_module_name = tensors[0].bagua_module_name
        self._bagua_backend = get_backend(self.bagua_module_name)
        self.name = name
        """
        The bucket's name.
        """
        self.padding_tensor = None

        if alignment > 1:
            padding = sum(tensor.numel() for tensor in self.tensors) % alignment
            if padding > 0:
                padding = alignment - padding
                self.padding_tensor = torch.zeros(
                    padding, dtype=self.tensors[0].dtype, device=self.tensors[0].device
                ).to_bagua_tensor("bagua_padding_tensor_bucket_" + name)

        self._all_tensors = (
            self.tensors + [self.padding_tensor]
            if self.padding_tensor is not None
            else self.tensors
        )

        self.backend_tensor = None
        self.flatten = flatten
        if self.flatten:
            self._flatten_()
            torch.cuda.empty_cache()

        self.backend_bucket = B.BaguaBucketPy( # 底層實現
            name, [tensor._bagua_backend_tensor for tensor in self._all_tensors]
        )

        for tensor in self._all_tensors:
            tensor._bagua_bucket = self
2.6.2.1 append_decentralized_synchronous_op

append_decentralized_synchronous_op 是往桶新增了操作,當bucket中的所有張量都標記為ready時,該操作將由Bagua後端按照附加順序執行。引數 peer_weight 的意義是用於與對等模型求平均值的張量,應與桶張量的總大小相同。

append_decentralized_synchronous_op 不是 inplace 操作,這意味著桶權重首先複製到peer_weight,去中心化平均的結果放置在 peer_weight,然後使用op.copy_back_peer_weight(self) 將結果再拷貝回來。具體在前面 init_post_backward_hook 之中有拷貝回來的操作。

我們還可以注意到,如果採取了 hierarchical 模式,則傳入了 inter, intra 兩種communicator。

def append_decentralized_synchronous_op(
    self,
    peer_weight: BaguaTensor,
    hierarchical: bool = True,
    peer_selection_mode: str = "all",
    group: Optional[BaguaProcessGroup] = None,
):
    """
    Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers.
    """
    if group is None:
        group = _get_default_group()

    if hierarchical:
        return self.backend_bucket.append_decentralized_synchronous_op(
            _bagua_backend_comm(group.get_inter_node_communicator()),
            _bagua_backend_comm(group.get_intra_node_communicator()),
            hierarchical=hierarchical,
            peer_selection_mode=peer_selection_mode,
            peer_weight=peer_weight._bagua_backend_tensor,
        )
    else:
        return self.backend_bucket.append_decentralized_synchronous_op(
            _bagua_backend_comm(group.get_global_communicator()),
            None,
            hierarchical=hierarchical,
            peer_selection_mode=peer_selection_mode,
            peer_weight=peer_weight._bagua_backend_tensor,
        )
2.6.2.2 BaguaBucket

我們來到了 Rust 世界,BaguaBucket 的 append_decentralized_synchronous_op 操作之中,如果是 "all" 或者 "shift_one",則會呼叫 DecentralizedFullPrecisionSynchronous。

pub fn append_decentralized_synchronous_op(
    &mut self,
    communicator_internode: Option<&BaguaSingleCommunicator>,
    communicator_intranode: Option<&BaguaSingleCommunicator>,
    hierarchical: bool,
    peer_selection_mode: String,
    peer_weight: BaguaTensor,
) -> Arc<DecentralizedFullPrecisionSynchronous> {
    let communicator =
        BaguaCommunicator::new(communicator_internode, communicator_intranode, hierarchical)
            .expect("cannot create communicator");
    let comm_op = Arc::new(DecentralizedFullPrecisionSynchronous {
        communicator,
        peer_selection_mode: match peer_selection_mode.as_str() {
            "all" => PeerSelectionMode::All,
            "shift_one" => PeerSelectionMode::ShiftOne,
            &_ => {
                unimplemented!("unsupported peer_selection_mode for decentralized algorithm (should be `all` or `shift_one`)")
            }
        },
        step: Default::default(),
        peer_weight,
    });

    self.inner
        .lock()
        .comm_ops
        .push(comm_op.clone() as Arc<dyn CommOpTrait + Send + Sync>);
    comm_op
}
2.6.2.3 DecentralizedFullPrecisionSynchronous

DecentralizedFullPrecisionSynchronous 位於 rust/bagua-core/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs 之中。

其定義如下:

pub struct DecentralizedFullPrecisionSynchronous {
    pub communicator: BaguaCommunicator,
    pub peer_selection_mode: PeerSelectionMode,
    pub step: Mutex<usize>,
    pub peer_weight: BaguaTensor,
}
2.6.2.3.1 傳送

再回憶一下官方思路。

在第 2 步中,我們採用一種策略為每次迭代中的每個 worker 選擇一個 peer,這樣所有 worker 都正確配對並且資料交換是有效的,因為每個 worker 可以在迭代之間與不同的 peer 交換資料。簡而言之,我們的策略將工作人員平均分成兩組,並在兩組之間動態配對 worker,每次迭代都不同。

具體就是通過下面程式碼實現的。關鍵點在函式的最後一句,通過調整step, 計算出下一個peer,這樣每次peer都不同

                    // 計算出下一個peer,關鍵點在函式的最後一句,通過調整step,每次peer都不同
                    let peer_rank = if c.rank < c.nranks / 2 {
                        ((step + rank) % ((nranks + 1) / 2)) + (nranks / 2)
                    } else {
                        (rank - (nranks / 2) - step).rem_euclid(nranks / 2)
                    } 
                    
										......
                            c.send(&t.raw, peer_rank); // 傳送
                            c.recv(peer_tensor, peer_rank); // 接受
                    ......
                    
                    *self.step.lock() += 1; // 這裡是關鍵點!遞增到下一個peer

全部程式碼如下:

impl CommOpTrait for DecentralizedFullPrecisionSynchronous {
    fn execute_background_communication(
        &self,
        bucket: Arc<BaguaBucket>,
        comm_op_channels: &BaguaCommOpChannels,
    ) {
        let bucket_guard = bucket.inner.lock();
        let stream_ptr = self.communicator.stream_ptr();

        // 獲取不同的communicator
        let mut communication_tensor = match &self.communicator {
            BaguaCommunicator::SingleCommunicator(_) => {
                bucket_guard.get_communication_tensor(stream_ptr, false, false)
            }
            BaguaCommunicator::HierarchicalCommunicator(x) => match x {
                BaguaHierarchicalCommunicator::Leader(_) => {
                    bucket_guard.get_communication_tensor(stream_ptr, true, true)
                }
                BaguaHierarchicalCommunicator::Worker(_) => {
                    bucket_guard.get_communication_tensor(stream_ptr, false, false)
                }
            },
        };

        let peer_mode = &self.peer_selection_mode;
        let mut peer_guard = self.peer_weight.inner.write();
        let mut peer_tensor = peer_guard.raw.as_mut();
        let step = { *self.step.lock() } as i64;

        self.communicator.execute_communication( // 執行通訊
            &mut communication_tensor,
            true,
            true,
            false,
            &mut |c, t| {
                match peer_mode {
                    PeerSelectionMode::All => {
                        // 做普通 allreduce
                        {
                            peer_tensor.clone_from(&t.raw, c.stream_ptr);
                            let _guard = NCCLGroupGuard::new();
                            c.allreduce_inplace(peer_tensor, BaguaReductionOp::AVG);
                        }
                    }
                    PeerSelectionMode::ShiftOne => { // shift_one 
                        let rank = c.rank as i64;
                        let nranks = c.nranks as i64;
                        // 計算出下一個peer,關鍵點在函式的最後一句,通過調整step,每次peer都不同
                        let peer_rank = if c.rank < c.nranks / 2 {
                            ((step + rank) % ((nranks + 1) / 2)) + (nranks / 2)
                        } else {
                            (rank - (nranks / 2) - step).rem_euclid(nranks / 2)
                        } as i32;
                        {
                            let _guard = NCCLGroupGuard::new();
                            c.send(&t.raw, peer_rank); // 傳送
                            c.recv(peer_tensor, peer_rank); // 接受
                        }
                        peer_tensor.average_inplace(&t.raw, c.stream_ptr);
                    },
                    PeerSelectionMode::Ring => {
                        unimplemented!() // 沒有實現
                    },
                }
            },
        );

        *self.step.lock() += 1; // 這裡是關鍵點!遞增到下一個pee
    }
}

沒有精力去研究rust,所以使用原始碼中的測試程式碼 tests/torch_api/test_decentralized.py 來看看,八卦在這方面真心做的不錯。

def get_peer_rank(peer_selection_mode, rank, nranks, step, communication_interval):
    comm_step = step // communication_interval
    if peer_selection_mode == "shift_one":
        if rank < nranks // 2:
            return ((comm_step + rank) % ((nranks + 1) // 2)) + (nranks // 2)
        else:
            return (rank - (nranks // 2) - comm_step) % (nranks // 2)
    else:
        ValueError("Unsupported `peer_selection_mode`")

step = 1
for i in range(6):
    print("iteration : ", i)
    print("peer is : ", get_peer_rank("shift_one", 1, 5, step, 1))
    step += 1
    
"""
iteration :  0
peer is :  4
iteration :  1
peer is :  2
iteration :  2
peer is :  3
iteration :  3
peer is :  4
iteration :  4
peer is :  2
iteration :  5
peer is :  3
"""

整理出圖如下,worker 1 每次分別和 worker 4, worker 2,worker 3 進行交換。

                              +--------------+
                              |              |
                              |   Worker 0   |
                              |              |
                              |              |
                              +--------------+

                              +--------------+
                              |              |
                   +------->  |   Worker 2   |
+--------------+   | peer 2   |              |
|              |   |          |              |
|   Worker 1   |   |          +--------------+
|              +---+
|              |   |          +--------------+
+--------------+   |          |              |
                   |          |   Worker 3   |
                   +------->  |              |
                   | peer 3   |              |
                   |          +--------------+
                   |
                   |          +--------------+
                   |          |              |
                   +--------> |   Worker 4   |
                     peer 1   |              |
                              |              |
                              +--------------+
2.6.2.3.2 拷貝回來

copy_back_peer_weight 就是前面提到的回拷貝操作。

impl DecentralizedFullPrecisionSynchronous {
  
    pub fn copy_back_peer_weight(&self, bucket: Arc<BaguaBucket>) { // 拷貝回去
        let bucket_guard = bucket.inner.lock();
        let stream_ptr = self.communicator.stream_ptr();

        let mut communication_tensor =
            bucket_guard.get_communication_tensor(stream_ptr, false, false);

        self.communicator.execute_communication(
            &mut communication_tensor,
            false,
            false,
            true,
            &mut |c, t| {
                t.raw
                    .clone_from(self.peer_weight.inner.read().raw.as_ref(), c.stream_ptr);
            },
        );
    }
}

我們再給出一個示意圖。

+---------------------------------------------------------------------+
|DecentralizedAlgorithmImpl                                           |
|                                                                     |
|     process_group                                                   |
|                                                                     |
|     decentralized_op = bucket.append_decentralized_synchronous_op   |
|                                                 +                   |
|     peer_selection_mode                         |                   |
|                                                 |                   |
|     init_post_backward_hook                     |                   |
|              ^                                  |                   |
|              |                                  |                   |
|              |                                  |                   |
+---------------------------------------------------------------------+
               |                                  |
               |                                  |
+-----------------------------------------------------------+         +----------+
| BaguaBucket  |                                  |         |         | Worker 0 |
|              |                                  |         |         +----------+
|              |                                  v         |
|              |                                            |         +----------+
|              |    DecentralizedFullPrecisionSynchronous { |         | Worker 1 |
|              |                                            |         +----------+
|              |         PeerSelectionMode::ShiftOne {      |
|              |                                            |   peer2 +----------+
|              |            c.send(&t.raw, peer_rank);+--------+----> | Worker 2 |
|              |            c.recv(peer_tensor, peer_rank)  |  |      +----------+
|              |         }                                  |  |
|              |    }                                       |  |peer3 +----------+
|              |                                            |  +----> | Worker 3 |
|              |                                            |  |      +----------+
|              |                                            |  |
|              +--+ copy_back_peer_weight                   |  |peer4 +----------+
|                                                           |  +----> | Worker 4 |
+-----------------------------------------------------------+         +----------+

0x03 非同步

關於非同步通訊,官方文件思路如下:

  • 同步或是非同步(Synchronous or Asynchronous):同步模式中,在每一次迭代過程中,所有工作節點都需要進行通訊,並且下一步迭代必須等待當前迭代的通訊完成才能開始。反之,非同步式分佈演算法 [2] 則不需要等待時間:當某個節點完成計算後就可直接傳遞本地梯度,進行模型更新。

我們接下來用 https://tutorials.baguasys.com/algorithms/async-model-average 結合程式碼來分析學習。

3.1 示例用法

首先初始化八卦演算法:

from bagua.torch_api.algorithms import async_model_average
algorithm = async_model_average.AsyncModelAverageAlgorithm()

然後對模型使用演算法

model = model.with_bagua([optimizer], algorithm)

與執行同步演算法不同,您需要在訓練過程完成時(例如,當您要執行測試時)明確停止通訊執行緒:

model.bagua_algorithm.abort(model)

要在再次開始訓練時恢復通訊執行緒,請執行以下操作:

model.bagua_algorithm.resume(model)

3.2 非同步模型平均

在Gradient AllReduce 等同步通訊演算法中,同一迭代中每個 worker 都需要以鎖步(lock-step)方式運作。當系統中沒有落後者(straggler)時,這種同步演算法相當有效,並可以提供更容易推理的確定性訓練結果。然而,當系統中存在落後者時,使用同步演算法時,更快的 worker 必須在每次迭代中等待最慢的 worker,這會極大地損害整個系統的效能。為了處理掉隊者,我們可以使用非同步演算法,其中 worker 不需要同步。八卦提供的非同步模型平均演算法就是這樣的非同步演算法。

3.3 演算法

非同步模式平均演算法可以被描述為如下:

每個 worker 都維護一個本地模型 X. 第 i 個 worker 維護 $ x^{(i)}$ ,每個 worker 並行執行兩個執行緒。第一個執行緒進行梯度計算(稱為計算執行緒),另一個執行緒進行通訊(稱為通訊執行緒)。對於每個 worker i, 有一個鎖 \(m_i\),控制對其模型的訪問。

第 i 個 worker 上的計算執行緒重複以下步驟:

  1. 獲取鎖 \(m_i\)
  2. 在一批輸入資料上計算區域性梯度 $∇ F (x^{(i)}) $。
  3. 釋放鎖 \(m_i\).
  4. 用區域性梯度更新模型,$x^{(i)} = x^{(i)} - γ∇ F (x^{(i)}) $。

第 i 個 worker 上的通訊執行緒重複以下步驟::

  1. 獲取鎖 \(m_i\)
  2. 與所有其他 worker 的模型通訊以平均本地模型\(X^{(i)}\)\(X^{(i)} = \frac{1}{n} \sum^n_{j=1}X^{(j)}\)
  3. 釋放鎖 \(m_i\).

每個 worker 獨立併發地執行這兩個執行緒。

3.4 分析

大家可以看到,本質上就是計算執行緒和通訊執行緒都是自己操作,但是依賴鎖進行彼此協調,達到了非同步的目的。

3.4.1 非同步通訊實現

AsyncModelAverageAlgorithmImpl 是非同步通訊的實現。

class AsyncModelAverageAlgorithmImpl(AlgorithmImpl):
    def __init__(
        self,
        process_group: BaguaProcessGroup,
        peer_selection_mode: str = "all",
        sync_interval_ms: int = 500,
        warmup_steps: int = 0,
    ):
        """
        Implementation of the
        `AsyncModelAverage <https://tutorials.baguasys.com/algorithms/async-model-average.html>`_
        algorithm.

        The asynchronous implementation is experimental, and imposes some restrictions.
        With such asynchronous algorithm, the number of iterations on each worker are different. Therefore
        the current implementation assumes that the dataset is an endless stream, and all workers continuously
        synchronize between each other.

        Users should call :meth:`abort` to manually stop the algorithm's continuous synchronization process.
        For example, for a model wrapped with `.with_bagua(...)`, you can abort with `model.bagua_algorithm.abort(model)`,
        and resume with `model.bagua_algorithm.resume(model)`.

        Args:
            process_group (BaguaProcessGroup): The process group to work on.
            peer_selection_mode (str): The way how workers communicate with each other. Currently ``"all"`` is supported.
                ``"all"`` means all workers' weights are synchronized during each communication.
            sync_interval_ms (int): Number of milliseconds between model synchronizations.
            warmup_steps (int): Number of steps to warm up by doing gradient allreduce before doing asynchronous
                model averaging. Use 0 to disable.
        """

        super(AsyncModelAverageAlgorithmImpl, self).__init__(process_group)
        self.peer_selection_mode = peer_selection_mode
        self.sync_interval_ms = sync_interval_ms
        self.step_id = 0
        self.warmup_steps = warmup_steps
        self.cuda_event = torch.cuda.Event()
        self.abort_event = threading.Event()
        self.dummy_tensor = torch.Tensor([0]).byte().cuda()

        # 執行緒池
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        self.scheduled = False

        process_ranks = list(_pg_group_ranks[self.process_group])
        self.thread_group = new_group(
            process_ranks, stream=torch.cuda.Stream(priority=-1)
        )

3.4.2 初始化操作

init_operations 的 這部分呼叫是在 _bagua_reset_algorithm_buckets 之中,每個 BaguaModule 都會做設定,主要是設定:熱身時期是同步操作/其他時間是非同步操作,這裡忽略了大部分程式碼。

def _bagua_reset_algorithm_buckets(self):
    self._bagua_cleanup_algorithm()
    raw_buckets = self._bagua_autotune_get_buckets()
    self.bagua_buckets.extend(self.bagua_algorithm.tensors_to_buckets(raw_buckets))

    for name, param in self.named_parameters():
        # 忽略 real_hook_factory 定義
        if param.requires_grad:
            param_tmp = param.expand_as(param)
            grad_acc = param_tmp.grad_fn.next_functions[0][0]
            hook = grad_acc.register_hook(real_hook_factory(name, param))
            hook.grad_acc = grad_acc
            self._bagua_algorithm_hooks.append(hook)

    optimizer_hook = self.bagua_algorithm.init_post_optimizer_step_hook(self)

    for optimizer in self.bagua_optimizers:
        if not hasattr(optimizer, "_bagua_original_step"):
            optimizer._bagua_original_step = optimizer.step
        # 忽略 new_step_factory 定義
        optimizer.step = new_step_factory(optimizer)

    for bucket in self.bagua_buckets:
        self.bagua_algorithm.init_operations( # 這裡呼叫對演算法的初始化操作
            self,
            bucket,
        )
    self._bagua_backend.register_ordered_buckets(
        [bucket.backend_bucket for bucket in self.bagua_buckets]
    )

就是對於除了熱身期間之外,每個桶都設定了非同步通訊

def init_operations(
    self,
    bagua_module: BaguaModule,
    bucket: BaguaBucket,
):
    bagua_module._bagua_backend.wait_pending_comm_ops()
    bucket.clear_ops()

    if self.step_id < self.warmup_steps:
        bucket.append_centralized_synchronous_op( # 熱身時期是同步操作
            hierarchical=False,
            average=True,
            group=self.process_group,
        )
    else:
        # 其他時間是非同步操作
        async_op = bucket.append_asynchronous_model_average_op(
            peer_selection_mode=self.peer_selection_mode, group=self.thread_group
        )
        bucket._async_op = async_op

3.4.3 加鎖解鎖

我們接下來看看加鎖釋放鎖的基礎操作。bagua/torch_api/algorithms/async_model_average.py 之中有:

def _lock_model(self, bagua_module: BaguaModule):
    torch.cuda.current_stream().record_event(self.cuda_event)
    self.cuda_event.synchronize() # CUDA同步操作

    for bucket in bagua_module.bagua_buckets:
        bucket._async_op.lock_weight() # 加鎖操作

def _unlock_model(self, bagua_module: BaguaModule):
    torch.cuda.current_stream().record_event(self.cuda_event)
    self.cuda_event.synchronize() # CUDA同步操作

    for bucket in bagua_module.bagua_buckets:
        bucket._async_op.unlock_weight() # 釋放鎖

lock_weight 和 unlock_weight 的實現在 rust 程式碼之中。

impl DecentralizedFullPrecisionAsynchronous {
    pub fn lock_weight(&self) {
        let raw_mutex = unsafe { self.weight_mutex.raw() };
        raw_mutex.lock();
    }

    pub fn unlock_weight(&self) {
        unsafe {
            let raw_mutex = self.weight_mutex.raw();
            raw_mutex.unlock();
        };
    }
}

3.4.4 計算執行緒

計算執行緒之中,和加鎖解鎖關鍵步驟如下:

3.4.4.1 前向傳播

前向傳播時候,先進行加鎖,如果非同步迴圈通訊執行緒沒有啟動,則會進行啟動。

def init_forward_pre_hook(self, bagua_module: BaguaModule):
    def hook(input):
        if (
            self.step_id > self.warmup_steps
            and self.sync_interval_ms > 0  # noqa: W503
        ):
            self._lock_model(bagua_module) # 枷鎖

            if not hasattr(self, "future"):
                self.future = self.executor.submit(
                    self._run_async_loop, bagua_module # 啟動非同步迴圈通訊執行緒
                )
                self.scheduled = True

    return hook
3.4.4.2 後向傳播

後向傳播結束之後,會對鎖進行釋放,就是說,前向傳播時候加鎖啟動執行緒,後向傳播時候解鎖,這期間進行計算

def init_backward_hook(self, bagua_module: BaguaModule):
    def hook(parameter_name, parameter):
        if self.step_id <= self.warmup_steps:
            parameter._bagua_grad.bagua_mark_communication_ready() # 通知後端可以通訊

    return hook

def init_post_backward_hook(self, bagua_module: BaguaModule):
    def hook():
        if self.step_id <= self.warmup_steps:
            bagua_module._bagua_backend.wait_pending_comm_ops() # 等待
        else:
            self._unlock_model(bagua_module) # 解鎖

    return hook

此時邏輯如下:

+---------------------------------------------------------------------------+
| AsyncModelAverageAlgorithmImpl                                            |
|                                                                           |
|  +-----------------------------+                 +----------------------+ |
|  | Computation thread          |                 | BaguaBucket          | |
|  |                             | set async_op    |  +----------------+  | |
|  |    init_operations   +----------------------> |  | _async_op      |  | |
|  |                             |                 |  |                |  | |
|  |                             | lock_weight()   |  |                |  | |
|  |    init_forward_pre_hook +------------------> |  |                |  | |
|  |                             | unlock_weight() |  |                |  | |
|  |    init_post_backward_hook+-----------------> |  |                |  | |
|  |                             |                 |  |                |  | |
|  |                             |                 |  +----------------+  | |
|  +-----------------------------+                 +----------------------+ |
|                                                                           |
|  +-----------------------------+                                          |
|  | Communation thread          |                                          |
|  |                             |                                          |
|  | _run_async_loop             |                                          |
|  |                             |                                          |
|  |                             |                                          |
|  +-----------------------------+                                          |
|                                                                           |
+---------------------------------------------------------------------------+

3.4.5 通訊執行緒

通訊執行緒主迴圈如下,主要是通知後端,進行通訊

def _run_async_loop(self, bagua_module: BaguaModule):
    comm_step = 0
    while True:
        state = self._negotiate()
        if state == _AsyncInternalState.ABORT:
            break

        start_time = time.time()
        for bucket in bagua_module.bagua_buckets: # 遍歷桶
            for tensor in bucket.tensors: # 遍歷張量
                # 通知後端,進行通訊
                tensor.bagua_mark_communication_ready_without_synchronization() 

        bagua_module._bagua_backend.wait_pending_comm_ops()
        duration = (time.time() - start_time) * 1000

        comm_step += 1
        time.sleep(self.sync_interval_ms / 1000)
3.4.5.1通知後端
Python

bagua_mark_communication_ready_without_synchronization 的實現如下,呼叫後端的 mark_communication_ready。

def bagua_mark_communication_ready_without_synchronization(self):
    """
    Mark a Bagua tensor ready immediately, without `CUDA event <https://pytorch.org/docs/stable/generated/torch.cuda.Event.html?highlight=event#torch.cuda.Event>`_ synchronization.
    """
    self.bagua_backend.mark_communication_ready(
        self._bagua_backend_tensor,
        0,
    )
Rust

mark_communication_ready 的實現在 rust 之中。位置是 rust/bagua-core/bagua-core-py/src/lib.rs。

pub fn mark_communication_ready(
    &mut self,
    tensor: PyRef<BaguaTensorPy>,
    ready_cuda_event_ptr: u64,
    py: Python,
) -> PyResult<()> {
    let inner = &tensor.inner;
    py.allow_threads(|| {
        self.inner
            .mark_communication_ready(inner, ready_cuda_event_ptr)
    })
    .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))
}

rust/bagua-core/bagua-core-internal/src/lib.rs 之中有:

pub fn mark_communication_ready(
    &mut self,
    tensor: &BaguaTensor,
    ready_cuda_event_ptr: u64,
) -> Result<(), BaguaCoreError> {
    let tracer = global::tracer("bagua-core");
    let mut span = tracer.start("tensor_ready");
    span.set_attribute(KeyValue::new("tensor_name", tensor.name()));

    tensor.mark_comm_ready(ready_cuda_event_ptr);
    while self.should_schedule()? {
        let bucket = self.ordered_buckets.pop_front().unwrap();
        bucket.reset_comm_ready();
        let bucket_clone = bucket.clone();
        self.ordered_buckets.push_back(bucket);
        self.schedule_comm(bucket_clone)?;
    }
    Ok(())
}

schedule_comm 在 rust/bagua-core/bagua-core-internal/src/lib.rs 之中。

pub fn schedule_comm(&self, bucket: Arc<BaguaBucket>) -> Result<(), BaguaCoreError> {
    let event_channel = BaguaEventChannel::new("comm_op");
    self.channels
        .schedule_channel_sender
        .send(BaguaScheduledCommOp {
            name: format!("comm op for bucket {}", bucket.name),
            ops: {
                let guard = bucket.inner.lock();
                guard.comm_ops.clone() // 獲取bucket的op,進行呼叫
            },
            bucket,
            event_channel: event_channel.clone(),
        })
        .map_err(|e| BaguaCoreError::InternalChannelError(format!("{:?}", e)))?;
    Ok(self
        .channels
        .not_waited_events_sender
        .send(event_channel)
        .map_err(|e| BaguaCoreError::InternalChannelError(format!("{:?}", e)))?)
}

傳送了一個 BaguaScheduledCommOp。

pub struct BaguaScheduledCommOp {
    pub name: String,
    pub bucket: Arc<BaguaBucket>,
    pub ops: Vec<Arc<dyn CommOpTrait + Send + Sync>>,
    pub event_channel: BaguaEventChannel,
}

邏輯如下:

+---------------------------------------------------+    +----------------------------+
| AsyncModelAverageAlgorithmImpl                    |    | BaguaBucket                |
|                                                   |    | +------------------------+ |
|  +-----------------------------+                  |    | | _async_op              | |
|  | Computation thread          |                  |    | |                        | |
|  |                             |    set async_op  |    | |                        | |
|  |    init_operations   +----------------------------> | |                        | |
|  |                             |                  |    | |                        | |
|  |                             |    lock_weight() |    | |                        | |
|  |    init_forward_pre_hook +------------------------> | |                        | |
|  |                             |   unlock_weight()|    | |                        | |
|  |    init_post_backward_hook+-----------------------> | |                        | |
|  |                             |                  |    | +------------------------+ |
|  |                             |                  |    +----------------------------+
|  +-----------------------------+                  |
|  +---------------------------------+              |
|  | Communation thread              |              |    +----------------------------+
|  | +-----------------------------+ |              |    | BaguaCommBackendPy         |
|  | |                             | |              |    |                            |
|  | | _run_async_loop    +----------------------------> |   mark_communication_ready |
|  | |                             | |              |    |            +               |
|  | +-----------------------------+ |              |    |            |               |
|  +---------------------------------+              |    |            v               |
+---------------------------------------------------+    |      schedule_comm         |
                                                         |                            |
                                                         +----------------------------+
3.4.5.2 歸併

schedule_comm 最終會呼叫到 bucket.comm_ops,該變數在初始化時候被配置為 DecentralizedFullPrecisionAsynchronous,所以我們需要回頭來一步一步看看如何歸併。

前面初始化操作時候有使用 bucket.append_asynchronous_model_average_op 進行配置。

def init_operations(
    self,
    bagua_module: BaguaModule,
    bucket: BaguaBucket,
):
    bagua_module._bagua_backend.wait_pending_comm_ops()
    bucket.clear_ops()

    if self.step_id < self.warmup_steps:
        bucket.append_centralized_synchronous_op( # 熱身時期是同步操作
            hierarchical=False,
            average=True,
            group=self.process_group,
        )
    else:
        # 其他時間是非同步操作
        async_op = bucket.append_asynchronous_model_average_op( # 進行歸併配置
            peer_selection_mode=self.peer_selection_mode, group=self.thread_group
        )
        bucket._async_op = async_op
Python

append_asynchronous_model_average_op 程式碼在 bagua/torch_api/bucket.py。其作用是:

  • 將非同步模型歸併操作附加到bucket。此操作將在訓練模型時啟用 worker 之間的連續模型平均。當bucket中的所有張量都標記為ready時,操作將由Bagua後端按照追加的順序執行。

  • 此操作旨在與計算過程並行執行。它返回對op的引用。op具有獨佔訪問模型的鎖。呼叫op.lock_weight()獲取鎖,呼叫op.unlock_weight()釋放鎖。

  • 重點在於,張量 ready 之後進行操作。

def append_asynchronous_model_average_op(
    self, peer_selection_mode: str, group: Optional[BaguaProcessGroup] = None
):

    """
    Append an asynchronous model average operation to a bucket. This operation will enable continuous
    model averaging between workers while training a model.

    The operations will be executed by the Bagua backend in the order they are appended
    when all the tensors within the bucket are marked ready.

    This operation is intended to run in parallel with the computation process. It returns a reference
    to the op. The op features a lock to exclusively access the model. Call ``op.lock_weight()`` to
    acquire the lock and ``op.unlock_weight()`` to release it.

    Args:
        peer_selection_mode (str): The way how workers communicate with each otehr. Currently ``"all"`` is supported.
            ``"all"`` means all workers' weights are averaged during each communication.
        group: The process group to work on. If ``None``, the default process group will be used.
    Returns:
        The asynchronous model average operation itself.
    """
    if group is None:
        group = _get_default_group()

    return self.backend_bucket.append_decentralized_asynchronous_op(
        _bagua_backend_comm(group.get_global_communicator()),
        None,
        peer_selection_mode=peer_selection_mode,
        torch_stream=torch.cuda.current_stream().cuda_stream,
    )
Rust

append_decentralized_asynchronous_op 函式在 rust 之中,其呼叫了 DecentralizedFullPrecisionAsynchronous,就是往 bucket.comm_ops 之上新增了一個 DecentralizedFullPrecisionAsynchronous。

    pub fn append_decentralized_asynchronous_op(
        &mut self,
        communicator_internode: Option<&BaguaSingleCommunicator>,
        communicator_intranode: Option<&BaguaSingleCommunicator>,
        peer_selection_mode: String,
        torch_stream: u64,
    ) -> Arc<DecentralizedFullPrecisionAsynchronous> {
        let communicator =
            BaguaCommunicator::new(communicator_internode, communicator_intranode, false)
                .expect("cannot create communicator");

        let comm_op = Arc::new(DecentralizedFullPrecisionAsynchronous {
            communicator,
            peer_selection_mode: match peer_selection_mode.as_str() {
                "all" => PeerSelectionMode::All,
                &_ => {
                    unimplemented!("unsupported peer_selection_mode for decentralized asynchronous algorithm (should be `all`)")
                }
            },
            torch_stream,
            weight_mutex: Arc::new(Mutex::new(true)),
        });

        self.inner
            .lock()
            .comm_ops // 插入到 bucket 的 comm_ops
            .push(comm_op.clone() as Arc<dyn CommOpTrait + Send + Sync>);

        comm_op
    }

DecentralizedFullPrecisionAsynchronous 裡面有加鎖,釋放鎖,CUDA 同步操作等等,恰好與前面提到的前向傳播/後向傳播對應。

impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {
    fn execute_background_communication(
        &self,
        bucket: Arc<BaguaBucket>,
        comm_op_channels: &BaguaCommOpChannels,
    ) {
        let bucket_guard = bucket.inner.lock();

        let comm_stream = self.communicator.stream_ptr();

        let mut communication_tensor = match &self.communicator {
            BaguaCommunicator::SingleCommunicator(_) => {
                bucket_guard.get_communication_tensor(comm_stream, false, false)
            }
            BaguaCommunicator::HierarchicalCommunicator(x) => {
                panic!("asynchronous op only accepts non-hierarchical communicator");
            }
        };

        let peer_mode = &self.peer_selection_mode;

        let torch_stream = self.torch_stream;

        self.communicator.execute_communication(
            &mut communication_tensor,
            false,
            false,
            false,
            &mut |c, t| {
                let start_time = std::time::Instant::now();
   
                let temp_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id()]
                    .try_pull(t.raw.num_elements_allocated() * t.raw.dtype().bytes())
                    .expect("cannot allocate cuda memory");

                let mut temp_tensor = BaguaTensorRaw {
                    ptr: temp_buf.ptr,
                    num_elem_allocated: t.raw.num_elements_allocated(),
                    dtype: t.raw.dtype().clone(),
                    num_elem: t.raw.num_elements(),
                    device_id: t.raw.device_id(),
                    pool_allocations: vec![Arc::new(temp_buf)],
                };

                let reduced_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id()]
                    .try_pull(t.raw.num_elements_allocated() * t.raw.dtype().bytes())
                    .expect("cannot allocate cuda memory");

                let mut reduced_tensor = BaguaTensorRaw {
                    ptr: reduced_buf.ptr,
                    num_elem_allocated: t.raw.num_elements_allocated(),
                    dtype: t.raw.dtype().clone(),
                    num_elem: t.raw.num_elements(),
                    device_id: t.raw.device_id(),
                    pool_allocations: vec![Arc::new(reduced_buf)],
                };

                let src_ready_event = CUDA_EVENT_POOL.take().event;

                // use default stream to copy weights
                temp_tensor.clone_from(&t.raw, torch_stream as u64);

                unsafe {
                    cpp::cpp!([
                        src_ready_event as "cudaEvent_t",
                        comm_stream as "cudaStream_t",
                        torch_stream as "cudaStream_t"]
                    {
                        CUDACHECK(cudaEventRecord(src_ready_event, torch_stream));
                        CUDACHECK(cudaStreamWaitEvent(comm_stream, src_ready_event , 0));
                    });
                }

                match peer_mode {
                    PeerSelectionMode::All => {
                        c.allreduce(&temp_tensor, &mut reduced_tensor, BaguaReductionOp::SUM);
                    }
                    PeerSelectionMode::Ring => {
                        unimplemented!()
                    }
                    PeerSelectionMode::ShiftOne => {
                        unimplemented!()
                    }
                };

                {
                    // 獲取 ready event
                    let ready_event = CUDA_EVENT_POOL.take().event;
                    unsafe {
                        cpp::cpp!([
                            ready_event as "cudaEvent_t",
                            comm_stream as "cudaStream_t"]
                        {
                            // CUDA 同步操作
                            CUDACHECK(cudaEventRecord(ready_event, comm_stream));
                            CUDACHECK(cudaEventSynchronize(ready_event));
                        });
                    }

                    self.lock_weight(); // 加鎖
                  
                    t.raw.async_model_average(
                        &reduced_tensor,
                        &temp_tensor,
                        c.nranks as f32,
                        comm_stream,
                    );

                    unsafe {
                        cpp::cpp!([
                            ready_event as "cudaEvent_t",
                            comm_stream as "cudaStream_t"]
                        {
                            // 對CUDA進行操作
                            CUDACHECK(cudaEventRecord(ready_event, comm_stream));
                            CUDACHECK(cudaEventSynchronize(ready_event));
                        });
                    }
                    self.unlock_weight(); // 解鎖
                }

                tracing::debug!(
                    "#{} async model average update cost: {:?}",
                    c.rank,
                    start_time.elapsed()
                );
            },
        );
    }
}

在 rust/bagua-core/bagua-core-internal/kernels/bagua_kernels.cu 之中有最終操作。

__global__ void async_model_average(float *tensor, const float *reduced_tensor_copy, 
      const float *tensor_copy, const float nranks, const int N) {
    for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {  
   tensor[i] += reduced_tensor_copy[i] / nranks - tensor_copy[i];
    }
}

我們總結邏輯如下:

  • (1)init_operations 會進行一系列呼叫,生成了一個DecentralizedFullPrecisionAsynchronous,賦值在bucket 的 comm_ops 和 aysnc_op 之上。

計算執行緒之中做如下操作:

  • (2)計算執行緒之中,在前向傳播之前設定了hook,其中會 lock weight。
  • (3)計算執行緒之中,在後向傳播之前設定了hook,其中會 unlock weight。

通訊執行緒之中做如下操作:

  • (4)會呼叫 mark_communication_ready 進行通訊設定。
  • (5)mark_communication_ready 最終呼叫到 schedule_comm,其會啟動 bucket.comm_ops,bucket.comm_ops 就是 DecentralizedFullPrecisionAsynchronous。
  • DecentralizedFullPrecisionAsynchronous 之中會:
    • (6)lock weight。
    • (7)會進行非同步模型歸併。
    • (8)會 unlock weight。
  +---------------------------------------------------+   +----------------------+    +----------------------------------------+
  | AsyncModelAverageAlgorithmImpl                    |   |  BaguaBucket         |    | DecentralizedFullPrecisionAsynchronous |
  |                                                   |   |                 1    |    |                                        |
  |  +-----------------------------+                  |   |       comm_ops +--------> |  6   self.lock_weight()                |
  |  | Computation thread          |  1 set async_op  |   |                      |    |                                        |
  |  |                             |                  |   |    +--------------+  |    |                                        |
  |  |    init_operations   +---------------------------->+    | _async_op  1 |  |    |  7   t.raw.async_model_average(        |
  |  |                             |                  |   |    |           +--------> |                &reduced_tensor,        |
  |  |                             |                  |   |    |              |  |    |                &temp_tensor,           |
  |  |                             |                  |   |    |              |  |    |                c.nranks as f32,        |
  |  |                             |                  |   |    |              |  |    |                comm_stream,            |
  |  |                             |  2 lock_weight() |   |    |              |  |    |            );                          |
  |  |    init_forward_pre_hook +----------------------------> |              |  |    |                                        |
  |  |                             | 3 unlock_weight()|   |    |              |  |    |                                        |
  |  |    init_post_backward_hook+---------------------------> |              |  |    |  8   self.unlock_weight()              |
  |  |                             |                  |   |    +--------------+  |    |                                        |
  |  |                             |                  |   |                      |    +--------+-------------------------------+
  |  +-----------------------------+                  |   +----------------------+             ^
  |                                                   |                                        |
+--------------------------------------------------------------------------------------------------------------------------------+
  |                                                   |                                        |
  |  +---------------------------------+              |                                        |
  |  | Communation thread              |              |   +-----------------------------+      |
  |  | +-----------------------------+ |              |   |  BaguaCommBackendPy         |      |
  |  | |                             | |     4        |   |                             |      |
  |  | | _run_async_loop    +--------------------------------> mark_communication_ready |      |
  |  | |                             | |              |   |             +               |      | 5
  |  | +-----------------------------+ |              |   |             |               |      |
  |  +---------------------------------+              |   |             v               |      |
  +---------------------------------------------------+   |       schedule_comm         |      |
                                                          |             +               |      |
                                                          |             |               |      |
                                                          |             v               |      |
                                                          |       bucket.comm_ops  +-----------+
                                                          |                             |
                                                          +-----------------------------+

手機如下:

或者我們換一個角度來看,就是左右兩個執行緒都操作桶,通過鎖來協調競爭,特色除了鎖之外,就在DecentralizedFullPrecisionAsynchronous 之中。這裡需要注意的是,數值 1 的意義是設定,就是 bucket 的 _async_op 和 comm_ops 都配置成 DecentralizedFullPrecisionAsynchronous,最後通訊執行緒之中(4)會呼叫 mark_communication_ready 進行通訊設定。

                                                                                                                             +-------------------------+
                                                 +----------------------+                                                    | Communation thread      |
                                                 |  BaguaBucket         |                                                    | +---------------------+ |
                                                 |                      | 1                                                  | |                     | |
+---------------------------+                    |       comm_ops +--------------------------------+                         | | _run_async_loop     | |
| Computation thread        |  1 set async_op    |                      |                          |                         | |          +          | |
|                           |                    |    +--------------+  |                          |                         | |          |          | |
|  init_operations   +-------------------------->+    | _async_op    |  | 1                        |                         | +---------------------+ |
|                           |                    |    |           +------------------+             |                         +-------------------------+
|                           |                    |    |              |  |            |             |                                      |
|                           |                    |    |              |  |            |             |                                      |
|                           |                    |    |              |  |            v             v                                      v
|                           |  2 lock_weight()   |    |              |  |     +------+-------------+-------------------+    +-------------+---------------+
|  init_forward_pre_hook +--------------------------> |              |  |     | DecentralizedFullPrecisionAsynchronous |    |  BaguaCommBackendPy         |
|                           |                    |    |              |  | 6   |                                        |    |                             |
|                           |                    |    |              +<------------+ self.lock_weight()                |    |    mark_communication_ready |
|                           |                    |    |              |  |     |                                        |    |             +               |
|                           |                    |    |              |  |     |  7   t.raw.async_model_average(        |    |             |               |
|                           |                    |    |              |  |     |                &reduced_tensor,        |    |             v               |
|                           |                    |    |              |  |     |                &temp_tensor,           |    |       schedule_comm         |
|                           |                    |    |              |  |     |                c.nranks as f32,        |    |             +               |
|                           |                    |    |              |  |     |                comm_stream,            |    |             |               |
|                           |                    |    |              |  |     |            );                          |  4 |             v               |
|                           |                    |    |              |  | 8   |                                        +<--------+  bucket.comm_ops       |
|                           | 3 unlock_weight()  |    |              +<-----------+  self.unlock_weight()              |    |                             |
|  init_post_backward_hook+-------------------------> |              |  |     |                                        |    +-----------------------------+
|                           |                    |    |              |  |     +----------------------------------------+
|                           |                    |    +--------------+  |
|                           |                    |                      |
+---------------------------+                    +----------------------+

手機如下:

至此,八卦框架分析完畢,這個框架無論是論文,程式碼,文件,介紹網站,PPT都非常給力,推薦有興趣的朋友繼續深入研究。

0xFF 參考

PyTorch internals

快手八卦!突破 TensorFlow、PyTorch 並行瓶頸的開源分散式訓練框架來了!

https://arxiv.org/pdf/2107.01499.pdf

https://tutorials.baguasys.com/algorithms/decentralized

[1] Dean, Jeffrey, Greg S. Corrado, Rajat Monga, Kai Chen, Matthieu Devin, Quoc V. Le, Mark Z. Mao et al. “Large scale distributed deep networks.” (2012).

[2] Zhengyuan Zhou, Panayotis Mertikopoulos, Nicholas Bambos, Peter Glynn, Yinyu Ye, Li-Jia Li, and Li Fei-Fei. 2018. Distributed asynchronous optimization with unbounded delays: How slow can you go?. In International Conference on Machine Learning. PMLR, 5970–5979.

[3] DanAlistarh, DemjanGrubic, JerryLi, RyotaTomioka, and MilanVojnovic. 2016. QSGD: Communication-efficient SGD via gradient quantization and encoding. arXiv preprint arXiv:1610.02132 (2016).

[4] Dan Alistarh, Torsten Hoefler, Mikael Johansson, Sarit Khirirat, Nikola Konstanti- nov, and Cédric Renggli. 2018. The convergence of sparsified gradient methods. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. 5977–5987.

[5] Anastasia Koloskova, Sebastian Stich, and Martin Jaggi. 2019. Decentralized stochastic optimization and gossip algorithms with compressed communication. In International Conference on Machine Learning. PMLR, 3478–3487.

[6] Xiangru Lian, Ce Zhang, Huan Zhang, Cho-Jui Hsieh, Wei Zhang, and Ji Liu. 2017. Can decentralized algorithms outperform centralized algorithms? a case study for decentralized parallel stochastic gradient descent. In Proceedings of the 31st International Conference on Neural Information Processing Systems. 5336–5346.

[7] Christopher De Sa, Matthew Feldman, Christopher Ré, and Kunle Olukotun. 2017. Understanding and optimizing asynchronous low-precision stochastic gradient descent. In Proceedings of the 44th Annual International Symposium on Computer Architecture. 561–574.

[8] Xiangru Lian, Wei Zhang, Ce Zhang, and Ji Liu. 2018. Asynchronous decentral- ized parallel stochastic gradient descent. In International Conference on Machine Learning. PMLR, 3043–3052.

[9] Hanlin Tang, Shaoduo Gan, Ce Zhang, Tong Zhang, and Ji Liu. 2018. Com- munication compression for decentralized training. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. 7663–7673.

[10] Ji Liu, Ce Zhang, et al. 2020. Distributed Learning Systems with First-Order Methods. Foundations and Trends® in Databases 9, 1 (2020), 1–100.

相關文章