[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

羅西的思考 發表於 2021-08-30
深度學習

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

0x00 摘要

GPipe是一個基於 Lingvo (Lingvo 是 Google 基於 TensorFlow 二次開發的重點針對序列模型的框架)開發的,支援超大規模模型的神經網路訓練並行庫,本文介紹其重計算功能,同時可以和其他實現一起印證。

本系列其他文章如下:

[原始碼解析] 深度學習流水線並行Gpipe(1)---流水線基本實現

[原始碼解析] 深度學習流水線並行GPipe (2) ----- 梯度累積

0x01 概述

1.1 前文回顧

前文提到,目前分散式模型訓練有幾個必要並行技術:

  • 流水並行,尤其是如何自動設定流水;
  • 梯度累加(Gradient Accumulation);
  • 後向重計算;
  • 1F1B 策略(我們將採用PipeDream分析);

在前文中,我們介紹了Gpipe如何實施流水線並行技術,以及梯度累積。

流水並行存在一個問題:視訊記憶體佔用太大。如果每個 micro-batch 前向計算的中間結果(activation)被後向計算所消費,則需要在視訊記憶體中快取 n 份(梯度累加的次數)完整的前向 activation。這時就不得不用另一項重要的技術:重計算(Checkpointing)。

本文以論文"Training deep nets with sublinear memory cost"為基礎,對於 pytorch 和 Gpipe 原始碼 進行分析,期望可以對 “Gradient checkpointing”技術有一個具體的理解。

1.2 Gradient checkpointing

2016年,陳天奇團隊提出了亞線性記憶體優化相關的 "gradient/activation checkpointing(後向重計算)"等技術,旨在降低深度學習訓練過程中的中間啟用(activation)帶來的視訊記憶體佔用。Checkpointing技術屬於亞線性記憶體優化的一種,除此之外還有CPU offload等技術(CPU offload在微軟Deepspeed框架中被廣泛使用)。

梯度檢查點是一種減少深度神經網路訓練時記憶體消耗的系統性方法,具體是在反向傳播中,針對每個設定為檢查點的段,通過重新執行前向傳播段來實現的:

  • 梯度檢查點方法集中在減少降低儲存中間結果(特徵圖)和梯度的記憶體開銷,因為在許多常見的深度網路之中,與模型引數相比,中間結果要大得多。
  • 梯度檢查點是一種以時間(算力)換空間(視訊記憶體)的方法,通過減少儲存的啟用值壓縮模型佔用空間,但是在計算梯度時必須重新計算沒有儲存的啟用值,即需要花兩倍的前向傳播計算時間。
  • 具體來說,就是設定一些梯度檢查點,檢查點之外的中間結果先釋放掉,將來在反向傳播的過程中如果發現前向結果不在視訊記憶體中,就找到最近的梯度檢查點再進行前向計算,恢復出被釋放的張量。

0x02 背景知識

2.1 求導如何工作

此處借鑑了 訓練時視訊記憶體優化技術——OP合併與gradient checkpoint 的思路。

DNN模型由一系列不同型別的層組成(例如卷積層,全連線層,池化層)。

反向傳播的關鍵是“自動鏈式求導”,但實際上BP在這個基礎上也加入了一點動態規劃機制。一般的BP包含以下兩個步驟:

  • 前向傳導。以影像分類為例,當前模型首先對一小部分訓練樣本(也稱為minibatch)進行預測。這個過程被稱為向前傳導。
    • 為了進行預測,來自小批量的輸入資料被輸入到模型的第一層。
    • 然後,每一層在其輸入上計算一個函式,為下一層生成輸出。前向傳導記錄以下兩個值:中間結點的輸出值,輸出值關於輸入值的梯度。
    • 最後一層的輸出是類預測。基於模型的預測標籤和每個影像的實際標籤,輸出層計算損失(或錯誤)。
  • 反向傳播梯度計算。反向傳播就是一個計算網路最終輸出值關於本層輸出的梯度的過程。即,從輸出開始,反向傳播梯度值,計算輸出值對於每一箇中間變數的梯度,並儲存。每層計算 前一層的誤差,和 所有相關層的權重更新(損失梯度),這將使模型的預測朝著所需的輸出移動。

在梯度回傳的過程中需要用到節點的輸出值,但是在反向傳播進行梯度計算的時候,BP不會進行重複計算,其原因就是在前向傳導的時候,進行了中間變數的儲存,也就是每個中間節點的輸出值。BP不斷地反向傳播梯度,並儲存中間梯度,直到計算圖的所有中間值以及初始值的梯度被求解完畢。

我們看看反向傳播是如何工作的。

所謂自動求導框架實際上是“半自動”的:它並非直接求出一個複雜函式導數的解析形式,而是通過構建計算圖和預先寫好的基礎函式的求導規則,結合鏈式求導法則實現的自動求導

我們假設一個函式為例進行說明,其表示式如下:

f(x) = x * (x + 1)

通過簡單的數學推導得到其梯度的解析式為f'(x) = x + 1 + x;先把這個結果放一邊,看看自動求導框架是如何一步步求出這個結果的,畫出計算圖如下:

                       +---------+
                       |         |
               +------>+  x + 1  +----+
               |       |         |    | 3
             2 |       +---------+    |
               |                      |
               |                      v
         +-----+--+                  ++------+
         |        |                  |       |
+------> |    x   +----------------> |   +   +---------->
         |        |         1        |       |
         +--------+                  +-------+

在計算圖上,反向傳播先經過乘法運算,根據上面的求導規則:

  • 路徑1上的梯度為 x + 1
  • 路徑3上的梯度為 x
  • 路徑3再反向傳播要經過路徑2,除了其梯度為 x + 1 之外,還要乘上 路徑2的梯度 1
  • 路徑2和路徑1匯聚到一起,所以最終的梯度為 x + 1(路徑1)+ 1 * x(路徑2)= x + 1 + x,剛好等於我們用數學公式計算出來的結果;

自動求導框架正是依靠這些基礎的規則和鏈式求導法則在高效準確的運作。

在絕大多數神經網路的訓練過程中,在計算反向傳播時,前向傳播過程中得到的一些中間變數非常有用(為了方便求導)。在實際操作中,最好程式碼實現對於這些中間變數的快取,這樣在反向傳播的時候也能用上它們。於是視訊記憶體佔用的大頭就是中間結果,也就是所謂的“特徵圖”。對於本文,x 就是前一層輸出的中間結果(特徵圖)。

在適用乘法的求導規則時,要求我們要事先保留下中間結果 x 和 x+1。注意框架定義的乘法及其求導規則是通用規則,乘法的左右兩邊完全可能是不相關的兩個值,所以必須同時保留下來。就是說,x + 1 在其他函式中,可能是 ( x + y ) + z ....,也可能包含其他輸入變數,所以無法通過 + 1 這樣簡單的算式由一個輸入 x 計算出來

在不考慮框架自身優化的情況下,視訊記憶體佔用就包括了一個 x 和 一個 x + 1,注意x可不是一個單獨的數值,而是類似32x32x128這樣大小的特徵圖。

2.2 梯度Checkpoint

如前一節所述,神經網路的原始方式中:

  • 在forward函式中,每層的啟用函式值計算之後需要儲存下來,因為它們需要在後向傳播的計算中被消費。
  • 在backward時,根據損失函式值和該層對應的啟用函式值計算梯度。
  • 因此,我們需要在視訊記憶體中快取 n 份(梯度累加的次數)完整的前向 activation。也就是說,這種情況下視訊記憶體佔用與 層數成正比。

因此,目前流水並行存在一個問題:視訊記憶體佔用太大。

是否可以不儲存啟用值?比如在backward時,需要啟用函式值的時候重新進行forward就可以了。

假如我們一個都不儲存,都通過forward重新計算?那麼在大模型中這樣消耗的時間太大。所以我們可以選用折中的方式,比如只存部分層的啟用函式值。當backward需要啟用函式值的時候,取最近的啟用值就行。所以就引入了一項重要的技術:重計算(Checkpointing)。

2.3 論文內容

2.3.1 主要論文

Gpipe 的 Checkpointing 主要思路來自以下兩篇論文:

  • Andreas Griewank and Andrea Walther. Algorithm 799: revolve: an implementation of check- pointing for the reverse or adjoint mode of computational differentiation. ACM Transactions on Mathematical Software (TOMS), 26(1):19–45, 2000.
  • Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174, 2016.

主要思路是用算力換記憶體(計算換視訊記憶體,反向求導時需要的中間結果從 checkpoint 重新計算),以及用頻寬換視訊記憶體。

2.3.2 論文 Training Deep Nets with Sublinear Memory Cost

2.3.2.1 主要思路

我們主要來看這篇論文。

Checkpointing 是陳天奇在2016年發表的論文 Training Deep Nets with Sublinear Memory Cost 中提到的,也稱之為亞線性記憶體優化。亞線性記憶體優化有兩種思路,Checkpointing 和 CPU offload:

  • Checkpointing 的核心思想 是在前向網路中標記少量的 Tensor (被 Checkpointing 的 Tensor ),前向計算就只會保留這些被標記的 Tensor, 其餘的前向的 activation,會通過在反向傳播中根據 Checkpointing 的 Tensor 臨時重新計算一遍前向得到。這樣就使得大量的 activation 不需要一直儲存到後向計算,有效減少了大量 Tensor 的生命週期,使得記憶體複用效率大幅提升。
  • CPU offload 的思路類比於計算機作業系統中的“虛擬記憶體”技術(將不常用的記憶體臨時換入換出到磁碟上,從而增加記憶體總量),在深度學習中,GPU 視訊記憶體(Device Memory)的特點是昂貴、高速且容量小,而 CPU 主存(Host Memory)的特點是便宜、相對低速和大容量;那麼將前向計算中的一些暫時用不到的 activation 臨時換出到 CPU 主存上,等到反向計算需要時再換入到 GPU 視訊記憶體裡,通過這種方式也可以節省視訊記憶體。

兩種亞線性記憶體優化通過不同的方式達到了視訊記憶體優化:Checkpointing 是通過額外的計算開銷換視訊記憶體, CPU offload 通過額外的傳輸開銷換視訊記憶體。

2.3.2.2 Checkpointing 優化

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

上圖展示了做 Checkpointing 之前和之後的計算圖對比。

左面灰色的是網路配置。

中間 Normal Gradient Graph 是普通網路的前向後向傳播流程。

右面 Memory Optimized Gradient Graph 就是應用了 gradient-checkpoint 的結果。為了進一步減少記憶體,會刪除一些中間結果,並在需要時從額外的前向計算中恢復它們。

  • 首先,神經網路分為幾個部分(右面圖中就分成了三段),該演算法只記住每一段的輸出,並在每一段中刪除所有中間結果。
  • 其次,在反向傳播階段,我們可以通過從最近的記錄結果向前執行來重新計算丟棄的中間結果。
  • 因此,我們只需支付儲存每個段的輸出的記憶體成本加上在每個段上進行反向傳播的最大記憶體成本。

所以gradient-checkpoint 就是並非是不需要中間結果,而是有辦法在求導過程中實時的計算出之前被捨棄掉的中間結果

重計算並不是單獨為流水並行設計的,並且之前大多使用在單卡或者資料並行場景下。但這個優化在流水並行下就非常關鍵,因為它使得前向不需要快取所有的 activation,而只需要快取非常少個數的(比如一層 Transformer Layer 只會快取一個 )、被 checkpoint 的特定 Tensor ,從而大大節省了流水並行下的視訊記憶體開銷。

0x03 OpenAI

在OpenAI 提出的gradient-checkpoint 就是論文Training Deep Nets with Sublinear Memory Cost思路的實現,因為其文件比較齊全(https://github.com/openai/gradient-checkpointing),我們可以學習借鑑下。

總體思路是:在神經網路中間設定若干個檢查點(checkpoint),對於中間結果feature map,每隔 sqrt(n)保留一個檢查點。檢查點以外的中間結果全部捨棄,反向傳播求導數的時間,需要某個中間結果時,從最近的檢查點開始計算,這樣既節省了視訊記憶體,又避免了從頭計算的繁瑣過程。

3.1 計算圖

對一個簡單的 n 層前饋神經網路,獲取梯度的計算圖如下所示:

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

具體如下:

  • 神經網路的層級啟用值對應於 f 標記的節點,且在正向傳播過程中,所有這些節點需要按順序計算。
  • 損失函式對啟用值和這些層級引數的梯度使用 b 節點標記,且在反向傳播過程中,所有這些節點需要按逆序計算。
  • 計算 f 節點的啟用值是進一步計算 b 節點梯度的前提要求,因此 f 節點在前向傳播後會保留在記憶體中。
  • 只有當反向傳播執行地足夠遠以令計算對應的梯度不再需要使用後面層級的啟用值或 f 的子節點時,這些啟用值才能從記憶體中清除。這意味著簡單的反向傳播要求記憶體與神經網路的層級數成線性增長關係。

3.2 重計算

簡單的反向傳播已經是計算最優的了,因為每個節點只需要計算一次。然而,如果我們願意重新計算節點,那麼我們可以節省大量的記憶體。當我們需要節點的啟用值時,我們可以簡單地重計算前向傳播的節點啟用值。我們可以按順序執行計算,直到計算出需要使用啟用值進行反向傳播的節點。

使用這一策略,需要令計算梯度的記憶體在神經網路層的數量 n 上是穩定的,且 n 在記憶體方面是最優的。但是要注意,節點的計算數量現在擴充套件了 n^2,相比於之前的 n。n 個節點中的每一個被再計算 n 次。因此計算圖變得很慢以計算深度網路,使得這一方法不適用於深度學習。

3.3 策略

為了在記憶體與計算之間取得平衡,我們需要一個策略允許節點被再計算,但是這種再計算不會發生很頻繁。這裡我們使用的策略是把神經網路啟用的一個子集標記為一個節點。紫色的節點表示在給定的時間內需要儲存在記憶體中。

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

這些檢查點節點在前向傳播後保留在記憶體中,而其餘節點最多隻會重新計算一次。在重新計算後,非檢查點節點將保留在記憶體中,直到不再需要它們來執行反向傳播。對於簡單的前饋神經網路,所有神經元的啟用節點都是由正向傳播定義的連線點或圖的分離點。這意味著我們在反向傳播過程中只需要重計算 b 節點和最後檢查點之間的節點,當反向傳播達到了我們儲存的檢查點節點,那麼所有從該節點開始重計算的節點在記憶體中都能夠移除。

3.4 過程

首先,我們設定了兩個checkpoint,圖上第一行左面兩個紫色,注意,右面第一個紫色是輸入。

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

其次,正向傳播已經完成,開始反向傳播,就是從下面一行紫色1號開始反向傳播。

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

第三,來到了下面一行的紫色2號,它依賴於上面的紫色3號來計算(回憶一下,後向傳播計算需要前向計算的輸出),此紫色3號是checkpoint,在記憶體中存在,所以正常執行反向傳播

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

第四,來到了下面一行的白色 4 號,它依賴於上面的紫色 5 號來計算,5 號不是一個checkpoint,不在記憶體之中,需要重它前面的checkpoint開始計算,即從紫色 7 號開始計算。計算出來一個新的checkpoint,同時可以刪除上面一行原有紫色 5 號,因為不需要了。

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

第五,計算出下面的新紫色 4 號,從而繼續後向計算。

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

因為涉及到自動生成checkpoint,OpenAI這部分程式碼比較晦澀鬼畜,所以這裡不進行分析,如果有興趣的同學可以自行學習。

0x04 Pytorch 實現

我們接下來用Pyorch來看看。

4.1 基礎知識

4.1.1 Variable & Function

在PyTorch中,autograd是所有神經網路的核心內容,為Tensor所有操作提供自動求導方法。它是一個按執行方式定義的框架,這意味著backprop是由程式碼的執行方式定義的。

autograd.Variable 是autograd中最核心的類。 它包裝了一個Tensor,並且幾乎支援所有在其上定義的操作。一旦完成了你的運算,你可以呼叫 .backward()來自動計算出所有的梯度。

另一個對autograd的實現非常重要的類是Function,Function簡單說就是對Variable的運算,如加減乘除,relu,pool等。但它不僅僅是簡單的運算。與普通Python或者numpy的運算不同,Function是針對計算圖,需要計算反向傳播的梯度。因此他不僅需要進行該運算(forward過程),還需要利用cache保留前向傳播的輸入(為計算梯度),並支援反向傳播計算梯度。

Pytorch是利用Variable與Function來構建計算圖的。回顧下Variable,Variable就像是計算圖中的節點,儲存計算結果(包括前向傳播的啟用值,反向傳播的梯度),而Function就像計算圖中的邊,實現Variable的計算,並輸出新的Variable。

總結,Function與Variable構成了pytorch的自動求導機制,它定義的是各個Variable之間的計算關係。

備註:最新 PyTorch 程式碼之中,已經用把 Function 修改為 Node 類,應該是為了更好的表示計算圖中節點的概念。

4.1.2 Function進一步理解

我們可以使用autograd.Function類來自定義一個模型、一個層、一個啟用函式、一個損失函式,就更加好理解了,實際上本質上來說都是一個函式,只分這個函式是簡單還是複雜。

4.2 普通模式

這部分程式碼位於torch/utils/checkpoint.py。pytorch是需要使用者指定checkpoint,因此實現相對簡單很多。

4.2.1 封裝

在 torch/utils/checkpoint.py 之中,對checkpoint有了一個封裝,該註釋非常值得我們閱讀,我們深入學習一下。

  • Checkpointing 本質就是用計算換記憶體。

  • Checkpointing 儲存用於後向計算所需要的整個計算圖的全部中間啟用值,而是在反向傳播中重新計算它們。

  • 在前向傳播過程中,Checkpointing 引數 function 是執行在 torch.no_grad 模式,這樣就不會計算中間啟用值了。相反,向前傳遞儲存輸入元組和function引數。

  • 在向後傳遞中,儲存的輸入和function被取出,function將再次被計算,這次會跟蹤中間啟用值,然後

    使用這些啟用值計算梯度。

def checkpoint(function, *args, **kwargs):
    r"""Checkpoint a model or part of the model

    Checkpointing works by trading compute for memory. Rather than storing all
    intermediate activations of the entire computation graph for computing
    backward, the checkpointed part does **not** save intermediate activations,
    and instead recomputes them in backward pass. It can be applied on any part
    of a model.

    Specifically, in the forward pass, :attr:`function` will run in
    :func:`torch.no_grad` manner, i.e., not storing the intermediate
    activations. Instead, the forward pass saves the inputs tuple and the
    :attr:`function` parameter. In the backwards pass, the saved inputs and
    :attr:`function` is retrieved, and the forward pass is computed on
    :attr:`function` again, now tracking the intermediate activations, and then
    the gradients are calculated using these activation values.

    The output of :attr:`function` can contain non-Tensor values and gradient
    recording is only performed for the Tensor values. Note that if the output
    consists of nested structures (ex: custom objects, lists, dicts etc.)
    consisting of Tensors, these Tensors nested in custom structures will not
    be considered as part of autograd.

    Args:
        function: describes what to run in the forward pass of the model or
            part of the model. It should also know how to handle the inputs
            passed as the tuple. For example, in LSTM, if user passes
            ``(activation, hidden)``, :attr:`function` should correctly use the
            first input as ``activation`` and the second input as ``hidden``
        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.
        args: tuple containing inputs to the :attr:`function`

    Returns:
        Output of running :attr:`function` on :attr:`*args`
    """
    # Hack to mix *args with **kwargs in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
        raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

    return CheckpointFunction.apply(function, preserve, *args)

4.2.2 處理裝置

因為pytorch無法知道向前傳播函式是否會把一些引數移動到不同的裝置上,這就需要一些邏輯來儲存為這些裝置儲存RNG狀態。雖然可以為所有可見裝置儲存/恢復所有的RNG狀態,但是這樣在大多數情況下是一種浪費,因此作為折中,pytorch只是針對所有的張量引數的裝置進行儲存RNG狀態。

def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
    # the conditionals short-circuit.
    fwd_gpu_devices = list(set(arg.get_device() for arg in args
                               if isinstance(arg, torch.Tensor) and arg.is_cuda))

    fwd_gpu_states = []
    for device in fwd_gpu_devices:
        with torch.cuda.device(device):
            fwd_gpu_states.append(torch.cuda.get_rng_state())

    return fwd_gpu_devices, fwd_gpu_states


def set_device_states(devices, states) -> None:
    for device, state in zip(devices, states):
        with torch.cuda.device(device):
            torch.cuda.set_rng_state(state)

4.2.3 核心邏輯

CheckpointFunction 繼承了torch.autograd.Function。

我們可以對Function進行擴充,使其滿足我們自己的需要,而擴充就需要自定義Function的forward運算,以及對應的backward運算,同時在forward中需要通過儲存輸入值用於backward。

  • forward函輸入tensor,計算輸出tensor。

    • 在前向傳播過程中,Checkpointing 引數 function 是執行在 torch.no_grad 模式,這樣就不會計算中間啟用值了。
    • 向前傳遞儲存輸入元組和function引數。
    • 對於CheckpointFunction來說,還是需要在forward之中儲存一些另外的資訊(就是上面說的 rng 資訊),以供後向傳播時候計算使用。
    • 進行前向傳播返回啟用值。
  • backward函式接收相對於某個標量值的輸出張量的梯度,並且計算關於該相同標量值的輸入張量的梯度。

    • 在向後傳遞中,儲存的輸入和function被取出。
    • function將再次被計算,這次會跟蹤中間啟用值,然後使用這些啟用值計算梯度。
"""
我們可以通過建立torch.autograd的子類來實現我們自定義的autograd函式,
並完成張量的正向和反向傳播。
"""
class CheckpointFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        """
        在forward函式中,接收包含輸入的Tensor並返回包含輸出的Tensor。
        ctx是環境變數,用於提供反向傳播是需要的資訊。我們可以使用上下文物件來快取物件,以便在反向傳播中使用。可通過ctx.save_for_backward方法快取資料,save_for_backward只能傳入Variable或是Tensor的變數。
        """
        check_backward_validity(args)
        # 儲存前向傳播函式
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            # 儲存前向傳播時候的裝置狀態
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)

        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
        # to be filled out during the backward.
        ctx.inputs = [] 
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args): # 儲存輸入數值
            if torch.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)

        # `saved_for_backward`是會保留此input的全部資訊, 並避免in-place操作導致的input在backward被修改的情況. 它是將函式的輸入引數儲存起來以便後面在求導時候再使用,起前向反向傳播中協調作用。      
        ctx.save_for_backward(*tensor_inputs)

        with torch.no_grad():
            outputs = run_function(*args) # 進行前向傳播
        return outputs

"""
在反向傳播中,我們接收到上下文物件和一個張量,
其包含了相對於正向傳播過程中產生的輸出的損失的梯度。
我們可以從上下文物件中檢索快取的資料,
並且必須計算並返回與正向傳播的輸入相關的損失的梯度。
"""      
    # 自動求導是根據每個op的backward建立的graph來進行的
    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
                " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
                " argument.")
        # Copy the list to avoid modifying original list.
        inputs = list(ctx.inputs)
        tensor_indices = ctx.tensor_indices
        tensors = ctx.saved_tensors # 獲取前面儲存的引數,也可以使用self.saved_variables

        # Fill in inputs with appropriate saved tensors.
        for i, idx in enumerate(tensor_indices): # 利用儲存的張量重新設定input
            inputs[idx] = tensors[i]

        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrounding state
        # when we're done.
        # 儲存目前rng狀態,模擬前向傳播狀態,最後恢復目前狀態
        rng_devices = [] 
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
            if ctx.preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_state) # 恢復前向傳播時候的裝置狀態
                if ctx.had_cuda_in_fwd:
                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
            detached_inputs = detach_variable(tuple(inputs))
            with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
                # 利用前向傳播函式再次計算
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)

        # run backward() with only tensor that requires grad
        outputs_with_grad = [] # 啟用值
        args_with_grad = [] # 梯度
        # 從前向傳播計算的結果中篩選需要傳播的張量
        for i in range(len(outputs)): 
            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
                outputs_with_grad.append(outputs[i])
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError(
                "none of output has requires_grad=True,"
                " this checkpoint() is not necessary")
            
        # 開始後向傳播    
        torch.autograd.backward(outputs_with_grad, args_with_grad)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
                      for inp in detached_inputs)

        return (None, None) + grads

4.3 Pipeline模式

我們接下來看看 流水線模式如何進行 Checkpoint。

Pytorch 流水型並行模式是受到了GPipe的啟發,在其註釋之中有提到。

通過CheckpointFunction,pytorch可以做到把重計算和遞迴反向傳播合併到一個自動求導函式中,因此當梯度到達時,重計算就會開始。但是在流水線模式中,為了縮減GPU idle時間,重計算需要發生在梯度到達之前進行(因為重計算其實和梯度無關,重計算可以在梯度到來之前進行以獲得啟用值,等後向傳播的梯度來了之後,再集合啟用值進行自己的梯度計算)。

為了解決這個問題,pytorch引入了兩個自動求導函式:class:Recompute and class:Checkpoint,分別代表重計算和遞迴反向傳播就是把普通模式下的 CheckpointFunction 分離成兩個階段,這樣用這兩個函式就可以控制自動求導引擎和CUDA。具體說就是在class:Recompute and class:Checkpoint之間插入CUDA同步,這樣把class:Checkpoint 推遲到梯度完全拷貝結束。

分開段,就可以多個流水線stage並行了。

4.3.1 樣例

我們可以先看看 test/distributed/pipeline/sync/test_checkpoint.py 這個程式碼。

其通過log的巧妙列印,可以讓我們看出來執行時候,checkpoint在前向後向傳播之中的使用。

timeline 最後結果是 ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"],

其中兩兩一組,分別對應了 forward pass ,Checkpoint(Log[b]),Checkpoint(Log[a])。

@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
    # Copied from https://github.com/pytorch/pytorch/pull/18568.
    timeline = []

    class Log(torch.autograd.Function):
        @staticmethod
        def forward(ctx, name, x):
            ctx.name = name
            timeline.append(f"{name}:forward")
            return x.detach()

        @staticmethod
        def backward(ctx, grad_output):
            name = ctx.name
            timeline.append(f"{name}:backward")
            return None, grad_output

    a = torch.rand(1, device=device, requires_grad=True)
    b = torch.rand(1, device=device, requires_grad=True)

    # Increase the next function sequence number.
    _ = a + 1 + 2 + 3 + 4 + 5

    # 這裡意味著最後 backward 實際會執行"a:forward", "a:backward"
    a = checkpoint(partial(Log.apply, "a"), a)

    a, phony = fork(a)
    b = join(b, phony)

    # 這裡意味著最後 backward 實際會執行"b:forward", "b:backward"
    b = checkpoint(partial(Log.apply, "b"), b)

    c = torch.cat((a, b))

    out = c.sum()

    #                        +--> {a} --Checkpoint(Log)--> {a}
    # {out} --Sum--> {c} --Cat     ^-----------------------------+
    #                        +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
    out.backward()

    assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
    #    |----------------------|  |-----------------------|  |-----------------------|
    #          forward pass            Checkpoint(Log[b])         Checkpoint(Log[a])

4.3.2 共享變數

class:Recompute and class:Checkpoint之間具體是通過Context這個上下文來進行共享變數的儲存。

# Types for shared memory between Checkpoint and Recompute.

Recomputed = Tuple[TensorOrTensors, Tensors]  # (output, input_leaf)
RNGStates = Tuple[Tensor, Optional[Tensor]]  # (cpu_rng_state, gpu_rng_state)

class Context:
    """The common interface between the :class:`Checkpoint` and
    :class:`Recompute` context.
    """

    recomputed: Deque[Recomputed]
    rng_states: Deque[RNGStates]
    function: Function
    input_atomic: bool

    saved_tensors: Tuple[Tensor, ...]

    def save_for_backward(self, *tensors: Tensor) -> None:  # pragma: no cover
        pass

4.3.3 rng state

根據執行時的不同,RNG狀態可能會產生不同的效能影響,所以需要在每個檢查點期間儲存當前裝置的RNG狀態,在重計算之前恢復當前裝置的RNG狀態。

save_rng_states 和 restore_rng_states 兩個方法分別用來存取 RNG 狀態。

def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
    """:meth:`Checkpoint.forward` captures the current PyTorch's random number
    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state = torch.get_rng_state()

    gpu_rng_state: Optional[Tensor]
    if device.type == "cuda":
        gpu_rng_state = torch.cuda.get_rng_state(device)
    else:
        gpu_rng_state = None

    rng_states.append((cpu_rng_state, gpu_rng_state))


@contextmanager
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
    """:meth:`Recompute.backward` restores the random number generator states
    captured by :func:`save_rng_states` within its context.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state, gpu_rng_state = rng_states.pop()

    gpu_devices: List[torch.device] = []
    if device.type == "cuda":
        gpu_devices.append(device)

    with torch.random.fork_rng(gpu_devices):
        torch.set_rng_state(cpu_rng_state)
        if gpu_rng_state is not None:
            torch.cuda.set_rng_state(gpu_rng_state, device)
        yield

4.3.4 Checkpoint

Checkpoint 和下面的 Recompute 就是把普通模式下的 checkpoint 程式碼分離成兩個階段(forward函式被分成兩段,backward 函式也被分成兩段),從而可以更好的利用流水線。

class Checkpoint(torch.autograd.Function):
    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> TensorOrTensors:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states

        # 存RNG狀態
        save_rng_states(input[0].device, ctx.rng_states)

        ctx.function = function
        ctx.input_atomic = input_atomic
        # 為BP做準備,其實目前沒有實現
        ctx.save_for_backward(*input)

        # 進行前向計算
        with torch.no_grad(), enable_checkpointing():
            output = function(input[0] if input_atomic else input)

        return output

    @staticmethod
    def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:  # pragma: no cover
        # 從儲存的重計算變數中彈出所需變數
        output, input_leaf = ctx.recomputed.pop() 

        if isinstance(output, tuple):
            tensors = output
        else:
            tensors = (output,)
            
        if any(y.requires_grad for y in tensors):
            tensors = tuple([x for x in tensors if x.requires_grad])
            # 進行自動微分
            torch.autograd.backward(tensors, grad_output)

        grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
        grad_input.extend(x.grad for x in input_leaf)
        return tuple(grad_input)

4.3.5 Recompute

Recompute 就是依據儲存的資訊,重新計算中間變數。

class Recompute(torch.autograd.Function):
  
    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> Tensor:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states

        ctx.function = function
        ctx.input_atomic = input_atomic
        ctx.save_for_backward(*input)

        return phony

    @staticmethod
    def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]:  
        input = ctx.saved_tensors
        input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input)

        # 取出儲存的RNG狀態,進行前向計算,得到中間變數
        with restore_rng_states(input[0].device, ctx.rng_states):
            with torch.enable_grad(), enable_recomputing():
                output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf)

        # 儲存變數,為Checkpoint使用
        ctx.recomputed.append((output, input_leaf))

        grad_input: List[None] = [None, None, None, None, None]
        grad_input.extend(None for _ in ctx.saved_tensors)
        return tuple(grad_input)

4.3.6 Pipeline

4.3.6.1 Task

我們首先要看看 Task 類。程式碼位於:torch/distributed/pipeline/sync/worker.py。

由註釋可知,Task 就是用來在一個分割槽上計算一個micro-batch。

compute可以在worker執行緒內被並行執行。

finalize 應該在compute結束之後被執行。

class Task:
    """A task represents how to compute a micro-batch on a partition.

    It consists of two parts: :meth:`compute` and :meth:`finalize`.
    :meth:`compute` should be executed in worker threads concurrently.
    :meth:`finalize` should be executed after when worker threads complete to
    execute :meth:`compute`.

    :meth:`compute` might be boosted by worker threads. Because it produces
    several CUDA API calls by user code. In PyTorch, parallel CUDA API calls
    are not serialized through GIL. So more than one CUDA API call can be
    produced at the same time.

    """

    def __init__(
        self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]],
    ) -> None:
        self.stream = stream
        self._compute = compute
        self._finalize = finalize
        self._grad_enabled = torch.is_grad_enabled()

    def compute(self) -> Batch:
        with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
            return self._compute()

    def finalize(self, batch: Batch) -> None:
        if self._finalize is None:
            return
        with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
            self._finalize(batch)
4.3.6.2 compute

這裡說的是 Pipeline 類的 compute 函式。

Pipeline 的邏輯如其註釋所示(PyTorch的註釋真的很翔實)。重點是 Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) 這裡設定瞭如何進行checkpoint。

可以看到,這裡會將 recompute 方法設定為 Task 的 finalize 方法,然後會計劃重計算。

class Pipeline:
    """The pipeline parallelism for Pipe."""
    
    def compute(
        self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
    ) -> None:
        """Runs tasks with synchronization to copy streams."""
        partitions = self.partitions
        devices = self.devices
        copy_streams = self.copy_streams
        checkpoint_stop = self.checkpoint_stop

        # Disable checkpointing if in eval mode.
        if not self.partitions[0].training:
            checkpoint_stop = 0

        n = len(partitions)
        streams = [current_stream(d) for d in devices]
        exc_info: Optional[ExcInfo] = None

        # With checkpointing, the autograd graph looks like this diagram:
        # ┌─────┸──────┐
        # │    Copy    │
        # └─────┰──────┘   (fence)
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        #       ┃          (compute)
        # ┌─────┸──────┐
        # │    Wait    │ [1] Synchronize the current stream with the copy stream.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │ Checkpoint │ [2] Compute a partition within checkpointing.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │    Wait    │ [3] Synchronize the copy stream with the current stream.
        # └─────┰──────┘
        #       ┠ ─ ─ ─ ┐
        #       ┃ ┌─────┴─────┐
        #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
        #       ┃ └─────┬─────┘
        #       ┠ ─ ─ ─ ┘
        #       ┃
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        # ┌─────┸──────┐   (fence)
        # │    Copy    │
        # └─────┰──────┘
        for i, j in schedule:
            batch = batches[i]
            partition = partitions[j]

            # Synchronize with the copied input. ([1] in the diagram)
            if j != 0:
                _wait(batch, copy_streams[j][i], streams[j])

            # Determine whether checkpointing or not.
            checkpoint = i < checkpoint_stop
            if checkpoint:

                def function(
                    input: TensorOrTensors,
                    partition: nn.Sequential = partition,
                    skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                    chunk_id: int = i,
                    part_id: int = j,
                ) -> TensorOrTensors:
                    with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
                        return partition(input)

                # 這裡進行處理
                chk = Checkpointing(function, batch)
                # 分別設定了chk.checkpoint 和 chk.recompute
                task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
                del function, chk

            else:

                def compute(
                    batch: Batch = batch,
                    partition: nn.Sequential = partition,
                    skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                    chunk_id: int = i,
                    part_id: int = j,
                ) -> Batch:
                    with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
                        return batch.call(partition)

                task = Task(streams[j], compute=compute, finalize=None)
                del compute

            # Compute tasks in parallel. ([2] in the diagram)
            self.in_queues[j].put(task) # 將task插入到 pipeline的queue,這樣可以並行。

        for i, j in schedule: 
            ok, payload = self.out_queues[j].get()

            # Hold the first exception.
            if exc_info is not None:
                continue
            elif not ok:
                exc_info = cast(ExcInfo, payload)
                continue

            # 取出 task    
            task, batch = cast(Tuple[Task, Batch], payload)

            # The copy stream synchronizes to copy the output. ([3] in the
            # diagram)
            if j != n - 1:
                _wait(batch, streams[j], copy_streams[j][i])

            # Finalize tasks. If checkpointing is enabled, here the
            # recomputation is scheduled at backpropagation. ([4] in the
            # diagram)
            with use_device(devices[j]):
                task.finalize(batch) # 計劃進行重計算

            batches[i] = batch

        # Fail at the first exception.
        if exc_info is not None:
            raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

關於 PyTorch 的 Pipeline,後續會有專門系列進行分析。

0x05 Gpipe實現

Gpipe 在反向傳播的時候,可以在第 k-th 個 accelerator 上重新計算前向傳播函式 F_k。

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

5.1 API函式 _Rematerialize

首先,我們看看API方法。

在 builder.py 之中有 _Rematerialize 函式,可以用來包裝一個需要重新計算的層。

  def _Rematerialize(self, name, body):
    """Forces rematerialization on FProp of the body layer."""
    return builder_layers.RematerializationLayer.Params().Set(
        name=name, body=body)

5.2 包裝層 RematerializationLayer

RematerializationLayer 是包裝層,其中有:

FProp 就是把被封裝層 包裝為一個函式 Fn,然後呼叫 py_utils.RematerializeFn 把 Fn 與 輸入變數一起傳入。

class RematerializationLayer(base_layer.BaseLayer):
  """A wrapper layer with rematerialization."""

  @classmethod
  def Params(cls):
    p = super().Params()
    p.Define('body', None,
             'The main layer whose FProp will be wrapped by RematerializeFn.')
    return p

  def __init__(self, params):
    super().__init__(params)
    self.CreateChild('body', self.params.body)

  def FProp(self, theta, *xs):
    input_list = theta.body.Flatten() # 得到theta
    theta_len = len(input_list)
    input_list += list(xs) # 得到輸入引數
    input_len = len(input_list)

    def Fn(*args): # 包裝函式,會呼叫被封裝層的 FProp
      body_theta = theta.body.Pack(args[:theta_len])
      return self.body.FProp(body_theta, *args[theta_len:input_len])

    return py_utils.RematerializeFn(Fn, *input_list) # 呼叫,執行FProp,並且做Gradient checking

  @classmethod
  def FPropMeta(cls, p, *args): # 就是傳播被封裝層的資訊
    py_utils.CheckShapes(args)
    return p.body.cls.FPropMeta(p.body, *args)

3.2.3 tensorflow gradients 函式

RematerializeFn 呼叫了 tensorflow gradients 函式 來計算梯度,所以我們需要解釋下。

在tensorflow中,gradients 函式可以自動計算函式的梯度。我們只需要設計我們的函式,然後去呼叫 tf.gradients 函式就可以了。

tf.gradients()的引數如下,其中

  • tf.gradients()實現ysxs求導
  • grad_ys也是一個list,其長度等於len(ys)。這個引數的意義在於對xs中的每個元素的求導權重。
tf.gradients(ys, xs, 
			 grad_ys=None, 
			 name='gradients',
			 colocate_gradients_with_ops=False,
			 gate_gradients=False,
			 aggregation_method=None,
			 stop_gradients=None)

5.4 功能函式 RematerializeFn

RematerializeFn 是最終功能函式,就是呼叫 fn,並且在反向傳播過程中進行rematerializes fn。

def RematerializeFn(fn, *xs):
  """Calls fn and rematerializes fn in the backward pass.

  `fn(*xs) -> ys`, where xs and ys can be a single tensor or a tuple of tensors.

  Args:
    fn: A python function to be rematerialized in the backprop pass.
    *xs: A single tensor or a list/tuple of tensors. `xs` are input args to the
      fn function.

  Returns:
    `fn(*xs)`
  """
  initial_step_seed = GetStepSeed()
  final_step_seed = MaybeGenerateSeedFromScope()

  def Backward(fwd_xs, fwd_ys, d_fwd_ys):
    """The backward function that rematerializes forward outputs."""
    del fwd_ys # 去掉傳入的引數,因為在內部需要用備份的Checkpoint來處理
    always_true = tf.random.uniform([]) < 2.0
    # Alternatively, can do this:
    # tf.where(tf.math.is_nan(x),
    #          tf.constant(float('nan'), dtype=x.dtype) * tf.ones_like(x),
    #          x)
    bak_xs = [tf.where(always_true, x, tf.zeros_like(x)) for x in fwd_xs.xs] # 依據Checkpoint來生成 bak_xs
    for dst, src in zip(bak_xs, xs):
      dst.set_shape(src.shape)
    ResetStepSeed(initial_step_seed)
    ys = fn(*bak_xs) # 依據Checkpoint來重新生成ys
    MaybeResetStepSeed(final_step_seed)
    dxs = tf.gradients(ys, bak_xs, grad_ys=d_fwd_ys) # ys 對 bak_xs 求導
    dxs_final = [] # 聚合
    for dx, x in zip(dxs, bak_xs):
      if dx is None:
        dxs_final.append(tf.zeros_like(x))
      else:
        dxs_final.append(dx)
    assert len(dxs_final) == len(bak_xs)
    return NestedMap(
        initial_step_seed=tf.zeros_like(initial_step_seed), xs=dxs_final)

  ys_shapes = []

  # TODO(huangyp, yonghui): Check Forward doesn't use any stateful random ops.
  def Forward(fwd_xs):
    """Forward function plus sanity checks."""
    for dst, src in zip(fwd_xs.xs, xs):
      dst.set_shape(src.shape)
    ResetStepSeed(fwd_xs.initial_step_seed)
    ys = fn(*fwd_xs.xs) # 正常計算
    # Some sanity check.
    assert not GetExtraInputs()
    assert not GetExtraArgs()
    assert not GetExtraVars()
    if isinstance(ys, tuple):
      for y in ys:
        assert isinstance(y, tf.Tensor)
        ys_shapes.append(y.shape)
    else:
      assert isinstance(ys, tf.Tensor)
      ys_shapes.append(ys.shape)
    return ys

  ys = CallDefun(
      Forward,
      NestedMap(initial_step_seed=initial_step_seed, xs=xs),
      bak=Backward)
  if isinstance(ys, tuple):
    for y, s in zip(ys, ys_shapes):
      y.set_shape(s)
  else:
    ys.set_shape(ys_shapes[0])
  # TODO(b/129159299): The ResetStepSeed below is needed to work around this
  # bug, which is a problem with global tensors being shared by different
  # inference graphs. It should be replaced with the new step seed value
  # returned from the Forward function when the bug is fixed.
  MaybeResetStepSeed(final_step_seed)
  return ys

CallDefun定義如下,就是把fwd, back封裝起來進行呼叫。其中,Function 的作用是依據一個callable 構建一個TensorFlow graph function

def CallDefun(fwd, args=None, bak=None, bak_as_function=False, device=None):
  """Wraps fwd in a defun with custom gradient bak and calls it with args.

  Args:
    fwd: A callable xs: Nested Structure -> ys: Nested Structure.
    args: A Nested Structure of tf.Tensor or None.
    bak: A callable xs, ys, dys: Nested Structure -> dxs[, dcapture]: Nested
      Structure. The custom backprop function for fwd. bak needs to return
      dcapture if fwd uses any implicitly captured tensors, whose gradients are
      dcapture.
    bak_as_function: Whether to create a TF graph function for bak.
    device: the device on which to run fwd and bak.

  Returns:
    A Nested Structure equivalent to what fwd(args) computes.
  """
  if args is not None:
    args = Transform(tf.convert_to_tensor, args)
  sigs = Function(
      fwd_sig=TensorSpecs(args),
      bak=bak,
      bak_as_function=bak_as_function,
      device=device)(
          fwd=fwd)
  if args is None:
    return sigs()
  else:
    return sigs(args)

至此,GPipe 分析完畢,下一篇開始分析 PipeDream,敬請期待。

0xFF 參考

lingvo框架走讀筆記

Tensorflow實現先累加多個minibatch計算的梯度,再反向傳播

用tensorflow2實現梯度累積

十倍模型計算時間僅增20%:OpenAI開源梯度替換外掛

PipeDream: Fast and Efficient Pipeline Parallel DNN Training

論文解讀系列第五篇:微軟史丹佛等PipeDream快速訓練大規模神經網路

https://cs231n.github.io/neural-networks-3/#gradcheck

https://www.cnblogs.com/geekfx/p/14182048.html

訓練時視訊記憶體優化技術——OP合併與gradient checkpoint

Pytorch筆記04-自定義torch.autograd.Function

PyTorch教程之Autograd

pytorch的自定義擴充之(三)——torch.autograd.Function的簡單定義與案例

pytorch的自定義擴充之(二)——torch.autograd.Function完成自定義層

PyTorch 原始碼解讀之 torch.autograd:梯度計算詳解

再談反向傳播(Back Propagation)

CS231n課程筆記翻譯:反向傳播筆記