[原始碼解析] PyTorch 流水線並行實現 (6)--平行計算

羅西的思考發表於2021-10-10

[原始碼解析] PyTorch 流水線並行實現 (6)--平行計算

0x00 摘要

前幾篇文章我們介紹了 PyTorch 流水線並行的基本知識,自動平衡機制和切分資料,本文我們結合論文內容來看看如何實現流水線。

流水線並行其他文章連結如下:

[原始碼解析] 深度學習流水線並行Gpipe(1)---流水線基本實現

[原始碼解析] 深度學習流水線並行GPipe (2) ----- 梯度累積

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

[原始碼解析] 深度學習流水線並行之PipeDream(1)--- Profile階段

[原始碼解析] 深度學習流水線並行 PipeDream(2)--- 計算分割槽

[原始碼解析] 深度學習流水線並行 PipeDream(3)--- 轉換模型

[原始碼解析] 深度學習流水線並行 PipeDream(4)--- 執行時引擎

[原始碼解析] 深度學習流水線並行 PipeDream(5)--- 通訊模組

[原始碼解析] 深度學習流水線並行 PipeDream(6)--- 1F1B策略

[原始碼解析] PyTorch 流水線並行實現 (1)--基礎知識

[原始碼解析] PyTorch 流水線並行實現 (2)--如何劃分模型

[原始碼解析] PyTorch 流水線並行實現 (3)--切分資料和執行時系統

[原始碼解析] PyTorch 流水線並行實現 (4)--前向計算

[原始碼解析] PyTorch 流水線並行實現 (5)--計算依賴

本文圖片來自論文和github原始碼。

0x01 總體架構

我們首先從整體角度來梳理一下 torchgpipe。

1.1 使用

我們使用原始碼中的測試例子來進行分析。示例中有一個由三個層組成的Sequential模型,被GPipe封裝之後,進行前向和後向傳播。

class Layer1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 1)

    def forward(self, input):
        yield stash('1to3', input)
        output = self.conv(input)
        return output

class Layer2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 1)

    def forward(self, input):
        output = self.conv(input)
        return output

class Layer3(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 1)

    def forward(self, input):
        skip_1to3 = yield pop('1to3')
        output = self.conv(input) + skip_1to3
        return output

model = nn.Sequential(Layer1(), Layer2(), Layer3()) # 構建了一個Sequential
model = GPipe(model, balance, chunks=3, checkpoint=checkpoint) #在 Sequential 基礎上構建 GPipe

in_device = model.devices[0]
out_device = model.devices[-1]

input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True)
output = model(input) # 這裡將呼叫到 GPipe.forward
loss = output.mean()
loss.backward() # 這裡會進行反向傳播

1.2 前向傳播

GPipe 的前向傳播之中做了如下操作:

  • 利用 scatter 函式把輸入分割,就是把 mini-batch 分割為 micro-batches。
  • 利用 _ensure_copy_streams 方法針對每個裝置生成新的 CUDA stream。
  • 生成一個 Pipeline,並且執行。
  • 執行結束之後,利用 gather 方法把micro-batches 合併成一個 mini-batch。

因此我們可以看到,對於每次迭代的 forward 操作,都會生成一個 Pipeline 類進行操作,返回給呼叫者。

def forward(self, input: TensorOrTensors) -> TensorOrTensors:  # type: ignore
    """:class:`GPipe` is a fairly transparent module wrapper. It doesn't
    modify the input and output signature of the underlying module. But
    there's type restriction. Input and output have to be a
    :class:`~torch.Tensor` or a tuple of tensors. This restriction is
    applied at partition boundaries too.

    Args:
        input (torch.Tensor or tensors): input mini-batch

    Returns:
        tensor or tensors: output mini-batch

    Raises:
        TypeError: input is not a tensor or tensors.

    """
    microbatch.check(input)

    if not self.devices:
        # Empty sequential module is not illegal.
        return input

    # Divide a mini-batch into micro-batches.
    batches = microbatch.scatter(input, self.chunks)

    # Separate CUDA streams for copy.
    copy_streams = self._ensure_copy_streams()

    # The micro-batch index where the checkpointing stops.
    if self.training:
        checkpoint_stop = {
            'always': self.chunks,
            'except_last': self.chunks-1,
            'never': 0,
        }[self.checkpoint]
    else:
        checkpoint_stop = 0

    # Run pipeline parallelism.
    pipeline = Pipeline(batches,
                        self.partitions,
                        self.devices,
                        copy_streams,
                        self._skip_layout,
                        checkpoint_stop)
    pipeline.run()

    # Merge the micro-batches into one mini-batch.
    output = microbatch.gather(batches)
    return output

_ensure_copy_streams 方法就是針對每個裝置生成新的 CUDA stream

def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
    """Ensures that :class:`GPipe` caches CUDA streams for copy.

    It's worth to cache CUDA streams although PyTorch already manages a
    pool of pre-allocated CUDA streams, because it may reduce GPU memory
    fragementation when the number of micro-batches is small.

    """
    if not self._copy_streams:
        for device in self.devices:
            self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])

    return self._copy_streams

1.3 Pipeline 類

在 Pipeline 類的 run 方法之中按照時鐘週期來啟動計算,這樣在前向傳播之中,就按照這個序列,像水波紋一樣擴散。

def run(self) -> None:
    """Runs pipeline parallelism.

    It modifies the given batches in place.

    """
    batches = self.batches
    partitions = self.partitions
    devices = self.devices
    skip_layout = self.skip_layout

    m = len(batches)
    n = len(partitions)

    skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]

    with spawn_workers(devices) as (in_queues, out_queues):
        for schedule in clock_cycles(m, n): # 這裡使用,給出了執行序列計劃,後續按照這個來執行
            self.fence(schedule, skip_trackers) # 拷貝,設定依賴
            self.compute(schedule, skip_trackers, in_queues, out_queues) # 啟動各種Task

1.3.1 構建依賴

在 Pipeline 之中,fence 方法(省略部分程式碼)利用 depend 來構建後向傳播的依賴關係。

    def fence(self,
              schedule: List[Tuple[int, int]],
              skip_trackers: List[SkipTrackerThroughPotals],
              ) -> None:
        """Copies micro-batches after computation for the previous
        micro-batches.
        """
        batches = self.batches
        copy_streams = self.copy_streams
        skip_layout = self.skip_layout

        for i, j in schedule:
            # Ensure that batches[i-1] is executed after batches[i] in
            # backpropagation by an explicit dependency.
            if i != 0:
                depend(batches[i-1], batches[i]) # 在這裡建立了後向傳播依賴關係
                
            next_stream = copy_streams[j][i]

            for prev_j, ns, name in skip_layout.copy_policy(j):
                prev_stream = copy_streams[prev_j][i]
                skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)

            if j != 0:
                prev_stream = copy_streams[j-1][i]
                # 從之前的micro-batches進行拷貝
                copy(batches[i], prev_stream, next_stream)                

1.3.2 Queue

Worker 和主執行緒之間使用了 Python 的 Queue 資料結構進行互動。Queue 類實現了一個基本的先進先出(FIFO)容器,使用 put() 將元素新增到序列尾端,get() 從佇列尾部移除元素。

A multi-producer, multi-consumer queue.

兩個關鍵函式是:

  • get([block, [timeout]]) 讀佇列,timeout為等待時間,如果佇列滿,則阻塞。
  • put(item, [block, [timeout]]) 寫佇列,timeout為等待時間,如果佇列空,則阻塞。

1.3.3 計算

具體訓練是通過 compute 函式完成。

def compute(self,
            schedule: List[Tuple[int, int]],
            skip_trackers: List[SkipTrackerThroughPotals],
            in_queues: List[InQueue],
            out_queues: List[OutQueue],
            ) -> None:
    """Runs tasks with synchronization to copy streams."""
    batches = self.batches
    partitions = self.partitions
    devices = self.devices
    copy_streams = self.copy_streams
    checkpoint_stop = self.checkpoint_stop

    n = len(partitions)
    streams = [current_stream(d) for d in devices]
    exc_info: Optional[ExcInfo] = None

    # With checkpointing, the autograd graph looks like this diagram:
    # ┌─────┸──────┐
    # │    Copy    │
    # └─────┰──────┘   (fence)
    # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
    #       ┃          (compute)
    # ┌─────┸──────┐
    # │    Wait    │ [1] Synchronize the current stream with the copy stream.
    # └─────┰──────┘
    # ┌─────┸──────┐
    # │ Checkpoint │ [2] Compute a partition within checkpointing.
    # └─────┰──────┘
    # ┌─────┸──────┐
    # │    Wait    │ [3] Synchronize the copy stream with the current stream.
    # └─────┰──────┘
    #       ┠ ─ ─ ─ ┐
    #       ┃ ┌─────┴─────┐
    #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
    #       ┃ └─────┬─────┘
    #       ┠ ─ ─ ─ ┘
    #       ┃
    # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
    # ┌─────┸──────┐   (fence)
    # │    Copy    │
    # └─────┰──────┘
    for i, j in schedule:
        batch = batches[i]
        partition = partitions[j]

        # Synchronize with the copied input. ([1] in the diagram)
        if j != 0: # 等待拷貝結束
            wait(batch, copy_streams[j][i], streams[j])

        # Determine whether checkpointing or not.
        checkpoint = (i < checkpoint_stop)
        if checkpoint:
            def function(input: TensorOrTensors,
                         partition: nn.Sequential = partition,
                         skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                         ) -> TensorOrTensors:
                with use_skip_tracker(skip_tracker):
                    return partition(input)

            chk = Checkpointing(function, batch)
            task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
            del function, chk

        else:
            def compute(batch: Batch = batch,
                        partition: nn.Sequential = partition,
                        skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                        ) -> Batch:
                with use_skip_tracker(skip_tracker):
                    return batch.call(partition)

            task = Task(streams[j], compute=compute, finalize=None)
            del compute

        # Compute tasks in parallel. ([2] in the diagram)
        in_queues[j].put(task) # 並行執行操作

    for i, j in schedule:
        # 等待執行結果
        ok, payload = out_queues[j].get()

        # Hold the first exception.
        if exc_info is not None:
            continue
        elif not ok:
            exc_info = cast(ExcInfo, payload)
            continue

        task, batch = cast(Tuple[Task, Batch], payload)

        # The copy stream synchronizes to copy the output. ([3] in the
        # diagram)
        if j != n-1: # 拷貝輸出
            wait(batch, streams[j], copy_streams[j][i])

        # Finalize tasks. If checkpointing is enabled, here the
        # recomputation is scheduled at backpropagation. ([4] in the
        # diagram)
        with use_device(devices[j]):
            task.finalize(batch)

        batches[i] = batch

    # Fail at the first exception.
    if exc_info is not None:
        raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

我們總結梳理一下大致業務邏輯(並行邏輯):

  1. 系統呼叫 spawn_workers 來生成若干 workers。
    1. spawn_workers 為每個 device 生成了一個 Thread,這個 Thread 的執行函式是 worker。spawn_workers 內部也會針對每一個device生成一個 in_queue, out_queue。所以可保證每個device之上是序列來執行業務操作。
    2. 這些 queues 被新增到 (in_queues, out_queues) 之中。然後把 (in_queues, out_queues) 返回給 Pipeline 主執行緒。之後就是使用 (in_queues, out_queues) 作為各個task 之間傳遞資訊的上下文。
  2. Pipeline 主執行緒得到 (in_queues, out_queues) 之後,使用clock_cycles 演算法生成一系列迭代,每個迭代是一個schedule。
  3. 對於每個迭代(schedule),先用fence來進行拷貝stream & 設定依賴,然後使用 compute 來進行訓練。這就順序啟動了多個 compute
  4. 在每個 compute 之中,遍歷這個 schedule,對於其中 (i, j) 執行一個Task,即找到其device對應的in_queue,把Task插進去。
  5. Worker Thread 阻塞在 in_queue 之上,如果發現有內容,就讀取 Task,執行。雖然多個 compute 是順序執行,但是因為compute 只是一個插入queue操作,可以立即返回。而多個 worker Thread 阻塞在 queue 之上,這之後是可以並行訓練的
  6. Worker Thread 把執行結果插入到 out_queue之中。
  7. compute 方法會取出 out_queue 之中的執行結果,進行後續處理。

具體如下圖。

         +-------------------------------------------------------------------+       +-----------------------------------------+
         | Pipeline                                                          |  1    | spawn_workers                           |
         |                                     spawn_workers(devices)  +-----------> |                                         |
         |                                                                   |       | +-------------------------------------+ |
         |               for schedule in clock_cycles(m, n)                  |       | | workers                             | |
         |                     +                                             |       | |                                     | |
         |                     | 2                                           |       | |                                     | |
         |                     |                                             |       | |  device 1 : in_queue 1, out_queue 1 | |
         |                     +-----------+---------------+                 |       | |                                     | |
         |                     |           |               |                 |       | |  device 2 : in_queue 2, out_queue 2 | |
         |                     v           v               v                 |       | |                                     | |
         |  +------------------+------+        +-----------+--------------+  |       | |  device 3 : in_queue 3, out_queue 3 | |
         |  | compute                 |        | compute                  |  |       | |                                     | |
         |  |                         |  3     |                          |  |       | |                                     | |
         |  |  in_queues[j].put(task) |        |   in_queues[j].put(task) |  |       | +-------------------------------------+ |
         |  |                         | ...... |                          |  |       |                                         |
         |  |  out_queues[j].get()    |        |   out_queues[j].get()    |  |       +-----------------------------------------+
         |  |                         |        |                          |  |
         |  +----------+---+----------+        +----------------+----+----+  |
         |             |   ^                                    ^    |       |
         |             |   |                                    |    |       |
         +-------------------------------------------------------------------+
                     7 |   | 4                                7 |    | 4
                       |   |                                    |    |
                       v   |                                    |    v
                 +-----+---+------------------------------------+----+-----+
                 |                in_queues        out_queues              |
+------------>   |                                                         |  <--------------------+
|                +-----+---------------------------------------------+-----+                       |
| 6                    |                                             |                           6 |
|                    5 |                                             | 5                           |
|                      |                                             |                             |
|                      |                                             |                             |
|    +-------------------------------------+          +-------------------------------------+      |
|    | Thread 1        |        device 1   |          | Thread 2     |             device 3 |      |
|    |                 |                   |          |              |                      |      |
|    | +---------------------------------+ |          | +---------------------------------+ |      |
|    | | Worker        |                 | |          | | Worker     |                    | |      |
|    | |               v                 | |          | |            v                    | |      |
|    | |  task = in_queue.get()          | |          | |   task = in_queue.get()         | |      |
|    | |                                 | |  ......  | |                                 | |      |
|    | |  batch = task.compute()         | |          | |   batch = task.compute()        | |      |
|    | |                                 | |          | |                                 | |      |
+--------+out_queue.put((task, batch)))  | |          | |   out_queue.put((task, batch))+--------->+
     | |                                 | |          | |                                 | |
     | +---------------------------------+ |          | +---------------------------------+ |
     +-------------------------------------+          +-------------------------------------+

手機如下:

0x02 並行拷貝和計算

我們接下來分析並行拷貝和計算(Concurrent Copy and Computation: Streams)。

2.1 GPU並行操作

我們首先看看 GPU 提供的並行操作功能。

CUDA流表示一個GPU操作佇列,即某個裝置繫結的,按照順序執的核(kernel)序列。我們可以把一個流看作是GPU之上的一個任務。使用者向流的佇列上新增一系列操作,GPU會按照新增到流中的先後順序而依次執行這一系列操作。在同一個流之中,所有操作是序列序列化,因此這些操作永遠不會並行。因此,要想並行,兩個操作必須位於不同的 stream 中。不同流中的核函式可以交錯,甚至可能重疊。

幾乎所有具有計算能力1.1及更高計算能力的CUDA裝置都支援併發複製和執行,即裝置重疊(Device Overlap)功能,其特點如下:

  1. 資料拷貝和數值計算可以並行。
  2. 兩個方向的拷貝可以並行(GPU到CPU,CPU到GPU)。
  3. 進行數值計算的kernel不能讀寫正在拷貝的資料。

因為 CPU 記憶體一般來說是大於 GPU記憶體,因此不可能把 CPU 記憶體一次性都拷貝到GPU,需要分塊傳輸。所以裝置重疊功能就能夠很好提高GPU程式的執行效率,比如:

  1. 將資料拆分成為許多塊,每一塊交給一個Stream來處理。
  2. 每一個Stream會進行如下操作:
    1. 將屬於該Stream的資料從host記憶體拷貝進device記憶體,
    2. GPU進行 kernel 運算,將計算結果儲存到GPU記憶體,
    3. 把 Stream計算結果從device 記憶體拷貝回host記憶體。
  3. GPU的scheduler決定 stream 如何並行。
  4. CPU 的操作也可以同時並行。

2.2 PyTorch

除非另有指定,PyTorch將每個繫結到裝置的核函式釋出到預設流。因為前向傳播位於 default stream 中,所以要想並行處理 "下一個 batch 資料的預讀取(拷貝CPU到GPU)" 和 "當前 batch 的前向傳播",就必須做到:

  • cpu 上的 batch 資料 必須是 pinned。鎖頁可以使得硬體裝置直接訪問CPU記憶體,這樣就減少了某些複製操作,鎖定的頁面不可以被交換到硬碟之上。在GPU上分配的記憶體預設都是鎖頁記憶體。
  • 預讀取操作必須在另一個 stream 上進行。

Torchgpipe將每個拷貝核註冊到非預設流中,同時將計算核保留在預設流中。這允許裝置j處理 \(F_{i,j}\)的同時也會傳送 \(x^j_{i-1}\) 到裝置 \(j+1\) 和/或 從裝置 \(j-1\) 接受 \(x_i^{j-1}\)

此外,每個device對每個微批次使用不同的流。由於不同的微批次之間沒有真正的依賴關係,因此流的這種使用是安全的,這允許儘可能快地進行拷貝。請參見下圖。

圖上表示的是裝置 j 的時間線,是否使用非預設流進行復制

  • (a)部分的意思是:僅使用預設流,複製核可能會阻塞計算核(反之亦然),直到複製完全完成。

  • (b)部分的意思是:使用複製流,計算可以與從其他裝置傳送或接收資料同時進行。

2.3 Stream 封裝

因為是對stream進行操作,所以 torchgpipe 對底層流操作進行了一些基礎封裝,流相關主要程式碼位於:torchgpipe/stream.py。

2.3.1 PyTorch 樣例

因為 torchgpipe 用到了 wait_stream 和 record_stream,而網上相關資料較少,如果深入 CUDA 或者 PyTorch 相關部分又容易耗費太多時間,所以我們通過 torch/nn/parallel/distributed.py 中的程式碼來看看如何使用,可以看到。

  • wait_stream 起到等待作用:一個流等待另一個流完成。
  • record_stream 起到確保作用:保證張量記憶體在操作完成之前不會被重用。結合其他資料,我們可以理解為確保某一個流上記錄的操作完成,才能進行下一步。

具體程式碼如下:

# Perform CPU -> GPU copies in a background stream. This code is
# motivated from similar logic in torch/nn/parallel/_functions.py
stream = _get_stream(target_gpu)
with torch.cuda.stream(stream):
    output = obj.to(target_gpu) # 拷貝
# synchronize with the copy stream
with torch.cuda.device(target_gpu):
    current_stream = torch.cuda.current_stream()
    # Sync the current stream with the copy stream
    current_stream.wait_stream(stream) # 等待
    # Ensure tensor memory is not reused until work on main stream is complete
    output.record_stream(current_stream) # 確保
return (output,)

2.3.2 生成/獲取

關於生成和獲取的函式為:

  • new_stream 會生成一個新的stream。

  • current_stream 返回當前流。

  • default_stream 返回了預設流。

def new_stream(device: torch.device) -> AbstractStream:
    """Creates a new stream for either CPU or CUDA device."""
    if device.type != 'cuda':
        return CPUStream
    return torch.cuda.Stream(device)

def current_stream(device: torch.device) -> AbstractStream:
    """:func:`torch.cuda.current_stream` for either CPU or CUDA device."""
    if device.type != 'cuda':
        return CPUStream
    return torch.cuda.current_stream(device)

def default_stream(device: torch.device) -> AbstractStream:
    """:func:`torch.cuda.default_stream` for either CPU or CUDA device."""
    if device.type != 'cuda':
        return CPUStream
    return torch.cuda.default_stream(device)

2.3.3 記錄

以下方法用來封裝了CUDA record_stream。

def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
    """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream."""
    if is_cuda(stream):
        # NOTE(sublee): record_stream() on a shifted view tensor throws
        # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely
        # protect the tensor against unexpected reallocation, here we use a
        # temporal tensor associated with the same storage without shifting as
        # a workaround.
        #
        # Issue: https://github.com/pytorch/pytorch/issues/27366
        #
        tensor = tensor.new_empty([0]).set_(tensor.storage())

        tensor.record_stream(as_cuda(stream))

2.3.4 等待

以下方法封裝了CUDA wait_stream 。

  • 如果兩個流都是CUDA流,則就是一個流等待另外一個流完成。
  • 否則採用 synchronize() 來保證 CPU 等待 CUDA 完成。

因為這裡流操作是非同步的,所以當函式返回時候無法確定操作是否已經完成,所以將CPU和主機進行同步,或者CUDA流之間進行同步,以確保GPU完成流操作。

def wait_stream(source: AbstractStream, target: AbstractStream) -> None:
    """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It
    makes the source stream wait until the target stream completes work queued.
    """
    if is_cuda(target):
        if is_cuda(source):
            # A CUDA stream waits another CUDA stream.
            as_cuda(source).wait_stream(as_cuda(target))
        else:
            # CPU waits a CUDA stream.
            as_cuda(target).synchronize()

    # If the target is CPU, synchronization is not required.

這裡wait_stream和synchronize最終都會完成等待操作,比如synchronize最終呼叫到了 cudaDeviceSynchronize,該方法將停止CPU端執行緒的執行,直到GPU端完成此前CUDA上的任務(包括kernel函式、資料拷貝等)。

既然已經把 Stream 操作進行了基礎封裝,torchgpipe 接下來就使用這些封裝函式實現了拷貝操作和等待操作,我們接下來看看。

2.4 拷貝API

拷貝流的 API 如下,其實就是呼叫了 Copy 這個類的forward方法。

def copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
    batch[:] = Copy.apply(prev_stream, next_stream, *batch)

Copy 擴充了torch.autograd.Function,主要就是應用record_stream來協助完成拷貝業務。

class Copy(torch.autograd.Function):
    """Copies tensors on specific streams."""
    @staticmethod
    def forward(ctx: Context,  # type: ignore
                prev_stream: AbstractStream,
                next_stream: AbstractStream,
                *input: Tensor,
                ) -> Tensors:
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream

        output = []
        output_stream = current_stream(get_device(next_stream)) # 得到下一個流

        with use_stream(prev_stream), use_stream(next_stream):
            for x in input:
                y = x.to(get_device(next_stream)) # 把 input 拷貝到 next_stream
                output.append(y)

                # 'prev_stream' is not where 'x' has been allocated.
                record_stream(x, prev_stream) # 記錄流,確保拷貝完成之前不會使用 x
                # 'y' has been allocated on 'next_stream'.
                # It might be used on the current stream captured as 'output_stream'.
                record_stream(y, output_stream) # 記錄流,確保拷貝完成之前不會使用 y

        return tuple(output) # 返回輸出

    @staticmethod
    def backward(ctx: Context,
                 *grad_output: Tensor,
                 ) -> Tuple[Optional[Tensor], ...]:
        prev_stream = ctx.prev_stream
        next_stream = ctx.next_stream

        grad_input: Deque[Tensor] = deque(maxlen=len(grad_output))
        input_stream = current_stream(get_device(prev_stream))

        with use_stream(prev_stream), use_stream(next_stream):
            for x in reversed(grad_output):
                y = x.to(get_device(prev_stream))
                grad_input.appendleft(y)

                # 'next_stream' is not where 'x' has been allocated.
                record_stream(x, next_stream)
                # 'y' has been allocated on 'prev_stream'.
                # It might be used on the current stream captured as 'input_stream'.
                record_stream(y, input_stream)

        grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
        return grad_streams + tuple(grad_input)

2.5 等待API

wait 則是呼叫了 Wait 類的forward方法。

def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
    batch[:] = Wait.apply(prev_stream, next_stream, *batch)

Wait 也擴充了torch.autograd.Function,就是應用wait_stream完成業務,一個流等待另外一個流完成。

class Wait(torch.autograd.Function):
    """Synchronizes a stream to another stream.

    Place it just before you want to start an operation on the next stream,
    provided that all operations on the previous stream are done.

    """
    @staticmethod
    def forward(ctx: Context,  # type: ignore
                prev_stream: AbstractStream,
                next_stream: AbstractStream,
                *input: Tensor,
                ) -> Tensors:
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream

        wait_stream(next_stream, prev_stream)

        return tuple(x.detach() for x in input)

    @staticmethod
    def backward(ctx: Context,
                 *grad_input: Tensor,
                 ) -> Tuple[Optional[Tensor], ...]:
        prev_stream = ctx.prev_stream
        next_stream = ctx.next_stream

        wait_stream(prev_stream, next_stream)

        grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
        return grad_streams + grad_input

2.6 使用

2.6.1 總體概念

我們先給出一個註釋中的流程圖,大家有一個整體概念。

        # With checkpointing, the autograd graph looks like this diagram:
        # ┌─────┸──────┐
        # │    Copy    │
        # └─────┰──────┘   (fence)
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        #       ┃          (compute)
        # ┌─────┸──────┐
        # │    Wait    │ [1] Synchronize the current stream with the copy stream.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │ Checkpoint │ [2] Compute a partition within checkpointing.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │    Wait    │ [3] Synchronize the copy stream with the current stream.
        # └─────┰──────┘
        #       ┠ ─ ─ ─ ┐
        #       ┃ ┌─────┴─────┐
        #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
        #       ┃ └─────┬─────┘
        #       ┠ ─ ─ ─ ┘
        #       ┃
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        # ┌─────┸──────┐   (fence)
        # │    Copy    │
        # └─────┰──────┘

2.6.2 構建拷貝流

在 GPipe 類中,生成了拷貝專用流。

    def forward(self, input: TensorOrTensors) -> TensorOrTensors:  # type: ignore

        ......

        # Separate CUDA streams for copy.
        copy_streams = self._ensure_copy_streams() # 這裡會生成拷貝轉專用流

        # The micro-batch index where the checkpointing stops.

        # Run pipeline parallelism.
        pipeline = Pipeline(batches,
                            self.partitions,
                            self.devices,
                            copy_streams,
                            self._skip_layout,
                            checkpoint_stop)
        pipeline.run()

        ...... 

_ensure_copy_streams 程式碼如下,就是針對每一個裝置的每一個macro-batch,都生成了一個專用流:

    def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
        """Ensures that :class:`GPipe` caches CUDA streams for copy.

        It's worth to cache CUDA streams although PyTorch already manages a
        pool of pre-allocated CUDA streams, because it may reduce GPU memory
        fragementation when the number of micro-batches is small.

        """
        if not self._copy_streams:
            for device in self.devices:
                self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])

        return self._copy_streams

假設有3個devices,模型被分成3個子網路,小批次被分割成 4個微批次。則具體如下:

就是說 _copy_streams[i][j] 之中,i 表示 device 的序列,j 表示 batch 序列。這個順序比較重要,馬上會提到。

                  +----------------------------------+
                  | _copy_streams                    |
                  |                                  |
                  |     +----------------------+     |
                  |     |                      |     |
                  |     |  [1,1] [1,2] [1,3]+--------------------------------+
                  |     |                      |     |                       |
                  |     |  [2,1] [2,2] [2,3]+------------------------------------------+
                  |     |                      |     |                       |         |
+-------------------------+[3,1] [3,2] [3,3]   |     |                       |         |
|                 |     |                      |     |                       |         |
|                 |     +----------------------+     |                       |         |
|                 |                                  |                       |         |
|                 +----------------------------------+                       |         |
|                                                                            |         |
|                                                                            v         |
|   +------------------------------------------------------------------------+------+  |
|   | Stream of device 1, Stream of device 1, Stream of device 1, Stream of device 1|  |
|   +-------------------------------------------------------------------------------+  |
|                                                                                      |
|   +-------------------------------------------------------------------------------+  |
|   | Stream of device 2, Stream of device 2, Stream of device 2, Stream of device 2+<-+
|   +-------------------------------------------------------------------------------+
|
|   +-------------------------------------------------------------------------------+
+-->+ Stream of device 3, Stream of device 3, Stream of device 3, Stream of device 3|
    +-------------------------------------------------------------------------------+


2.6.3 並行操作

我們以 例項看看如何並行操作,具體要看下面 stream 的使用。

Pipeline 類的 run 方法中,有如下程式碼保證並行操作:

def run(self) -> None:
    with spawn_workers(devices) as (in_queues, out_queues):
        for schedule in clock_cycles(m, n):
            self.fence(schedule, skip_trackers)
            self.compute(schedule, skip_trackers, in_queues, out_queues)

每次計算之前,都會用 fence 方法來把資料從前一個裝置拷貝到後一個裝置。

2.6.4 預先拷貝

fence 方法做了預先拷貝操作,其中會做如下操作:

  • 設定依賴關係,這個我們在前文中分析過。
  • 得到下一個裝置的拷貝流。
  • 得到上一個裝置的拷貝流。
  • 拷貝前面流到後續流。
    def fence(self,
              schedule: List[Tuple[int, int]],
              skip_trackers: List[SkipTrackerThroughPotals],
              ) -> None:
        """Copies micro-batches after computation for the previous
        micro-batches.
        """
        batches = self.batches
        copy_streams = self.copy_streams
        skip_layout = self.skip_layout

        for i, j in schedule:
            # Ensure that batches[i-1] is executed after batches[i] in
            # backpropagation by an explicit dependency.
            if i != 0:
                depend(batches[i-1], batches[i]) # 設定依賴關係

            next_stream = copy_streams[j][i] # 得到下一個裝置的拷貝流,注意,這裡和for的i,j相反

            for prev_j, ns, name in skip_layout.copy_policy(j): # 因為篇幅原因,我們不分析這部分
                prev_stream = copy_streams[prev_j][i] # 拷貝前面流到後續流
                skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)

            if j != 0: # 
                prev_stream = copy_streams[j-1][i] # 得到上一個裝置的拷貝流
                copy(batches[i], prev_stream, next_stream) # 拷貝前面流到後續流

我們按照之前文章的例子來看看,下面是一個schedule 生成序列。

m=4 # m: number of micro-batches
n=3 # n: number of partitions
for k in range(m + n - 1):
    print( [(k - j + 1 , j +1 ) for j in range(max(1 + k - m, 0), min(1 + k, n))] )

列印是:
[(1, 1)]                  # 第 1 輪訓練計劃 & 資料
[(2, 1), (1, 2)]          # 第 2 輪訓練計劃 & 資料
[(3, 1), (2, 2), (1, 3)]  # 第 3 輪訓練計劃 & 資料
[(4, 1), (3, 2), (2, 3)]  # 第 4 輪訓練計劃 & 資料
[(4, 2), (3, 3)]          # 第 5 輪訓練計劃 & 資料
[(4, 3)]                  # 第 6 輪訓練計劃 & 資料

前 6 個週期對應瞭如下時間流,第一個時鐘週期 (1,1) 進入系統,第二個週期 (2,1) 進入系統 .....

           +          +          +          +          +          +          +
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
 cuda:0    |  (1,1)   |   (2,1)  |  (3,1)   |   (4,1)  |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
 cuda:1    |          |   (1,2)  |  (2,2)   |   (3,2)  |  (4,2)   |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
 cuda:2    |          |          |  (1,3)   |   (2,3)  |  (3,3)   |  (4,3)   |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           | clock 1  |  clock 2 |  clock 3 |  clock 4 |  clock 5 |  clock 6 |
           +          +          +          +          +          +          +

+------------------------------------------------------------------------------>  Time

我們以如下計劃看看,重點是第 3 個時鐘週期完成的任務。

第 2 個時鐘週期完成了如下操作。

[(2, 1), (1, 2)]         # 第 2 輪訓練計劃 & 資料

第 3 個時鐘週期的計劃如下:

[(3, 1), (2, 2), (1, 3)] # 第 3 輪訓練計劃 & 資料

就是對 schedule 的每個 i, j,都分別拷貝 copy_streams[j-1][i]copy_streams[j][i]

注意 我們之前的提到的,_copy_streams[i][j] 之中,i 表示 device 的序列,j 表示 batch 序列,和schedule 的 i,j 恰好相反。

所以對於我們例子,在第 3 個時鐘週期內的拷貝操作是 (這裡 i 和 j 在迴圈和後續陣列提取時候是相反,這個恰好和schedule對應,於是負負得正,最終 i, j 可以對應上):

  • 對於 (3, 1),這個是新資料進入了 device 1,不需要拷貝。
  • 對於 (2, 2),拷貝是 (2,1) 到 (2,2)。
  • 對於 (1, 3),拷貝是 (1,2) 到 (1,3)。

具體如下圖所示,這幾個拷貝可以並行操作,因為拷貝流不是執行計算的預設流,所以也可以和計算並行

         +             +            +             +            +            +             +
         |             |            |             |            |            |             |
 cuda:0  |    (1,1)    |   (2,1)    |   (3,1)     |   (4,1)    |            |             |
         |             |     +      |             |            |            |             |
         |             |     |      |             |            |            |             |
         |             |     |      |             |            |            |             |
         |             |     |      |             |            |            |             |
         |             |     +------------+       |            |            |             |
         |             |            |     |       |            |            |             |
         |             |            |     |       |            |            |             |
         |             |            |     |       |            |            |             |
         |             |            |     v       |            |            |             |
         |             |            |             |            |            |             |
 cuda:1  |             |   (1,2)    |   (2,2)     |   (3,2)    |  (4,2)     |             |
         |             |     +      |             |            |            |             |
         |             |     |      |             |            |            |             |
         |             |     |      |             |            |            |             |
         |             |     +-----------+        |            |            |             |
         |             |            |    |        |            |            |             |
         |             |            |    |        |            |            |             |
         |             |            |    |        |            |            |             |
         |             |            |    v        |            |            |             |
 cuda:2  |             |            |   (1,3)     |   (2,3)    |  (3,3)     |     (4,3)   |
         |             |            |             |            |            |             |
         |             |            |             |            |            |             |
         |             |            |             |            |            |             |
         |   clock 1   |  clock 2   |   clock 3   |  clock 4   |  clock 5   |     clock 6 |
         +             +            +             +            +            +             +

+----------------------------------------------------------------------------------->  Time

2.6.5 計算

compute 完成了如下步驟:

  • 使用 wait(batch, copy_streams[j][i], streams[j]) "拷貝流"同步到"計算流",確保拷貝操作完成。
  • 其次進行計算。
  • 使用 wait(batch, streams[j], copy_streams[j][i]) 把計算結果從"計算流"同步到"拷貝流",確保計算操作完成。

具體如下:

    def compute(self,
                schedule: List[Tuple[int, int]],
                skip_trackers: List[SkipTrackerThroughPotals],
                in_queues: List[InQueue],
                out_queues: List[OutQueue],
                ) -> None:
        """Runs tasks with synchronization to copy streams."""
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        copy_streams = self.copy_streams
        checkpoint_stop = self.checkpoint_stop

        n = len(partitions)
        streams = [current_stream(d) for d in devices]
        exc_info: Optional[ExcInfo] = None

        # With checkpointing, the autograd graph looks like this diagram:
        # ┌─────┸──────┐
        # │    Copy    │
        # └─────┰──────┘   (fence)
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        #       ┃          (compute)
        # ┌─────┸──────┐
        # │    Wait    │ [1] Synchronize the current stream with the copy stream.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │ Checkpoint │ [2] Compute a partition within checkpointing.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │    Wait    │ [3] Synchronize the copy stream with the current stream.
        # └─────┰──────┘
        #       ┠ ─ ─ ─ ┐
        #       ┃ ┌─────┴─────┐
        #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
        #       ┃ └─────┬─────┘
        #       ┠ ─ ─ ─ ┘
        #       ┃
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        # ┌─────┸──────┐   (fence)
        # │    Copy    │
        # └─────┰──────┘
        for i, j in schedule:
            batch = batches[i]
            partition = partitions[j]

            # Synchronize with the copied input. ([1] in the diagram)
            if j != 0:
                wait(batch, copy_streams[j][i], streams[j])

            # Determine whether checkpointing or not.
            checkpoint = (i < checkpoint_stop)
            if checkpoint:
                def function(input: TensorOrTensors,
                             partition: nn.Sequential = partition,
                             skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                             ) -> TensorOrTensors:
                    with use_skip_tracker(skip_tracker):
                        return partition(input)

                chk = Checkpointing(function, batch)
                task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
                del function, chk

            else:
                def compute(batch: Batch = batch,
                            partition: nn.Sequential = partition,
                            skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                            ) -> Batch:
                    with use_skip_tracker(skip_tracker):
                        return batch.call(partition)

                task = Task(streams[j], compute=compute, finalize=None)
                del compute

            # Compute tasks in parallel. ([2] in the diagram)
            in_queues[j].put(task)

        # 這裡進行了同步操作    
        for i, j in schedule:
            ok, payload = out_queues[j].get()

            # Hold the first exception.
            if exc_info is not None:
                continue
            elif not ok:
                exc_info = cast(ExcInfo, payload)
                continue

            task, batch = cast(Tuple[Task, Batch], payload)

            # The copy stream synchronizes to copy the output. ([3] in the
            # diagram)
            if j != n-1:
                wait(batch, streams[j], copy_streams[j][i]) # 這裡有同步

            # Finalize tasks. If checkpointing is enabled, here the
            # recomputation is scheduled at backpropagation. ([4] in the
            # diagram)
            with use_device(devices[j]):
                task.finalize(batch)

            batches[i] = batch

        # Fail at the first exception.
        if exc_info is not None:
            raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

針對論文就是:

0x03 重計算

我們接下來看看重計算,在論文中是 Autograd Functions with Shared Memory 這節。

因為之前在 GPipe 之中我們介紹過類似部分,所以這裡只是為了行文完整性而加入,故此分析較略。

3.1 解析

到目前為止,在本節中,我們沒有討論在使用梯度檢查點時,如何安排重新計算任務\(F^{'}_{i,j}\) 。當使用 checkpointing,那麼它必須在反向傳播任務\(B_{i,j}\) 之前 和 完成\(B_{i+1,j}\)之後被排程。這就要求必須在autograd引擎和在計算圖中對其進行編碼。PyTorch通過實現檢查點的內部 autograd 方法來支援此類功能。

PyTorch中的檢查點是通過定義一個autograd函式來實現的,該函式像普通函式一樣計算,即進行前向傳播,不儲存中間啟用值,而是儲存輸入。在向後傳遞中,此函式通過使用儲存的輸入重新計算此函式來構造後向傳播的區域性計算圖,並通過在區域性圖中反向傳播來計算梯度。

然而,這把 \(F^{'}_{i,j}\)\(B_{i,j}\)緊密地結合在一起,我們希望在 \(F^{'}_{i,j}\)\(B_{i,j}\) 中間插入一些指令,這些指令實現了一個等待操作,等待把 \(B_{i,j+1}\)結果 \(dx^j_j\)從裝置 \(j+1\) 拷貝到裝置 \(j\) 。這樣可以允許 \(F^{'}_{i,j}\) 和複製同時發生。

對於這種細粒度的順序控制,torchgpipe把checkpointing 操作改為使用兩個單獨的autograd函式Checkpoint和Recompute來實現。在任務 \(F^{'}_{i,j}\) 的執行時間之內,生成一對具有共享記憶體的Checkpoint和Recompute。該共享記憶體在向後傳播中被使用,用於將通過執行Recompute生成的本地計算圖傳輸到Checkpoint來進行反向傳播。

通過安排這些函式,在每次後向傳播之中,會做:

  • \(F^{'}_{i,j}\)
  • 一個同步操作,用來接受 \(dx^j_j\)
  • \(B_{i,j}\)

這三個操作可以按順序執行,就能確保可以同時進行重新計算和複製。

我們可以通過原始碼來看看。

3.2 封裝API

torchgpipe/checkpoint.py 之中有一個 checkpoint 方法,這是對外提供了一個簡單API。

def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
    """Makes a checkpoint with a simple interface like
    :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
    :class:`Checkpoint` and :class:`Recompute` without boilerplate.
    """
    batch = Batch(input)

    chk = Checkpointing(function, batch)
    batch = chk.checkpoint()
    chk.recompute(batch)

    return batch.tensor_or_tensors

具體使用參見tests/test_checkpoint.py。其通過log的巧妙列印,可以讓我們看出來執行時候,checkpoint在前向後向傳播之中的使用。

timeline 最後結果是 ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"],

其中兩兩一組,分別對應了 forward pass ,Checkpoint(Log[b]),Checkpoint(Log[a])。

@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
    # Copied from https://github.com/pytorch/pytorch/pull/18568.
    timeline = []

    class Log(torch.autograd.Function):
        @staticmethod
        def forward(ctx, name, x):
            ctx.name = name
            timeline.append(f"{name}:forward")
            return x.detach()

        @staticmethod
        def backward(ctx, grad_output):
            name = ctx.name
            timeline.append(f"{name}:backward")
            return None, grad_output

    a = torch.rand(1, device=device, requires_grad=True)
    b = torch.rand(1, device=device, requires_grad=True)

    # Increase the next function sequence number.
    _ = a + 1 + 2 + 3 + 4 + 5

    # 這裡意味著最後 backward 實際會執行"a:forward", "a:backward"
    a = checkpoint(partial(Log.apply, "a"), a)

    a, phony = fork(a)
    b = join(b, phony)

    # 這裡意味著最後 backward 實際會執行"b:forward", "b:backward"
    b = checkpoint(partial(Log.apply, "b"), b)

    c = torch.cat((a, b))

    out = c.sum()

    #                        +--> {a} --Checkpoint(Log)--> {a}
    # {out} --Sum--> {c} --Cat     ^-----------------------------+
    #                        +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
    out.backward()

    assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
    #    |----------------------|  |-----------------------|  |-----------------------|
    #          forward pass            Checkpoint(Log[b])         Checkpoint(Log[a])

checkpoint API 呼叫了 Checkpointing,所以我們看看其實現。

其實現是提供了 checkpoint 和 recompute 兩個方法。分別呼叫了兩個類。

class Checkpointing:
    """Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""

    def __init__(self, function: Function, batch: Batch) -> None:
        self.function = function
        self.batch = batch

        # Shared memory between Checkpoint and Recompute. 1-length deque is
        # used for mutability and length limitation.
        self.recomputed: Deque[Recomputed] = deque(maxlen=1)
        self.rng_states: Deque[RNGStates] = deque(maxlen=1)

    def checkpoint(self) -> Batch:
        """Returns a batch applied by :class:`Checkpoint`."""
        input_atomic = self.batch.atomic
        input = tuple(self.batch)

        # Use a phony which requires grad to ensure that Checkpoint can be
        # tracked by the autograd engine even when none of the input tensors
        # require grad.
        phony = get_phony(self.batch[0].device, requires_grad=True)

        output = Checkpoint.apply(phony, self.recomputed, self.rng_states,
                                  self.function, input_atomic, *input)
        return Batch(output)

    def recompute(self, batch: Batch) -> None:
        """Applies :class:`Recompute` to the batch in place."""
        input_atomic = self.batch.atomic
        input = tuple(self.batch)

        # batch[0] is always requiring grad, because it has been passed
        # checkpoint with a phony requiring grad.
        batch[0], phony = fork(batch[0])
        phony = Recompute.apply(phony, self.recomputed, self.rng_states,
                                self.function, input_atomic, *input)
        batch[0] = join(batch[0], phony)

3.3 實現

Checkpoint 和下面的 Recompute 就是把普通模式下的 checkpoint 程式碼分離成兩個階段(forward函式被分成兩段,backward 函式也被分成兩段),從而可以更好的利用流水線。

對應論文中就是:

我們希望在 \(F^{'}_{i,j}\)\(B_{i,j}\) 中間插入一些指令,這些指令實現了一個等待操作,等待把 \(B_{i,j+1}\)結果 \(dx^j_j\)從裝置 \(j+1\) 拷貝到裝置 \(j\) 。這樣可以允許 \(F^{'}_{i,j}\) 和複製同時發生。

對於這種細粒度的順序控制,torchgpipe把checkpointing 操作改為使用兩個單獨的autograd函式Checkpoint和Recompute來實現。在任務 \(F^{'}_{i,j}\) 的執行時間之內,生成一對具有共享記憶體的Checkpoint和Recompute。該共享記憶體在向後傳播中被使用,用於將通過執行Recompute生成的本地計算圖傳輸到Checkpoint來進行反向傳播。

3.3.1 Checkpoint

class Checkpoint(torch.autograd.Function):
    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> TensorOrTensors:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states

        # 存RNG狀態
        save_rng_states(input[0].device, ctx.rng_states)

        ctx.function = function
        ctx.input_atomic = input_atomic
        # 為BP做準備,其實目前沒有實現
        ctx.save_for_backward(*input)

        # 進行前向計算
        with torch.no_grad(), enable_checkpointing():
            output = function(input[0] if input_atomic else input)

        return output

    @staticmethod
    def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:  # pragma: no cover
        # 從儲存的重計算變數中彈出所需變數
        output, input_leaf = ctx.recomputed.pop() 

        if isinstance(output, tuple):
            tensors = output
        else:
            tensors = (output,)
            
        if any(y.requires_grad for y in tensors):
            tensors = tuple([x for x in tensors if x.requires_grad])
            # 進行自動微分
            torch.autograd.backward(tensors, grad_output)

        grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
        grad_input.extend(x.grad for x in input_leaf)
        return tuple(grad_input)

3.3.2 Recompute

Recompute 就是依據儲存的資訊,重新計算中間變數。

class Recompute(torch.autograd.Function):
  
    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> Tensor:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states

        ctx.function = function
        ctx.input_atomic = input_atomic
        ctx.save_for_backward(*input)

        return phony

    @staticmethod
    def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]:  
        input = ctx.saved_tensors
        input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input)

        # 取出儲存的RNG狀態,進行前向計算,得到中間變數
        with restore_rng_states(input[0].device, ctx.rng_states):
            with torch.enable_grad(), enable_recomputing():
                output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf)

        # 儲存變數,為Checkpoint使用
        ctx.recomputed.append((output, input_leaf))

        grad_input: List[None] = [None, None, None, None, None]
        grad_input.extend(None for _ in ctx.saved_tensors)
        return tuple(grad_input)

3.4 總體呼叫

總體呼叫程式碼如下:

    def compute(self,
                schedule: List[Tuple[int, int]],
                skip_trackers: List[SkipTrackerThroughPotals],
                in_queues: List[InQueue],
                out_queues: List[OutQueue],
                ) -> None:
        """Runs tasks with synchronization to copy streams."""
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        copy_streams = self.copy_streams
        checkpoint_stop = self.checkpoint_stop

        n = len(partitions)
        streams = [current_stream(d) for d in devices]
        exc_info: Optional[ExcInfo] = None

        # With checkpointing, the autograd graph looks like this diagram:
        # ┌─────┸──────┐
        # │    Copy    │
        # └─────┰──────┘   (fence)
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        #       ┃          (compute)
        # ┌─────┸──────┐
        # │    Wait    │ [1] Synchronize the current stream with the copy stream.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │ Checkpoint │ [2] Compute a partition within checkpointing.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │    Wait    │ [3] Synchronize the copy stream with the current stream.
        # └─────┰──────┘
        #       ┠ ─ ─ ─ ┐
        #       ┃ ┌─────┴─────┐
        #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
        #       ┃ └─────┬─────┘
        #       ┠ ─ ─ ─ ┘
        #       ┃
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        # ┌─────┸──────┐   (fence)
        # │    Copy    │
        # └─────┰──────┘
        for i, j in schedule:
            batch = batches[i]
            partition = partitions[j]

            # Synchronize with the copied input. ([1] in the diagram)
            if j != 0:
                wait(batch, copy_streams[j][i], streams[j])

            # Determine whether checkpointing or not.
            checkpoint = (i < checkpoint_stop)
            if checkpoint:
                def function(input: TensorOrTensors,
                             partition: nn.Sequential = partition,
                             skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                             ) -> TensorOrTensors:
                    with use_skip_tracker(skip_tracker):
                        return partition(input)

                chk = Checkpointing(function, batch)
                task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
                del function, chk

            else:
                def compute(batch: Batch = batch,
                            partition: nn.Sequential = partition,
                            skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                            ) -> Batch:
                    with use_skip_tracker(skip_tracker):
                        return batch.call(partition)

                task = Task(streams[j], compute=compute, finalize=None)
                del compute

            # Compute tasks in parallel. ([2] in the diagram)
            in_queues[j].put(task)

        for i, j in schedule:
            ok, payload = out_queues[j].get()

            # Hold the first exception.
            if exc_info is not None:
                continue
            elif not ok:
                exc_info = cast(ExcInfo, payload)
                continue

            task, batch = cast(Tuple[Task, Batch], payload)

            # The copy stream synchronizes to copy the output. ([3] in the
            # diagram)
            if j != n-1:
                wait(batch, streams[j], copy_streams[j][i])

            # Finalize tasks. If checkpointing is enabled, here the
            # recomputation is scheduled at backpropagation. ([4] in the
            # diagram)
            with use_device(devices[j]):
                task.finalize(batch)

            batches[i] = batch

        # Fail at the first exception.
        if exc_info is not None:
            raise exc_info[0].with_traceback(exc_info[1], exc_info[2])


至此,PyTorch 流水線並行分析完畢,我們接下來的計劃是把PyTorch 並行訓練再系統梳理一下,首先需要分析其梯度相關基礎知識,敬請期待。

0xFF 參考

Markdown公式用法大全

markdown中公式編輯教程

https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html#stream-sync-behavior

CUDA學習:基礎知識小結

CUDA隨筆之Stream的使用

NVIDIA解決方案架構師深度解析大規模引數語言模型Megatron-BERT

Accelerating Wide & Deep Recommender Inference on GPUs

HugeCTR: High-Performance Click-Through Rate Estimation Training

https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548

https://github.com/NVIDIA/apex/

https://github.com/justheuristic/prefetch_generator

https://pytorch.org/tutorials/intermediate/model_parallel_turotial.html

https://pytorch.org/docs/stable/autograd.html

https://pytorch.org/docs/notes/cuda.html

https://zhuanlan.zhihu.com/p/61765561

https://pytorch.apachen.org/docs/1.7/64.html

https://zhidx.com/p/217999.html

相關文章