[原始碼解析] PyTorch 分散式(14) --使用 Distributed Autograd 和 Distributed Optimizer

羅西的思考發表於2021-12-13

[原始碼解析] PyTorch 分散式(14) --使用 Distributed Autograd 和 Distributed Optimizer

0x00 摘要

在前面的文章之中,我們已經學習了PyTorch 分散式的基本模組,接下來我們通過幾篇文章來看看如何把這些模組應用到實踐之中,順便把PyTorch分散式邏輯整體梳理一下。本文介紹如何把分散式自動微分和分散式優化器結合起來訓練一個模型。

本文以 https://pytorch.org/tutorials/intermediate/rpc_tutorial.html 的部分翻譯為基礎,加入了自己的理解。

PyTorch分散式其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀

[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎

[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構

[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯

[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法

[原始碼解析] PyTorch 分散式(1)------歷史和概述

[原始碼解析] PyTorch 分散式(2) ----- DataParallel(上)

[原始碼解析] PyTorch 分散式(3) ----- DataParallel(下)

[原始碼解析] PyTorch 分散式(4)------分散式應用基礎概念

[原始碼解析] PyTorch分散式(5) ------ DistributedDataParallel 總述&如何使用

[原始碼解析] PyTorch分散式(6) ---DistributedDataParallel -- 初始化&store

[原始碼解析] PyTorch 分散式(7) ----- DistributedDataParallel 之程式組

[原始碼解析] PyTorch 分散式(8) -------- DistributedDataParallel之論文篇

[原始碼解析] PyTorch 分散式(9) ----- DistributedDataParallel 之初始化

[原始碼解析] PyTorch 分散式(10)------DistributedDataParallel 之 Reducer靜態架構

[原始碼解析] PyTorch 分散式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作

[原始碼解析] PyTorch 分散式(12) ----- DistributedDataParallel 之 前向傳播

[原始碼解析] PyTorch 分散式(13) ----- DistributedDataParallel 之 反向傳播

[原始碼解析] PyTorch 分散式 Autograd (1) ---- 設計

[原始碼解析] PyTorch 分散式 Autograd (2) ---- RPC基礎

[原始碼解析] PyTorch 分散式 Autograd (3) ---- 上下文相關

[原始碼解析] PyTorch 分散式 Autograd (4) ---- 如何切入引擎

[原始碼解析] PyTorch 分散式 Autograd (5) ---- 引擎(上)

[原始碼解析] PyTorch 分散式 Autograd (6) ---- 引擎(下)

[原始碼解析] PyTorch分散式優化器(1)----基石篇

[原始碼解析] PyTorch分散式優化器(2)----資料並行優化器

[原始碼解析] PyTorch分散式優化器(3)---- 模型並行

0x01 說明

首先要做一下說明,原文有兩部分:強化學習和RNN,本文只是翻譯了RNN部分。而且本文沒有完全按照原文順序進行翻譯,而是按照自己理解的思路重新組織了文章,用一種從上至下的角度來看這個系統。

本文使用RNN模型來展示如何使用RPC API構建分散式模型並行訓練。示例RNN模型非常小,可以很容易地放入單個GPU中,但我們仍然將它的層分在兩個不同worker來之上來演示如何分散式訓練。開發人員可以應用類似的技術在多個裝置和機器上分發更大的模型。

注:在官方這些分散式文章中,worker 有時指代分散式系統之中所有程式,而實際訓練程式往往叫做 trainer,本文的worker 就包括一個 trainer 和 一個引數伺服器。

0x02 啟動

在啟動階段,run_worker 方法會啟動一個 trainer 和 一個引數伺服器,引數伺服器在程式碼之中沒有任何行為。

def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 1:
        # 啟動了trainer
        rpc.init_rpc("trainer", rank=rank, world_size=world_size)
        # trainer 業務邏輯
        _run_trainer()
    else:
        # 啟動了引數伺服器
        rpc.init_rpc("ps", rank=rank, world_size=world_size)
        # parameter server do nothing
        pass

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = 2
    mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)

具體如下圖:

           torch.multiprocessing.spawn
                      +
                      |
                      |
    +-----------------+--------------------+
    |                                      |
    |                                      |
    v                                      v
+---+---------------------+   +------------+-------------+
| "ps"          rank = 0  |   | "trainer"      rank = 1  |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
+-------------------------+   +--------------------------+

0x03 Trainer

我們接下來看看訓練迴圈。初始化模型引數後,我們建立"RNNModel"和"DistributedOptimizer"。分散式優化器將獲取引數"RRefs"的列表,查詢這些引數所有的不同的 owner workers,並使用給定引數(即"lr=0.05")在每個owner worker上建立給定的本地優化器(在本例中即"SGD",您也可以使用其他本地優化器)。

在訓練迴圈中,它做如下操作:

  • 首先建立分散式autograd context,這將幫助分散式autograd引擎查詢梯度和涉及的RPC send/recv 函式。
  • 然後,它像本地模型一樣開始向前傳播,並且執行分散式向後傳播。對於分散式後向傳播,您只需要指定根的列表(list of roots),在本例中,它是loss 張量。分散式autograd引擎將自動遍歷分散式計算圖並正確寫入梯度。
  • 接下來,它在分散式優化器上執行'step'函式,該函式將與所有相關的本地優化器聯絡以更新模型引數。與本地訓練相比,一個區別是使用者不需要執行 zero_grad() ,因為每個autograd context 都有專用的空間來儲存梯度,這樣每次迭代建立一個上下文時,來自不同迭代的梯度不會累積到同一組張量之上。

具體程式碼如下:

def run_trainer():
    batch = 5
    ntoken = 10
    ninp = 2
    nhid = 3
    nindices = 3
    nlayers = 4
    hidden = (
        torch.randn(nlayers, nindices, nhid),
        torch.randn(nlayers, nindices, nhid)
    )

    model = rnn.RNNModel('ps', ntoken, ninp, nhid, nlayers)

    # setup distributed optimizer
    opt = DistributedOptimizer( # 建立分散式優化器
        optim.SGD,
        model.parameter_rrefs(),
        lr=0.05,
    )

    criterion = torch.nn.CrossEntropyLoss()

    def get_next_batch():
        for _ in range(5):
            data = torch.LongTensor(batch, nindices) % ntoken
            target = torch.LongTensor(batch, ntoken) % nindices
            yield data, target

    # train for 10 iterations
    for epoch in range(10):
        for data, target in get_next_batch():
            # create distributed autograd context
            with dist_autograd.context() as context_id: # 建立分散式上下文
                hidden[0].detach_()
                hidden[1].detach_()
                output, hidden = model(data, hidden)
                loss = criterion(output, target)
                # run distributed backward pass
                dist_autograd.backward(context_id, [loss]) # 執行分散式後向傳播
                # run distributed optimizer
                opt.step(context_id) # 分散式優化器進行更新
                # not necessary to zero grads since they are
                # accumulated into the distributed autograd context
                # which is reset every iteration.
        print("Training epoch {}".format(epoch))

邏輯擴充套件為:

           torch.multiprocessing.spawn
                      +
                      |
                      |
    +-----------------+--------------------+
    |                                      |
    |                                      |
    v                                      v
+---+---------------------+   +------------+-----------------------------------+
| "ps"          rank = 0  |   | "trainer"      rank = 1                        |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |    model = rnn.RNNModel('ps')                  |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |    dist_autograd.backward(context_id, [loss])  |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |    DistributedOptimizer.step(context_id)       |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |                                                |
+-------------------------+   +------------------------------------------------+

0x04 模型

我們接下來看看具體模型。

4.1 元件

RNN模型設計借鑑了PyTorch示例庫 example中的word語言模型,該模型包含三個主要元件:嵌入表、LSTM層和解碼器。

4.1.1 參考程式碼

我們有必要貼出原始參考程式碼來比對,可以看到,Embedding 和 Linear 都是作為 RNNModel 的成員變數存在,整個 RNNModel 耦合的非常緊密。

class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
        super(RNNModel, self).__init__()
        self.ntoken = ntoken
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp) # 嵌入表成員變數
        if rnn_type in ['LSTM', 'GRU']:
            self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
        else:
            nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
            self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken) # 解碼器成員變數

			  # 省略後部分程式碼

4.1.2 分散式修改

我們看看如何依據分散式的特點來對上面模型進行修改。

下面的程式碼將嵌入表(embedding table)和解碼器包裝到子模組(sub-modules)中,以便將它們的建構函式傳遞給RPC API。在EmbeddingTable子模組中,我們有意將嵌入層放在GPU上以做演示。在v1.4中,RPC總是在目標工作程式上建立CPU張量引數或返回值。如果函式採用GPU張量,則需要顯式地將其移動到適當的裝置。

class EmbeddingTable(nn.Module):
    r"""
    Encoding layers of the RNNModel
    """
    def __init__(self, ntoken, ninp, dropout):
        super(EmbeddingTable, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp).cuda()
        self.encoder.weight.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        return self.drop(self.encoder(input.cuda()).cpu()


class Decoder(nn.Module):
    def __init__(self, ntoken, nhid, dropout):
        super(Decoder, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-0.1, 0.1)

    def forward(self, output):
        return self.decoder(self.drop(output))

4.2 RNN 模型

前面提到,為了實現分散式模型並行訓練,開發人員可以將模型劃分為子模組。有了上面的子模組,我們現在可以使用RPC將它們組合在一起,建立一個RNN模型。我們將呼叫RPC遠端建立子模組例項,並在必要時使用RRef查詢它們。正如您在下面的程式碼中所看到的,它看起來非常類似於單機模型並行訓練。主要區別在於用RPC函式替換 Tensor.to(device)

ps表示一個引數伺服器,它承載嵌入表和解碼器的引數。建構函式使用remote API在引數伺服器上建立EmbeddingTable物件和解碼器物件,並在本地建立LSTM子模組。

在向前傳播過程中,trainer使用EmbeddingTable RRef查詢遠端子模組,並使用RPC將輸入資料傳遞給EmbeddingTable並獲取查詢結果。然後,它通過本地LSTM層執行嵌入,最後使用另一個RPC將輸出傳送到解碼器子模組。

class RNNModel(nn.Module):
    def __init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()

        # setup embedding table remotely
        self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout))
        # setup LSTM locally
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        # setup decoder remotely
        self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout))

    def forward(self, input, hidden):
        # pass input to the remote embedding table and fetch emb tensor back
        emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
        output, hidden = self.rnn(emb, hidden)
        # pass output to the rremote decoder and get the decoded output back
        decoded = _remote_method(Decoder.forward, self.decoder_rref, output)
        return decoded, hidden

因此,邏輯圖擴充如下:

                 torch.multiprocessing.spawn
                            +
                            |
                            |
          +-----------------+--------------------+
          |                                      |
          |                                      |
          v                                      v
+---------+------------+   +---------------------+-------------------------------------+
|"ps"        rank = 0  |   | "trainer"      rank = 1                                   |
|                      |   |                                                           |
|                      |   |   model = rnn.RNNModel('ps')                              |
|                      |   |                                                           |
| +---------------+    |   |   +---------------------------------------+               |
| |EmbeddingTable |    |   |   | RNNModel                              |               |
| |               |    |   |   |                                       |               |
| |               | <--------------+ self.emb_table_rref               |               |
| +---------------+    |   |   |                                       |               |
| +---------------+    |   |   |                                       |               |
| |Decoder        | <--------------+ self.decoder_rref                 |               |
| |               |    |   |   |                                       |               |
| |               |    |   |   |     self.rnn = LSTM                   |               |
| |               |    |   |   |                                       |               |
| +---------------+    |   |   +---------------------------------------+               |
|                      |   |                                                           |
|                      |   |                                                           |
|                      |   |   forward() {                                             |
|                      |   |       emb = _remote_method(EmbeddingTable.forward, input) |
|                      |   |       output, hidden = self.rnn(emb, hidden)              |
+----------------------+   |       decoded = _remote_method(Decoder.forward, output)   |
                           |   }                                                       |
                           |                                                           |
                           |                                                           |
                           |   dist_autograd.backward(context_id, [loss])              |
                           |                                                           |
                           |                                                           |
                           |   DistributedOptimizer.step(context_id)                   |
                           |                                                           |
                           +-----------------------------------------------------------+


4.3 分散式優化器

在介紹分散式優化器之前,讓我們新增一個helper函式,此函式用來生成模型引數的RRefs列表,分散式優化器將使用該列表。在本地訓練中,應用程式可以呼叫 Module.parameters()來獲取對所有引數張量的引用,並將其傳遞給本地優化器進行後續更新。但是,由於某些引數存在於遠端機器上,因此同一API在分散式訓練場景中不起作用。因此,分散式優化器不採用引數"張量"列表,而是採用"RRef"列表,本地和遠端模型引數的每個模型引數都有一個"RRef"。helper函式非常簡單,只需呼叫Module.parameters() 並在每個引數上建立一個本地'RRef'。

def _parameter_rrefs(module):
    param_rrefs = []
    for param in module.parameters():
        param_rrefs.append(RRef(param))
    return param_rrefs

然後,由於RNNModel包含三個子模組,我們需要呼叫 _parameter_rrefs 三次,並將其封裝到另一個helper函式中。

class RNNModel(nn.Module):
    ...
    def parameter_rrefs(self):
        remote_params = []
        # get RRefs of embedding table
        remote_params.extend(_remote_method(_parameter_rrefs, self.emb_table_rref))
        # create RRefs for local parameters
        remote_params.extend(_parameter_rrefs(self.rnn))
        # get RRefs of decoder
        remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref))
        return remote_params

在 trainer 之中,使用如下來生成分散式優化器,這樣就把遠端的一些引數作為優化物件。

# setup distributed optimizer
opt = DistributedOptimizer(
    optim.SGD,
    model.parameter_rrefs(),
    lr=0.05,
)

我們最後擴充如下:

  • (1) RNNModel 的 emb_table_rref 成員變數指向引數伺服器上的EmbeddingTable。
  • (2) RNNModel 的 decoder_rref 成員變數指向引數伺服器上的Decoder。
  • (3) RNNModel 的 rnn 成員變數指向本地的LSTM。
  • DistributedOptimizer 內部的三個待優化變數分別指向:4) 引數伺服器上的EmbeddingTable 的 引數,5) 引數伺服器上的Decoder 的引數,6) 本地LSTM的引數。

分別對應下圖上的數字。

                 torch.multiprocessing.spawn
                            +
                            |
                            |
            +---------------+--------------------+
            |                                    |
            |                                    |
            v                                    v
  +---------+------------+ +---------------------+----------------------------------------+
  |"ps"        rank = 0  | | "trainer"                                         rank = 1   |
  |                      | |                                                              |
  |                      | |   model = rnn.RNNModel('ps')                                 |
  |                      | |                                                              |
  |  +---------------+   | |   +---------------------------------------+                  |
  |  |EmbeddingTable |   | |   | RNNModel                              |                  |
+--->+               |   | | 1 |                                       |                  |
| |  |               +<------------+ self.emb_table_rref               |    +------+      |
| |  +---------------+   | |   |                            3          |    |LSTM  |  6   |
| |                      | |   |     self.rnn +---------------------------->+      +<---+ |
| |  +---------------+   | | 2 |                                       |    |      |    | |
| |  |Decoder        +<------------+ self.decoder_rref                 |    +------+    | |
| |  |               |   | |   |                                       |                | |
| |  |               |   | |   +---------------------------------------+                | |
| |  |               |   | |                                                            | |
| |  +------+--------+   | |   forward() {                                              | |
| |         ^            | |       emb = _remote_method(EmbeddingTable.forward, input)  | |
| |         |            | |       output, hidden = self.rnn(emb, hidden)               | |
| |         |            | |       decoded = _remote_method(Decoder.forward, output)    | |
| |         |            | |   }                                                        | |
| +----------------------+ |                                                            | |
|           |              |   dist_autograd.backward(context_id, [loss])               | |
|           |              |                                                            | |
| 5         | 4            |  +------------------------------------------------------+  | |
|           |              |  | DistributedOptimizer                                 |  | |
|           |              |  |                                                      |  | |
|           |              |  |     remote_optimizers = [                            |  | |
+-------------------------------------------------------+ optim_rref1,               |  | |
            |              |  |                           optim_rref2+------------------+ |
            +-------------------------------------------+ optim_rref3                |    |
                           |  |                                                      |    |
                           |  |                          ]                           |    |
                           |  |     step(context_id)                                 |    |
                           |  +------------------------------------------------------+    |
                           +--------------------------------------------------------------+

手機如下:

4.4 比對

因為前面提到:分散式模型並行訓練看起來非常類似於單機模型並行訓練。主要區別在於用RPC函式替換 Tensor.to(device)。我們用GPU替代引數伺服器,把上圖大致修改下做一下對比,可能不是非常確切,但是大家可以看出來分散式訓練的關鍵點。

  +----------------------+ +-------------------------------------------------------------+
  | GPU                  | | CPU                                                rank = 0 |
  |                      | |                                                             |
  |                      | |   model = rnn.RNNModel()                                    |
  |                      | |                                                             |
  |  +---------------+   | |   +---------------------------------------+                 |
  |  |EmbeddingTable |   | |   | RNNModel                              |                 |
+--->+               |   | | 1 |                                       |                 |
| |  |               +<------------+ self.emb_table_rref               |   +------+      |
| |  +---------------+   | |   |                            3          |   |LSTM  |  6   |
| |                      | |   |     self.rnn +--------------------------->+      +<---+ |
| |  +---------------+   | | 2 |                                       |   |      |    | |
| |  |Decoder        +<------------+ self.decoder_rref                 |   +------+    | |
| |  |               |   | |   |                                       |               | |
| |  |               |   | |   +---------------------------------------+               | |
| |  |               |   | |                                                           | |
| |  +------+--------+   | |   forward() {                                             | |
| |         ^            | |       emb = EmbeddingTable.forward(input)                 | |
| |         |            | |       output, hidden = self.rnn(emb, hidden)              | |
| |         |            | |       decoded = Decoder.forward(output)                   | |
| |         |            | |   }                                                       | |
| +----------------------+ |                                                           | |
|           |              |   loss.backward()                                         | |
|           |              |                                                           | |
| 5         | 4            |  +----------------------------------------+               | |
|           |              |  | Optimizer                              |               | |
|           |              |  |                                        |               | |
|           |              |  |          param_groups = [              |               | |
+-------------------------------------------------------+ optim_rref1, |               | |
            |              |  |                                        |               | |
            |              |  |                           optim_rref2+-----------------+ |
            |              |  |                                        |                 |
            +-------------------------------------------+ optim_rref3  |                 |
                           |  |                          ]             |                 |
                           |  |          step()                        |                 |
                           |  |                                        |                 |
                           |  +----------------------------------------+                 |
                           +-------------------------------------------------------------+

手機如下:

0xFF 參考

GETTING STARTED WITH DISTRIBUTED RPC FRAMEWORK

相關文章