[原始碼解析] PyTorch 分散式(15) --- 使用分散式 RPC 框架實現引數伺服器

羅西的思考發表於2021-12-14

[原始碼解析] PyTorch 分散式(15) --- 使用分散式 RPC 框架實現引數伺服器

0x00 摘要

在前面的文章之中,我們已經學習了PyTorch 分散式的基本模組,接下來我們通過幾篇文章來看看如何把這些模組應用到實踐之中,順便把PyTorch分散式邏輯整體梳理一下。本文介紹如何使用分散式 RPC 框架實現引數伺服器。

本文以 https://pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html 為基礎,加入了自己的理解。

PyTorch分散式其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀

[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎

[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構

[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯

[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法

[原始碼解析] PyTorch 分散式(1)------歷史和概述

[原始碼解析] PyTorch 分散式(2) ----- DataParallel(上)

[原始碼解析] PyTorch 分散式(3) ----- DataParallel(下)

[原始碼解析] PyTorch 分散式(4)------分散式應用基礎概念

[原始碼解析] PyTorch分散式(5) ------ DistributedDataParallel 總述&如何使用

[原始碼解析] PyTorch分散式(6) ---DistributedDataParallel -- 初始化&store

[原始碼解析] PyTorch 分散式(7) ----- DistributedDataParallel 之程式組

[原始碼解析] PyTorch 分散式(8) -------- DistributedDataParallel之論文篇

[原始碼解析] PyTorch 分散式(9) ----- DistributedDataParallel 之初始化

[原始碼解析] PyTorch 分散式(10)------DistributedDataParallel 之 Reducer靜態架構

[原始碼解析] PyTorch 分散式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作

[原始碼解析] PyTorch 分散式(12) ----- DistributedDataParallel 之 前向傳播

[原始碼解析] PyTorch 分散式(13) ----- DistributedDataParallel 之 反向傳播

[原始碼解析] PyTorch 分散式 Autograd (1) ---- 設計

[原始碼解析] PyTorch 分散式 Autograd (2) ---- RPC基礎

[原始碼解析] PyTorch 分散式 Autograd (3) ---- 上下文相關

[原始碼解析] PyTorch 分散式 Autograd (4) ---- 如何切入引擎

[原始碼解析] PyTorch 分散式 Autograd (5) ---- 引擎(上)

[原始碼解析] PyTorch 分散式 Autograd (6) ---- 引擎(下)

[原始碼解析] PyTorch分散式優化器(1)----基石篇

[原始碼解析] PyTorch分散式優化器(2)----資料並行優化器

[原始碼解析] PyTorch分散式優化器(3)---- 模型並行

[原始碼解析] PyTorch 分散式(14) --使用 Distributed Autograd 和 Distributed Optimizer

注:本文沒有完全按照原文順序進行翻譯,而是按照自己理解的思路重新組織了文章。

0x01 綜述

本教程介紹了一個使用 PyTorch 的分散式 RPC 框架實現引數伺服器的簡單示例。引數伺服器框架是一種正規化,其中包括一組用來儲存引數(例如大型嵌入表)的伺服器,多個訓練器查詢引數伺服器以檢索最新的引數。這些訓練器可以在本地執行一個訓練迴圈,間或與引數伺服器同步以獲得最新的引數。有關引數伺服器方法的更多資訊,請檢視https://www.cs.cmu.edu/~muli/file/parameter_server_osdi14.pdf。

我們將使用分散式 RPC 框架構建一個示例,其中多個trainer使用 RPC 與同一個引數伺服器進行通訊,並使用RRef訪問遠端引數伺服器例項上的狀態。每個trainer將通過使用分散式 autograd 跨多個節點拼接了一個 autograd 計算圖,並且以分散式方式啟動每個trainer自己的反向傳播。

注意:本教程介紹了分散式 RPC 框架的使用,該框架可用於將模型拆分到多臺機器上,或用於實現引數伺服器訓練策略(trainer獲取託管在一個不同機器上的引數)。如果您正在尋找跨多個 GPU 複製模型進行資料並行訓練,請參閱分散式資料並行教程

0x02 基礎網路

我們首先要介紹一下基礎網路。讓我們從熟悉的開始:匯入我們所需的模組並定義一個簡單的 ConvNet,它將在 MNIST 資料集上進行訓練。下面的網路主要來自pytorch/examples repo 中定義的網路。

# --------- MNIST Network to train, from pytorch/examples -----
class Net(nn.Module):
    def __init__(self, num_gpus=0):
        super(Net, self).__init__()
        print(f"Using {num_gpus} GPUs to train")
        self.num_gpus = num_gpus
        device = torch.device(
            "cuda:0" if torch.cuda.is_available() and self.num_gpus > 0 else "cpu")
        print(f"Putting first 2 convs on {str(device)}")
        # Put conv layers on the first cuda device
        self.conv1 = nn.Conv2d(1, 32, 3, 1).to(device)
        self.conv2 = nn.Conv2d(32, 64, 3, 1).to(device)
        # Put rest of the network on the 2nd cuda device, if there is one
        if "cuda" in str(device) and num_gpus > 1:
            device = torch.device("cuda:1")

        print(f"Putting rest of layers on {str(device)}")
        self.dropout1 = nn.Dropout2d(0.25).to(device)
        self.dropout2 = nn.Dropout2d(0.5).to(device)
        self.fc1 = nn.Linear(9216, 128).to(device)
        self.fc2 = nn.Linear(128, 10).to(device)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)

        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        # Move tensor to next device if necessary
        next_device = next(self.fc1.parameters()).device
        x = x.to(next_device)

        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

0x03 輔助函式

接下來,讓我們定義一些對我們指令碼的其餘部分有用的輔助函式。下面使用rpc_syncRRef來定義一個函式,該函式呼叫位於遠端節點上的物件上的給定方法。我們由給定的rref引數生成遠端物件的控制程式碼,這樣我們可以在擁有它的節點(rref.owner())上執行這個遠端物件。在呼叫者節點上,我們通過使用 rpc_sync來同步執行此命令,這意味著我們將阻塞直到收到響應。

# --------- Helper Methods --------------------

# On the local node, call a method with first arg as the value held by the
# RRef. Other args are passed in as arguments to the function called.
# Useful for calling instance methods.
def call_method(method, rref, *args, **kwargs):
    return method(rref.local_value(), *args, **kwargs)

# Given an RRef, return the result of calling the passed in method on the value
# held by the RRef. This call is done on the remote node that owns
# the RRef. args and kwargs are passed into the method.
# Example: If the value held by the RRef is of type Foo, then
# remote_method(Foo.bar, rref, arg1, arg2) is equivalent to calling
# <foo_instance>.bar(arg1, arg2) on the remote node and getting the result
# back.
def remote_method(method, rref, *args, **kwargs):
    args = [method, rref] + list(args)
    return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs)

0x04 啟動

4.1 啟動方式

要在本地執行此用例,需要在單獨的終端視窗執行如下命令來啟動master和worker:python rpc_parameter_server.py --world_size=WORLD_SIZE --rank=RANK

  • 對於 world size 大小為 2 的master節點,命令是:python rpc_parameter_server.py --world_size=2 --rank=0

  • 對於trainer,命令是:python rpc_parameter_server.py --world_size=2 --rank=1

請注意,本教程假設使用 0 到 2 個 GPU 進行訓練,可以通過傳遞--num_gpus=N到訓練指令碼來配置此引數。當trainer和master在不同機器上執行時,您可以傳入命令列引數--master_addr=ADDRESS--master_port=PORT來標明master worker 正在偵聽的地址和埠。

4.2 啟動指令碼

首先,我們來看看啟動引數伺服器和訓練器所需要的各種引數。

  • world_size對應於將參與訓練的節點總數,是所有訓練器和引數伺服器的總和。
  • 我們還必須為每個單獨的程式傳遞一個唯一的rank,該值從 0(在其中將執行一個引數伺服器)到world_size - 1
  • master_addrmaster_port被用來標示 rank 0 程式,其將被各個節點用於發現彼此。
  • 要在本地測試此示例,只需傳入localhost和相同的master_port到所有的例項即可。

請注意,出於演示目的,此示例僅支援 0-2 個 GPU,但可以擴充套件該模式以使用其他 GPU。

# --------- Launcher --------------------
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Parameter-Server RPC based training")
    parser.add_argument(
        "--world_size",
        type=int,
        default=4,
        help="""Total number of participating processes. Should be the sum of
        master node and all training nodes.""")
    parser.add_argument(
        "--rank",
        type=int,
        default=None,
        help="Global rank of this process. Pass in 0 for master.")
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=0,
        help="""Number of GPUs to use for training, currently supports between 0
         and 2 GPUs. Note that this argument will be passed to the parameter servers.""")
    parser.add_argument(
        "--master_addr",
        type=str,
        default="localhost",
        help="""Address of master, will default to localhost if not provided.
        Master must be able to accept network traffic on the address + port.""")
    parser.add_argument(
        "--master_port",
        type=str,
        default="29500",
        help="""Port that master is listening on, will default to 29500 if not
        provided. Master must be able to accept network traffic on the host and port.""")

    args = parser.parse_args()
    assert args.rank is not None, "must provide rank argument."
    assert args.num_gpus <= 3, f"Only 0-2 GPUs currently supported (got {args.num_gpus})."
    os.environ['MASTER_ADDR'] = args.master_addr
    os.environ['MASTER_PORT'] = args.master_port

現在,我們將根據命令列引數建立一個引數伺服器程式或者或訓練器程式。如果傳入的rank為 0,我們將建立一個 ParameterServer,否則建立一個 TrainerNet

請注意,我們使用torch.multiprocessing啟動與我們要執行的函式相對應的子程式,並在主執行緒使用p.join() 來等待程式結束。我們也使用PyTorch dataloaders 來生成從MNIST資料集載入資料的訓練 data loaders 和測試 data loaders。

processes = []
world_size = args.world_size
if args.rank == 0:
    # 這裡是引數伺服器
    p = mp.Process(target=run_parameter_server, args=(0, world_size))
    p.start()
    processes.append(p)
else:
    # 這裡是trainer
    # Get data to train on
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=32, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=32, shuffle=True)
    # start training worker on this node
    p = mp.Process(
        target=run_worker, # 啟動trainer
        args=(
            args.rank,
            world_size, args.num_gpus,
            train_loader,
            test_loader))
    p.start()
    processes.append(p)

for p in processes:
    p.join()

目前邏輯如下,我們假設一個master,一個 worker,其將會在自己程式之中執行不同的程式碼:

rpc_parameter_server.py      Master      +    Worker         rpc_parameter_server.py
           +                             |                          +
           |                             |                          |
           | rank == 0                   |                          |
           |                             |                          |
           v                             |                          v
                                         |
run_parameter_server                     |                    mp.Process(run_worker)
                                         |
                                         |
                                         |
                                         +

4.3 啟動引數伺服器

首先,我們將初始化我們的引數伺服器。請注意,所有程式中只有一個引數伺服器例項,所有訓練器都將與同一個引數伺服器對話並更新同一個儲存模型。如 run_parameter_server中所見,伺服器本身不執行任何獨立行動,它只是等待來自訓練器的請求,並通過執行請求的函式來響應它們。

程式碼中主要兩步:為引數伺服器初始化rpc 和 rpc.shutdown()。注意,這裡沒有真正初始化引數伺服器

注意,rpc.shutdown()不會立即關閉引數伺服器。相反,它將等待所有worker(在本例中為trainer)也呼叫rpc.shutdown()。這保證了在所有 trainer完成訓練過程之前引數伺服器不會當機。

def run_parameter_server(rank, world_size):
    # The parameter server just acts as a host for the model and responds to
    # requests from trainers, hence it does not need to run a loop.
    # rpc.shutdown() will wait for all workers to complete by default, which
    # in this case means that the parameter server will wait for all trainers
    # to complete, and then exit.
    print("PS master initializing RPC")
    rpc.init_rpc(name="parameter_server", rank=rank, world_size=world_size)
    print("RPC initialized! Running parameter server...")
    rpc.shutdown() # 保證不會當機
    print("RPC shutdown on parameter server.")

邏輯擴充為:

          rpc_parameter_server.py      Master      +    Worker    rpc_parameter_server.py
                     +                             |                     +
                     |                             |                     |
                     | rank == 0                   |                     |
                     |                             |                     |
                     v                             |                     v
                                                   |
          run_parameter_server                     |               mp.Process(run_worker)
                     +                             |
                     |                             |
                     |                             |
                     v                             |
                                                   |
rpc.init_rpc("parameter_server", rank, world_size) |
                                                   |
                                                   +

4.4 啟動worker

run_worker 函式其內部邏輯也是啟動 rpc,然後進入了主迴圈。

# Main loop for trainers.
def run_worker(rank, world_size, num_gpus, train_loader, test_loader):
    print(f"Worker rank {rank} initializing RPC")
    rpc.init_rpc(
        name=f"trainer_{rank}",
        rank=rank,
        world_size=world_size)

    print(f"Worker {rank} done initializing RPC")

    run_training_loop(rank, num_gpus, train_loader, test_loader)
    rpc.shutdown()

邏輯擴充為:

 rpc_parameter_server.py      Master      +    Worker    rpc_parameter_server.py
            +                             |                     +
            |                             |                     |
            | rank == 0                   |                     |
            |                             |                     |
            v                             |                     v
                                          |
 run_parameter_server                     |               mp.Process(run_worker)
            +                             |                     +
            |                             |                     |
            |                             |                     |
            v                             |                     |
                                          |                     v
rpc.init_rpc("parameter_server"           |               run_worker
             ,rank,world_size)            |                     +
                                          |                     |
                                          |                     |
                                          |                     |
                                          |                     v
                                          |       rpc.init_rpc(f"trainer_{rank}"
                                          |                    ,rank,world_size)
                                          |                     +
                                          |                     |
                                          |                     |
                                          |                     v
                                          +             run_training_loop

4.5 建立引數伺服器

但是目前程式碼沒有建立引數伺服器,get_parameter_server 這裡才是構建引數伺服器的內容。

param_server = None
global_lock = Lock()

def get_parameter_server(num_gpus=0):
    global param_server
    # Ensure that we get only one handle to the ParameterServer.
    with global_lock:
        if not param_server:
            # construct it once
            param_server = ParameterServer(num_gpus=num_gpus)
        return param_server

究竟什麼地方建立了引數伺服器?其實是在 worker 主迴圈之中,TrainerNet 初始化生成的,後續會講到。

0x05 TrainerNet

接下來,我們將定義我們的TrainerNet類。這也是nn.Module 的子類,我們的__init__方法將使用rpc.remoteAPI 來獲取到我們的引數伺服器的 RRef 或遠端引用。請注意,這裡我們沒有將引數伺服器複製到我們的本地程式,相反,我們可以將self.param_server_rref視為指向位於另一個獨立程式上的引數伺服器的分散式共享指標

注意,TrainerNet是為了訓練用,不是業務函式,和前面的 class Net(nn.Module) 一定要區分清楚。TrainerNet 只是為了整體程式碼邏輯需要而實現的一個中轉站或者adapter,TrainNet 擁有一個指向 ParameterServer 的 param_server_rref,所以通過 TrainerNet 就能使用 ParameterServer。

5.1 總體程式碼

# --------- Trainers --------------------

# nn.Module corresponding to the network trained by this trainer. The
# forward() method simply invokes the network on the given parameter
# server.
class TrainerNet(nn.Module):
    def __init__(self, num_gpus=0):
        super().__init__()
        self.num_gpus = num_gpus
        self.param_server_rref = rpc.remote(
            "parameter_server", get_parameter_server, args=(num_gpus,)) # 生成引數伺服器

    def get_global_param_rrefs(self): # 此函式後續會用到
        remote_params = remote_method(
            ParameterServer.get_param_rrefs,
            self.param_server_rref)
        return remote_params

    def forward(self, x):
        model_output = remote_method(
            ParameterServer.forward, self.param_server_rref, x)
        return model_output

5.2 生成引數伺服器

初始化方法之中通過如下程式碼生成引數伺服器。

def __init__(self, num_gpus=0):
    super().__init__()
    self.num_gpus = num_gpus
    self.param_server_rref = rpc.remote(
        "parameter_server", get_parameter_server, args=(num_gpus,))

就呼叫到了之前提到的 get_parameter_server,正式構建了引數伺服器,注意,這裡是在 worker 呼叫 get_parameter_server,但是 get_parameter_server 在 master 之上執行,在 master 之上建立引數伺服器。

此時邏輯擴充如下:

 rpc_parameter_server.py      Master      +    Worker    rpc_parameter_server.py
            +                             |                     +
            |                             |                     |
            | rank == 0                   |                     |
            |                             |                     |
            v                             |                     v
                                          |
 run_parameter_server                     |               mp.Process(run_worker)
            +                             |                     +
            |                             |                     |
            |                             |                     |
            v                             |                     |
                                          |                     v
rpc.init_rpc("parameter_server"           |               run_worker
             ,rank,world_size)            |                     +
            +                             |                     |
            |                             |                     |
            |                             |                     |
            |                             |                     v
            |                             |       rpc.init_rpc(f"trainer_{rank}"
            |                             |                    ,rank,world_size)
            |                             |                     +
            |                             |                     |
            |                             |                     |
            |                             |                     v
            |                             |             run_training_loop
            |                             |                     +
            |                             |                     |
            |                             |                     |
            |                             |                     v
            |                             |     net = TrainerNet(num_gpus=num_gpus)
            |                             |                     +
            |                             |                     |
            |                             |                     |
            v                             |                     v
    get_parameter_server  <-------------------+ self.param_server_rref = rpc.remote(
            +                             |                 "parameter_server",
            |                             |                 get_parameter_server,
            |                             |                 args=(num_gpus,))
            v                             |
      ParameterServer                     |
                                          +

到目前為止,我們已經完成了初始化階段,下面就看看具體執行階段。

5.3 建立rref

在 TrainerNet 之中,還有一個get_global_param_rrefs 方法, 我們隨後會介紹如何使用。但是這裡會先分析一下這個方法。

為何要提供這個方法?我們有必要通讀DistributedOptimizer的文件,特別是 API 簽名。我們優化器需要優化一些遠端引數,但是如何在本地構建優化器時候傳入這些引數?這些遠端引數在本地就是 RRef,所以在本地構建優化器時候,我們就傳遞這個代表遠端引數的RRef列表。

由於與給定TrainerNet互動的唯一遠端節點是 ParameterServer,所以我們只需要在 ParameterServer上呼叫 一個 remote_method 。於是,我們使用在ParameterServer類中定義的get_param_rrefs方法。這個方法會返回一個RRefs的列表,這個列表指向需要優化的引數

請注意,在這種情況下,我們TrainerNet沒有定義自己的引數;如果定義了它自己的引數(需要優化),我們也需要將每個引數包裝在一個RRef之中,並將其包含在 DistributedOptimizer 輸入中。

class TrainerNet(nn.Module):
...
    def get_global_param_rrefs(self):
        remote_params = remote_method(
            ParameterServer.get_param_rrefs,
            self.param_server_rref)
        return remote_params

5.4 前向函式

我們再看看forward方法,該方法將呼叫(同步)RPC 來執行定義在ParameterServer之上的模型網路的前向傳播。我們傳入了self.param_server_rref,它是ParameterServer的一個遠端handle,用來執行RPC。呼叫forward將向正在執行ParameterServer的節點傳送一個 RPC ,以呼叫引數伺服器的forward函式,並返回對應於模型輸出的結果Tensor

class TrainerNet(nn.Module):
...
    def forward(self, x):
        model_output = remote_method(
            ParameterServer.forward, self.param_server_rref, x)
        return model_output

0x06 引數伺服器

我們接下來看看引數伺服器。

6.1 總體程式碼

引數伺服器總體程式碼如下:

# --------- Parameter Server --------------------
class ParameterServer(nn.Module):
    def __init__(self, num_gpus=0):
        super().__init__()
        model = Net(num_gpus=num_gpus)
        self.model = model
        self.input_device = torch.device(
            "cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu")

    def forward(self, inp):
        inp = inp.to(self.input_device)
        out = self.model(inp)
        # This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
        # Tensors must be moved in and out of GPU memory due to this.
        out = out.to("cpu")
        return out

    # Use dist autograd to retrieve gradients accumulated for this model.
    # Primarily used for verification.
    def get_dist_gradients(self, cid):
        grads = dist_autograd.get_gradients(cid)
        # This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
        # Tensors must be moved in and out of GPU memory due to this.
        cpu_grads = {}
        for k, v in grads.items():
            k_cpu, v_cpu = k.to("cpu"), v.to("cpu")
            cpu_grads[k_cpu] = v_cpu
        return cpu_grads

    # Wrap local parameters in a RRef. Needed for building the
    # DistributedOptimizer which optimizes parameters remotely.
    def get_param_rrefs(self):
        param_rrefs = [rpc.RRef(param) for param in self.model.parameters()]
        return param_rrefs

6.2 初始化

前面提到了,引數伺服器的正式初始化是在 TrainerNet 之中完成的,我們接下來看看如何初始化。

引數伺服器是nn.Module的派生類,其儲存上面定義的模型網路的控制程式碼。這裡就是用前面介紹的業務模型網路 class Net(nn.Module) 來生成了內部的 model 成員變數。 引數伺服器後續就是用這個來進行具體引數處理。我們還將儲存一個輸入裝置,在呼叫模型之前,我們的輸入需要傳輸到這個裝置之上。

# --------- Parameter Server --------------------
class ParameterServer(nn.Module):
    def __init__(self, num_gpus=0):
        super().__init__()
        model = Net(num_gpus=num_gpus) # 我們的業務網路模型
        self.model = model # self.model就是handle
        self.input_device = torch.device( # 輸入裝置
            "cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu")

6.3 前向函式

接下來,我們將定義前向傳播函式。請注意,無論模型輸出的裝置如何,我們都將輸出移動到 CPU,因為分散式 RPC 框架目前僅支援通過 RPC 傳送 CPU 張量。我們刻意禁止通過 RPC 傳送 CUDA 張量,因為呼叫方/被呼叫方可能使用不同的裝置(CPU/GPU),可能會在未來版本中支援 RPC 傳送 CUDA 張量。

class ParameterServer(nn.Module):
...
    def forward(self, inp):
        inp = inp.to(self.input_device)
        out = self.model(inp)
        # This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
        # Tensors must be moved in and out of GPU memory due to this.
        out = out.to("cpu")
        return out

6.4 雜項函式

接下來,我們將定義一些對訓練和驗證有用的雜項函式。

  • get_dist_gradients 將接收一個分散式 Autograd 上下文 ID 並呼叫dist_autograd.get_gradients以檢索由分散式 autograd 計算的梯度。更多資訊可以在分散式 autograd 文件中找到。請注意,我們還遍歷生成的字典並將每個張量轉換為 CPU 張量,因為該框架目前僅支援通過 RPC 傳送張量。
  • get_param_rrefs將遍歷我們的模型引數並將它們包裝為(本地)RRef。該方法將由 trainer 節點通過 RPC 呼叫,並將返回要分散式優化的引數列表。這引數列表將作為分散式優化器的輸入,因此,引數伺服器必須把必須優化的所有引數轉換為RRefs列表。對應程式碼就是獲取 Net 的引數,最終返回給 worker 端的 DistributedOptimizer。
# Use dist autograd to retrieve gradients accumulated for this model.
# Primarily used for verification.
def get_dist_gradients(self, cid):
    grads = dist_autograd.get_gradients(cid)
    # This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
    # Tensors must be moved in and out of GPU memory due to this.
    cpu_grads = {}
    for k, v in grads.items():
        k_cpu, v_cpu = k.to("cpu"), v.to("cpu")
        cpu_grads[k_cpu] = v_cpu
    return cpu_grads

# Wrap local parameters in a RRef. Needed for building the
# DistributedOptimizer which optimizes paramters remotely.
def get_param_rrefs(self):
    param_rrefs = [rpc.RRef(param) for param in self.model.parameters()]
    return param_rrefs

6.5 邏輯關係

我們需要一個邏輯關係圖來梳理一下:

  1. 生成 DistributedOptimizer 的時候,呼叫 TrainerNet 的 get_global_param_rrefs 方法來獲取需要分散式優化的引數。
  2. TrainerNet 呼叫 ParameterServer 的 get_param_rrefs 方法來取引數伺服器獲取。
  3. ParameterServer 呼叫 Net 的 parameters() 方法獲取最終需要的引數。
  4. 這些引數原路返回,最終給了 DistributedOptimizer,DistributedOptimizer 以後就是優化這些引數。
                        Master             +      Worker
                                           |
                                           |
+--------------------+                     |   +----------------------------------------+
| ParameterServer    |                     |   | run_training_loop                      |
|                    |           4         |   |     +-------------------------+        |
|                    | +-------------------------->  | TrainerNet              |        |
|                    |                     |   |     |                         |        |
|                    | <---------------------------+ |                         |        |
|        model       |   2 get_param_rrefs |   |     |                         |        |
|        ^   +       |                     |   |     |                         |        |
|        |   |       |                     |   |     |                         |        |
|        |   |       |                     |   |     +---+----+----------------+        |
+--------------------+                     |   |         |    ^                         |
         |   |                             |   |         |    |                         |
         |   |                             |   |       4 |  1 | get_global_param_rrefs  |
       4 | 3 | model.parameters()          |   |         |    |                         |
         |   |                             |   |         v    |                         |
         |   v                             |   |     +---+----+----------------+        |
+--------+---+-------+                     |   |     | DistributedOptimizer    |        |
| Net                |                     |   |     |                         |        |
|                    |                     |   |     |                         |        |
|                    |                     |   |     |                         |        |
|                    |                     |   |     +-------------------------+        |
+--------------------+                     |   +----------------------------------------+
                                           +

0x07 worker 主迴圈

現在,初始化完畢,引數伺服器也分析完畢,我們接下來看看 worker 主迴圈,它將建立我們的網路和優化器,通過網路執行一些輸入並計算損失。訓練迴圈看起來很像本地訓練程式,但由於我們的模型網路是分散式的,所以做了一些修改。

7.1 總體程式碼

在主迴圈之中,我們初始化"TrainerNet"並構建"DistributedOptimizer"。

請注意,如上所述,我們必須傳入所有的需要優化的全域性(跨參與分散式訓練的所有節點)引數。此外,我們傳入要使用的本地優化器,在本例中為SGD。另外,我們可以使用與建立本地優化器相同的方式來配置底層優化器演算法。例如,我們可以傳入一個自定義學習率,其將用作所有本地優化器的學習率。

worker 主迴圈 run_training_loop 程式碼如下,其中 model_output = net(data) 會呼叫 TrainerNet 的forward 方法。

def run_training_loop(rank, num_gpus, train_loader, test_loader):
    # Runs the typical neural network forward + backward + optimizer step, but
    # in a distributed fashion.
    net = TrainerNet(num_gpus=num_gpus)
    # Build DistributedOptimizer.
    param_rrefs = net.get_global_param_rrefs()
    opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03)
    for i, (data, target) in enumerate(train_loader):
        with dist_autograd.context() as cid:
            # 1. 呼叫 TrainerNet 的 forward
            model_output = net(data)
            target = target.to(model_output.device)
            # 2. 計算損失
            loss = F.nll_loss(model_output, target)
            if i % 5 == 0:
                print(f"Rank {rank} training batch {i} loss {loss.item()}")
            # 3. 反向傳播    
            dist_autograd.backward(cid, [loss])
            # Ensure that dist autograd ran successfully and gradients were
            # returned.
            # 4. 進行驗證
            assert remote_method(
                ParameterServer.get_dist_gradients,
                net.param_server_rref,
                cid) != {}
            opt.step(cid)

    print("Training complete!")
    print("Getting accuracy....")
    get_accuracy(test_loader, net)

7.2 訓練

我們需要對 run_training_loop 之中具體的訓練程式碼再做一下分析,就是主訓練迴圈。

我們通過PyTorch的DataLoader給出的iterables進行迴圈。在編寫典型的前向/後向/優化器迴圈之前,我們首先將邏輯包裝在Distributed Autograd context之中, 請注意,需要記錄在模型的前向傳播中呼叫的RPC,以便可以構造一個適當的計算圖,其包括後向傳播中所有參與的分散式工作程式。Distributed Autograd context 將返回一個"context_id",該id是一個識別符號,用來標示與特定迭代對應的梯度累積和優化。

與呼叫典型的loss.backward()來啟動本地工作程式的向後傳播不同,我們在上下文中呼叫dist_autograd.backward() ,呼叫時傳入 loss 和 context_id,這是我們希望向後傳播開始的根。此外,我們將這個"context_id"傳遞到優化器呼叫中,因為優化器呼叫需要能夠在所有節點上查詢這個特定向後傳遞計算的相應梯度。

for i, (data, target) in enumerate(train_loader):
    with dist_autograd.context() as cid:
        # 1. 呼叫 TrainerNet 的 forward
        model_output = net(data)
        target = target.to(model_output.device)
        # 2. 計算損失
        loss = F.nll_loss(model_output, target)
        if i % 5 == 0:
            print(f"Rank {rank} training batch {i} loss {loss.item()}")
        # 3. 反向傳播    
        dist_autograd.backward(cid, [loss])
        # Ensure that dist autograd ran successfully and gradients were
        # returned.
        # 4. 進行驗證
        assert remote_method(
            ParameterServer.get_dist_gradients,
            net.param_server_rref,
            cid) != {}
        # 5. 更新
        opt.step(cid)

我們擴充前面的邏輯(忽略初始化部分),下面圖上的數值對應訓練程式碼註釋中的數值

目前總體思路如下:

  • Master 之上執行了 ParameterServer,其包含了Net 這個業務模型。
  • Worker 之上執行了 trainer,其包含了 TrainerNet,TrainerNet 只是為了整體程式碼邏輯需要而實現的一箇中轉站或者adapter,TrainNet 擁有一個指向 ParameterServer 的 param_server_rref,所以通過 TrainerNet 就能使用 ParameterServer。
  • 具體 Forward 操作流程是: TrainerNet.forward --> ParameterServer.forward ---> Net.forward()。
  • 具體 Backward 是被dist_autograd.backward 呼叫 dist.autograd 引擎自動完成的。
  • 優化器更新則是DistributedOptimizer自動完成。

我們擴充前面的邏輯(忽略初始化部分)圖如下,下面圖上的數值對應訓練程式碼註釋中的數值

             Master      +    Worker
                         |
  ParameterServer        |        run_training_loop                  TrainerNet
        +                |                +                              +
        |                |                |                              |
        |                |                v                              |
        |                |          net = TrainerNet                     |
        |                |                +                              |
        |                |                |                              |
        +                |                |                              |
model = Net(num_gpus)    |                v                              |
        +                |     param_server_rref = rpc.remote(           |
        |                |          "parameter_server",                  |
        |                |         get_parameter_server,)                |
        |                |                +                              |
        |                |                |                              |
        |                |                |                              |
        |                |    opt = DistributedOptimizer(param_rrefs)    |
        |                |                +                              |
        |                |                |                              |
        |                |                |                              |
        |                |                v                              |
        |                |     model_output = net(data)                  |
        |                |                +                              |
        |                |                |            1                 |
        |                |                +----------------------------> |
        |                |                |                              +
        |                |                |                   ParameterServer.forward
        |                |                |                              +
        |                |                |            2                 |
        | <--------------------------------------------------------------+
        +                |                |                              |
    forward              |                |                              |
        +                |                |                              |
        |                |                |                              |
        +                |                |                              |
 out = self.model(inp)   |                |                              |
        +                |                |                              |
        |     return out |      3         |                              |
        +------------------------------>  |                              |
        |                |                |                              |
        |                |                +                              |
        |                |          F.nll_loss                           |
        |                |                +                              |
        |                |         dist_autograd.backward                |
        |         4      |                +                              |
        | <----------------------get_dist_gradients                      |
        |                |                +                              |
        |                |                |                              |
        |         5      |                +                              |
        | <-----------------------+ opt.step                             |
        |                |                |                              |
        |                |                |                              |
        v                +                v                              v

7.3 準確性

下面是在我們完成訓練後計算模型的精度,方法很像傳統的區域性模型。但是請注意,我們傳遞到這個函式的網路是TraineNet的一個例項,因此前向傳播將以透明的方式呼叫RPC。

def get_accuracy(test_loader, model):
    model.eval()
    correct_sum = 0
    # Use GPU to evaluate if possible
    device = torch.device("cuda:0" if model.num_gpus > 0
        and torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        for i, (data, target) in enumerate(test_loader):
            out = model(data, -1)
            pred = out.argmax(dim=1, keepdim=True)
            pred, target = pred.to(device), target.to(device)
            correct = pred.eq(target.view_as(pred)).sum().item()
            correct_sum += correct

    print(f"Accuracy {correct_sum / len(test_loader.dataset)}")

0x08 總結

我們之前介紹過引數伺服器的經典實現 ps-lite,現在通過前文和本文的學習,大家可以看到不同於ps-lite的思路。

  • ps-lite 是類似傳統伺服器實現,有自己主動的業務迴圈,可以響應使用者的顯式請求,也有自己明確的邏輯,本地也有自己的KV儲存。
  • PyTorch 這兩篇官方文件之中,引數伺服器則是另外一種思路,其上沒有主動的迴圈,沒有KV儲存,沒有伺服器邏輯,而是可以直接儲存業務模型,業務驅動由trainer完成。

具體如何實現就要看使用者自己業務需求了。

ps-lite 文章連結如下,大家有興趣可以對比一下。

[原始碼解析] 機器學習引數伺服器ps-lite (1) ----- PostOffice

[原始碼解析] 機器學習引數伺服器ps-lite(2) ----- 通訊模組Van

[原始碼解析] 機器學習引數伺服器ps-lite 之(3) ----- 代理人Customer

[原始碼解析]機器學習引數伺服器ps-lite(4) ----- 應用節點實現

0xFF 參考

https://pytorch.apachecn.org/docs/1.7/65.html

https://pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html

相關文章