[原始碼解析] PyTorch 分散式(16) --- 使用非同步執行實現批處理 RPC

羅西的思考發表於2021-12-15

[原始碼解析] PyTorch 分散式(16) --- 使用非同步執行實現批處理 RPC

0x00 摘要

在前面的文章之中,我們已經學習了PyTorch 分散式的基本模組,接下來我們通過幾篇文章來看看如何把這些模組應用到實踐之中,順便把PyTorch分散式邏輯整體梳理一下。本文介紹如何使用非同步執行操作來實現批處理 RPC,大家可以學習到PyTorch對引數伺服器一個新的實現方式。

本文以IMPLEMENTING BATCH RPC PROCESSING USING ASYNCHRONOUS EXECUTIONS的翻譯為基礎,加入了自己的理解。

PyTorch分散式其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀

[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎

[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構

[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯

[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法

[原始碼解析] PyTorch 分散式(1)------歷史和概述

[原始碼解析] PyTorch 分散式(2) ----- DataParallel(上)

[原始碼解析] PyTorch 分散式(3) ----- DataParallel(下)

[原始碼解析] PyTorch 分散式(4)------分散式應用基礎概念

[原始碼解析] PyTorch分散式(5) ------ DistributedDataParallel 總述&如何使用

[原始碼解析] PyTorch分散式(6) ---DistributedDataParallel -- 初始化&store

[原始碼解析] PyTorch 分散式(7) ----- DistributedDataParallel 之程式組

[原始碼解析] PyTorch 分散式(8) -------- DistributedDataParallel之論文篇

[原始碼解析] PyTorch 分散式(9) ----- DistributedDataParallel 之初始化

[原始碼解析] PyTorch 分散式(10)------DistributedDataParallel 之 Reducer靜態架構

[原始碼解析] PyTorch 分散式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作

[原始碼解析] PyTorch 分散式(12) ----- DistributedDataParallel 之 前向傳播

[原始碼解析] PyTorch 分散式(13) ----- DistributedDataParallel 之 反向傳播

[原始碼解析] PyTorch 分散式 Autograd (1) ---- 設計

[原始碼解析] PyTorch 分散式 Autograd (2) ---- RPC基礎

[原始碼解析] PyTorch 分散式 Autograd (3) ---- 上下文相關

[原始碼解析] PyTorch 分散式 Autograd (4) ---- 如何切入引擎

[原始碼解析] PyTorch 分散式 Autograd (5) ---- 引擎(上)

[原始碼解析] PyTorch 分散式 Autograd (6) ---- 引擎(下)

[原始碼解析] PyTorch分散式優化器(1)----基石篇

[原始碼解析] PyTorch分散式優化器(2)----資料並行優化器

[原始碼解析] PyTorch分散式優化器(3)---- 模型並行

[原始碼解析] PyTorch 分散式(14) --使用 Distributed Autograd 和 Distributed Optimizer

[原始碼解析] PyTorch 分散式(15) --- 使用分散式 RPC 框架實現引數伺服器

注:本文沒有完全按照原文順序進行翻譯,而是按照自己理解的思路重新組織了文章。

0x01 前言

1.1 先決條件

本文的先決條件如下:

本教程演示瞭如何使用@rpc.functions.async_execution 裝飾器構建批處理 RPC 應用程式,這有助於通過減少被阻塞的 RPC 執行緒的數量,並且在被呼叫方整合 CUDA 操作來加快訓練速度。這與使用 TorchServer 進行批量推理的想法相同。Batch RPC 有助於將動作整合到較少的 CUDA 操作中,從而攤銷開銷。

注意:本教程需要 PyTorch v1.6.0 或更高版本。

1.2 基礎知識

之前的教程已經展示了使用torch.distributed.rpc構建分散式訓練應用程式的步驟,但他們沒有詳細說明在處理 RPC 請求時被呼叫方會發生什麼。從 PyTorch v1.5 開始,針對每個 RPC 請求,被呼叫者都會啟動一個執行緒來執行該請求中的函式,該執行緒會阻塞直到該函式返回。這適用於許多用例,但有一個問題:如果使用者函式在 IO 上阻塞,例如使用巢狀的 RPC 呼叫或訊號(例如等待不同的 RPC 請求來解除阻塞),則被呼叫者上的 RPC 執行緒將不得不空閒等待,直到 IO 完成或訊號(signal)事件發生。因此,RPC 被呼叫者使用的執行緒可能會使用比實際需要更多。造成這個問題的原因是RPC把使用者函式當成黑盒,對函式中發生的事情知之甚少。為了讓使用者函式能夠讓出和釋放 RPC 執行緒,需要向 RPC 系統提供更多的提示。

從 v1.6.0 開始,PyTorch 通過引入兩個新概念來解決這個問題:

  • torch.futures.Future 封裝了一個非同步執行,同時也支援安裝回撥函式。
  • @rpc.functions.async_execution 裝飾器,它允許應用程式告訴被呼叫者,本目標函式將返回一個future,並且可以在執行過程中多次暫停和yield。

使用這兩個工具,應用程式程式碼可以將使用者函式分解為多個較小的函式,將它們連結在一起作為Future 物件的回撥方法,並返回包含最終結果的 Future給呼叫者。在被呼叫方,在獲取Future物件時,它也會安裝後續的 RPC 響應處理作為回撥方法,這些回撥會在最終結果準備好時被觸發。這樣,被呼叫者不再需要阻塞一個執行緒,只是等待最終返回值準備好就行。 簡單的例子請參考@rpc.functions.async_execution的API文件 。

除了減少被呼叫者的空閒執行緒數量外,這些工具還使批處理 RPC 處理更容易、更快。本教程演示瞭如何使用@rpc.functions.async_execution 裝飾器構建分散式批量更新引數伺服器和批量處理強化學習應用程式 。

注:我們不考慮強化學習的領域,那樣會影響我們的思路,牽扯精力

1.3 程式碼

因為原文主要是強化學習程式碼講解,而我們只關注普通分散式批量更新引數伺服器,所以需要看原始程式碼。

程式碼位於 https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py。先全部摘錄如下:

import os
import threading
from datetime import datetime

import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
from torch import optim

import torchvision


batch_size = 20
image_w = 64
image_h = 64
num_classes = 30
batch_update_size = 5
num_batches = 6

def timed_log(text):
    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")

class BatchUpdateParameterServer(object):

    def __init__(self, batch_update_size=batch_update_size):
        self.model = torchvision.models.resnet50(num_classes=num_classes)
        self.lock = threading.Lock()
        self.future_model = torch.futures.Future()
        self.batch_update_size = batch_update_size
        self.curr_update_size = 0
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        for p in self.model.parameters():
            p.grad = torch.zeros_like(p)

    def get_model(self):
        return self.model

    @staticmethod
    @rpc.functions.async_execution
    def update_and_fetch_model(ps_rref, grads):
        self = ps_rref.local_value()
        timed_log(f"PS got {self.curr_update_size}/{batch_update_size} updates")
        for p, g in zip(self.model.parameters(), grads):
            p.grad += g
        with self.lock:
            self.curr_update_size += 1
            fut = self.future_model

            if self.curr_update_size >= self.batch_update_size:
                for p in self.model.parameters():
                    p.grad /= self.batch_update_size
                self.curr_update_size = 0
                self.optimizer.step()
                self.optimizer.zero_grad()
                fut.set_result(self.model)
                timed_log("PS updated model")
                self.future_model = torch.futures.Future()

        return fut


class Trainer(object):

    def __init__(self, ps_rref):
        self.ps_rref = ps_rref
        self.loss_fn = nn.MSELoss()
        self.one_hot_indices = torch.LongTensor(batch_size) \
                                    .random_(0, num_classes) \
                                    .view(batch_size, 1)

    def get_next_batch(self):
        for _ in range(num_batches):
            inputs = torch.randn(batch_size, 3, image_w, image_h)
            labels = torch.zeros(batch_size, num_classes) \
                        .scatter_(1, self.one_hot_indices, 1)
            yield inputs.cuda(), labels.cuda()

    def train(self):
        name = rpc.get_worker_info().name
        m = self.ps_rref.rpc_sync().get_model().cuda()
        for inputs, labels in self.get_next_batch():
            timed_log(f"{name} processing one batch")
            self.loss_fn(m(inputs), labels).backward()
            timed_log(f"{name} reporting grads")
            m = rpc.rpc_sync(
                self.ps_rref.owner(),
                BatchUpdateParameterServer.update_and_fetch_model,
                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
            ).cuda()
            timed_log(f"{name} got updated model")


def run_trainer(ps_rref):
    trainer = Trainer(ps_rref)
    trainer.train()


def run_ps(trainers):
    timed_log("Start training")
    ps_rref = rpc.RRef(BatchUpdateParameterServer())
    futs = []
    for trainer in trainers:
        futs.append(
            rpc.rpc_async(trainer, run_trainer, args=(ps_rref,))
        )

    torch.futures.wait_all(futs)
    timed_log("Finish training")


def run(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    options=rpc.TensorPipeRpcBackendOptions(
        num_worker_threads=16,
        rpc_timeout=0  # infinite timeout
     )
    if rank != 0:
        rpc.init_rpc(
            f"trainer{rank}",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        # trainer passively waiting for ps to kick off training iterations
    else:
        rpc.init_rpc(
            "ps",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        run_ps([f"trainer{r}" for r in range(1, world_size)])

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = batch_update_size + 1
    mp.spawn(run, args=(world_size, ), nprocs=world_size, join=True)

0x02 啟動

我們首先看看如何啟動。

2.1 總體啟動

我們假設有一個master(rank 0),一個worker。Master 之上執行的是引數伺服器,worker 之上是訓練程式碼。

def run(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    options=rpc.TensorPipeRpcBackendOptions(
        num_worker_threads=16,
        rpc_timeout=0  # infinite timeout
     )
    if rank != 0:
        rpc.init_rpc( # 訓練程式碼
            f"trainer{rank}",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        # trainer passively waiting for ps to kick off training iterations
    else:
        rpc.init_rpc( # 引數伺服器
            "ps", 
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        run_ps([f"trainer{r}" for r in range(1, world_size)])

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = batch_update_size + 1
    mp.spawn(run, args=(world_size, ), nprocs=world_size, join=True)

邏輯如下圖:

             torch.multiprocessing.spawn
                        +
                        |
                        |
           +------------+-------------------------------------------------
           |                                                             |
           |                                                             |
           v                                                             v
+----------+----------------------------------------------+ +------------+----------------+
| "ps"                                           rank = 0 | | f"trainer{rank}"   rank = 1 |
|                                                         | |                             |
|                                                         | |                             |
|                     rpc.init_rpc                        | |         rpc.init_rpc        |
|                                                         | |                             |
|                                                         | |                             |
|  run_ps([f"trainer{r}" for r in range(1, world_size)])  | |                             |
|                                                         | |                             |
|                                                         | |                             |
+---------------------------------------------------------+ +-----------------------------+

2.2 啟動引數伺服器

run_ps 啟動了引數伺服器和trainer。注意,這裡在引數伺服器之中啟動 trainer,即,master 不僅僅有一個引數伺服器,還負責通過 rpc 來驅動trainer上的訓練迴圈。

def run_ps(trainers):
    timed_log("Start training")
    ps_rref = rpc.RRef(BatchUpdateParameterServer())
    futs = []
    for trainer in trainers: # trainer 是字串,比如"trainer1"
        futs.append(
            rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) # 執行run_trainer
        )

    torch.futures.wait_all(futs)
    timed_log("Finish training")
    
def run_trainer(ps_rref):
    trainer = Trainer(ps_rref)
    trainer.train() # 呼叫 Trainer 的方法   

具體擴充如下:

這裡沒有給出引數伺服器和trainer的邏輯,我們會在後續分析之後陸續給出。trainer 也只給出了一個。

0x03 引數伺服器

上面圖中沒有給出具體引數伺服器程式碼,我們接下來就分析一下。

這裡考慮具有一個引數伺服器 (PS) 和多個trainer的同步訓練應用程式。在這個應用中,PS 持有引數並等待所有訓練器報告梯度。在每次迭代中,它等待直到從所有訓練器接收梯度,然後一次性更新所有引數。

下面的程式碼顯示了 PS 類的實現。

  • PS初始化時候生成了常規SGB優化器,不是分散式優化器,而且優化器是在PS之上
  • update_and_fetch_model方法被 @rpc.functions.async_execution所裝飾,將由trainer呼叫。
  • 每次呼叫都會返回一個Future物件,該物件將被用來處理更新後的模型。
  • 大多數訓練器發起的呼叫只是累積梯度到 .grad成員變數 ,然後立即返回,並在 PS 上產生 RPC 執行緒。
  • 最後到達的訓練器將觸發優化器步驟並消耗所有先前上報的梯度。然後它使用更新後的模型來設定future_model,這是依靠通過Future物件來依次通知來自其他訓練者的先前請求,並將更新後的模型傳送給所有訓練者。

具體程式碼如下:

batch_size = 20
image_w = 64
image_h = 64
num_classes = 30
batch_update_size = 5
num_batches = 6

def timed_log(text):
    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")

class BatchUpdateParameterServer(object):

    def __init__(self, batch_update_size=batch_update_size):
        self.model = torchvision.models.resnet50(num_classes=num_classes)
        self.lock = threading.Lock()
        self.future_model = torch.futures.Future()
        self.batch_update_size = batch_update_size
        self.curr_update_size = 0
        # 重點:這裡是常規SGB優化器,不是分散式優化器
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        for p in self.model.parameters():
            p.grad = torch.zeros_like(p)

    def get_model(self):
        return self.model

    @staticmethod
    @rpc.functions.async_execution # trainer會直接呼叫
    def update_and_fetch_model(ps_rref, grads):
        self = ps_rref.local_value()
        timed_log(f"PS got {self.curr_update_size}/{batch_update_size} updates")
        for p, g in zip(self.model.parameters(), grads): # 得到
            p.grad += g # 累積梯度
        with self.lock:
            self.curr_update_size += 1
            fut = self.future_model

            if self.curr_update_size >= self.batch_update_size:
                # 最後到達的訓練器將觸發優化器步驟並消耗所有先前上報的梯度。
                for p in self.model.parameters():
                    p.grad /= self.batch_update_size
                self.curr_update_size = 0
                self.optimizer.step() # 更新模型
                self.optimizer.zero_grad()
                fut.set_result(self.model) # 將更新後的模型傳送給所有訓練者
                timed_log("PS updated model")
                self.future_model = torch.futures.Future() # 使用更新後的模型來設定future_model

        return fut # 該物件將被用來處理更新後的模型

邏輯擴充如下,這裡省略了引數伺服器生成trainer的步驟:

手機如下:

0x04 Trainer

對於訓練器,它們都使用來自 PS 的相同引數集進行初始化。在每次迭代中執行如下操作:

  • 每個訓練器首先執行前向和後向傳播以在本地生成梯度。
  • 然後,每個訓練器使用 RPC 向 PS 報告其梯度,並通過同一 RPC 請求的返回值取回更新後的引數。

在訓練器的實現中,目標函式是否被標記 @rpc.functions.async_execution是沒有區別的。訓練器只需使用 rpc_sync 呼叫update_and_fetch_model,其將阻塞訓練器,直到返回更新的模型。

可以看到,引數伺服器儲存模型,模型可以返回到trainer。

class Trainer(object):

    def __init__(self, ps_rref):
        self.ps_rref = ps_rref
        self.loss_fn = nn.MSELoss()
        self.one_hot_indices = torch.LongTensor(batch_size) \
                                    .random_(0, num_classes) \
                                    .view(batch_size, 1)

    def get_next_batch(self):
        for _ in range(num_batches):
            inputs = torch.randn(batch_size, 3, image_w, image_h)
            labels = torch.zeros(batch_size, num_classes) \
                        .scatter_(1, self.one_hot_indices, 1)
            yield inputs.cuda(), labels.cuda()

    def train(self):
        name = rpc.get_worker_info().name
        # 從引數伺服器獲取模型
        m = self.ps_rref.rpc_sync().get_model().cuda()
        for inputs, labels in self.get_next_batch():
            timed_log(f"{name} processing one batch")
            # 利用模型來前向傳播/反向傳播
            self.loss_fn(m(inputs), labels).backward()
            timed_log(f"{name} reporting grads")
            # 呼叫引數伺服器的函式來提交梯度
            m = rpc.rpc_sync( # rpc_sync 操作完成之後,m就是最新模型了
                self.ps_rref.owner(),
                BatchUpdateParameterServer.update_and_fetch_model,
                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
            ).cuda()
            timed_log(f"{name} got updated model")

擴充邏輯如下:

  1. 引數伺服器的run_trainer 方法會直接呼叫 trainer.train() 方法來執行一步step。
  2. train 方法之中,會呼叫 self.ps_rref.rpc_sync().get_model().cuda() 從引數伺服器獲得模型,放到本地裝置之上(圖上是雙向箭頭,表示這是一個get/return動作,需要把模型儲存在worker本地)。
  3. 呼叫 self.loss_fn(m(inputs), labels).backward() 來進行前向傳播/反向傳播。
  4. 呼叫引數伺服器的 update_and_fetch_model 函式來提交梯度,這裡使用了非同步RPC
  5. 引數伺服器的 update_and_fetch_model 之中,進行梯度累積,模型更新是通過PS之上常規SGD優化器完成,最後呼叫 fut.set_result(self.model) 來發布新模型給trainer。在trainer 之中,就是 m = rpc.rpc_sync(...) 這個賦值之後,m 是最新模型了。

0x05 對比

前文結尾,我們對比引數伺服器的經典實現 ps-lite 和 前兩篇實現的引數伺服器。

  • ps-lite 是類似傳統伺服器實現,有自己主動的業務迴圈,可以響應使用者的顯式請求,也有自己明確的邏輯,本地也有自己的KV儲存。
  • PyTorch 前兩篇官方文件(本系列前兩篇文章)之中,引數伺服器則是另外一種思路:
    • 引數伺服器上沒有主動的迴圈,沒有KV儲存,沒有伺服器邏輯,而是可以直接儲存業務模型,ps 會把業務模型需要優化的引數返回給trainer 之上的 DistributedOptimizer。
    • 業務驅動由trainer完成:train loop程式碼在trainer 之中,DistributedOptimizer 在trainer 之中,DistributedOptimizer 負責進行分散式優化。
  • 本文又與上面不同,看起來更像是ps-lite,但是又糅合了RPC實現:
    • ps程式會啟動trainer的訓練迴圈
    • 每個迭代之中,trainer 會從引數伺服器獲取最新模型,前向操作/後向傳播都在trainer 完成。
    • trainer 會通過非同步RPC把梯度提交給引數伺服器。
    • 模型更新是通過PS之上常規SGD優化器完成
    • 模型更新之後通過非同步RPC把模型再次分發給trainer。

不得不說,官方這幾篇文章快把各種實現方式玩出花來了,大家可以依據自己業務特點來參考實現。

0xFF 參考

IMPLEMENTING BATCH RPC PROCESSING USING ASYNCHRONOUS EXECUTIONS

相關文章