[原始碼分析] Facebook如何訓練超大模型--- (5)

羅西的思考 發表於 2022-01-26

[原始碼分析] Facebook如何訓練超大模型--- (5)

0x00 摘要

我們在前文介紹過,微軟 ZeRO 可以對一個萬億引數模型可以使用 8 路模型並行、64 路管道並行和 8 路資料並行在 4,096 個 NVIDIA A100 GPU 上進行擴充套件。而FSDP(Fully Sharded Data Parallel)是Facebook 深度借鑑微軟ZeRO之後提出的PyTorch DDP升級版本,可以認為是對標微軟 ZeRO,其本質是 parameter sharding。Parameter sharding 就是把模型引數等切分到各個GPU之上。我們會以 Google,微軟和 Facebook 的論文,部落格以及程式碼來進行學習分析。

之前文章之中我們談到了FSDP支援混合精度訓練,本篇來看看 Activation recomputation。

本系列其他文章如下:

[原始碼解析] PyTorch 分散式之 ZeroRedundancyOptimizer

[論文翻譯] 分散式訓練 Parameter sharding 之 ZeRO

[論文翻譯] 分散式訓練 Parameter Sharding 之 Google Weight Sharding

[原始碼分析] Facebook如何訓練超大模型---(1)

[原始碼分析] Facebook如何訓練超大模型 --- (2)

[原始碼分析] Facebook如何訓練超大模型 --- (3)

[原始碼分析] Facebook如何訓練超大模型---(4)

0x01 背景

啟用重新計算(Activation recomputation),也稱“啟用檢查點(activation checkpointing)”或“梯度檢查點(gradient checkpointing)”(Chen et al,2016 https://arvix.org/abs/1604.06174),其思路是用時間換空間,即,犧牲計算時間來換取記憶體空間。其減少了深度神經網路訓練層的記憶體開銷,代價是每個batch會消耗額外的前向傳播計算。

比方說,該方法將m層網路平均劃分為d個分割槽,只儲存分割槽邊界的啟用,並在workers之間交換這些啟用。因為後向傳播之中依然需要分割槽內層間啟用值(Intermediate activations at intra-partition layers)來計算梯度,所以在後向傳播過程中會在分割槽內部重新計算啟用。

下圖為論文之中的示意圖。

[原始碼分析] Facebook如何訓練超大模型--- (5)

我們在之前文章之中介紹過重計算 [原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算。本文會看看 FairScale 是如何對其進行進一步封裝和改進。

0x02 思路

2.1 學習建議

在看思路之前,我們先來講講如何更好的分析一個開源框架或者說如何學習原始碼。個人的意見是按照:論文 --> 文件 --> 使用者手冊 --> 註釋 --> 原始碼 這個順序來學習。

為什麼按照這個順序?因為這個順序是:

  • 從抽象邏輯(或者說是體系架構)到具體細節。
    • 論文是把作者的思想提煉,邏輯化,體系化的結果,文件次之。而且重讀經典論文,其收穫是多維度的。
    • 手冊則會從使用或者注意點方面幫你完成對這個框架整體的認識。
    • 原始碼則給你呈現了大量的細節。
  • 從人的思想到機器的思想。
    • 註釋是作者給閱讀者看的,程式碼是作者給機器看的。
    • 註釋會告訴你為什麼這樣實現(Why),程式碼告訴你怎麼實現(How)。

對於我們來說,應該首先尋求一種思維的改變,知識框架的更新與整理,然後才是用程式碼來分析驗證(畢竟紙上得來終覺淺)。當然,很多時候我們只有原始碼,那麼就只能從原始碼之中根據細節來探尋,重建作者的思路,提煉其精華,爭取和作者達到一個跨越空間和時間的共鳴,共鳴越多,你就越接近作者了 _

2.2 具體思路

我們接下來就看看原始碼文件之中的思路介紹。

啟用檢查點是一種用於減少訓練期間GPU記憶體使用的技術。具體做法是:

  • 在向前傳播過程中避免儲存中間啟用張量。
  • 在後向傳播過程中依靠跟蹤原始輸入來重新進行前向傳播計算。

其結果是:以略有增加(約33%)的計算成本來減少了儲存大型啟用張量的必要,因此允許我們增加batch size,從而增加模型的淨吞吐量。

啟用檢查點是通過過載 torch.autograd.Function 來完成的。

  • 通過在前向函式之中使用no_grad,我們可以在很長一段時間內(即直到反向傳播開始)避免前向計算圖的建立和中間啟用張量的具化(materialization)。
  • 在向後傳播期間內,會先再次執行向前傳播,然後執行向後傳播。
    • 向前傳播的輸入已經儲存在上下文物件之中,所以在向後傳播之中可以通過該上下文物件拿到原始輸入。
    • 因為在某些情況下(Dropout layers)會用到,所以還儲存了前向和後向傳播的Random Number Generator(RNG) 狀態。

上述功能在torch.utils.checkpoint.checkpoint_wrapper 之中可以看到其具體實現,可以在前向傳播之中使用這個API來對模組進行封裝。FairScale中的包裝器提供的功能比PyTorch API提供的功能更多,比如使用者可以使用 fairscale.nn.checkpoint.checkpoint_wrapper 來包裝一個 nn.Module,這樣就可以在正向傳遞中處理kwargs,將中間啟用解除安裝(offload)到CPU,並處理從前向函式返回的非張量輸出。

2.3 最佳實踐

我們接下來看看 fairscale.nn.checkpoint.checkpoint_wrapper 的最佳實踐。

  • 記憶體節省效果取決於模型和checkpoint wrapping如何進行分段。即,記憶體節省收益取決於層啟用的記憶體佔用情況。
  • 使用BatchNormalization時,您可能需要凍結統計資料的計算,因為在這種情況下會執行兩次前向傳播。
  • 確保輸入張量的requires_grad 屬性設定為True。通過將輸入張量的requires_grad 屬性設定為True,我們確保輸入可以傳播到輸出,並觸發 backward 函式。

0x03 具體實現

3.1 Wrapper

checkpoint_wrapper 是具體的wrapper,其內部就是呼叫了其他函式。但是我們發現其註釋可以讓我們進一步學習,所以翻譯如下:

checkpoint_wrapper 是執行啟用檢查點的包裝器,其比PyTorch版本更加使用者友好,具備如下特點:

  • 包裝一個nn.Module,以便所有後續呼叫都將使用checkpointing。

  • 處理前向過程中的關鍵字引數(keyword arguments)。

  • 處理來自正向過程中的非張量輸出。

  • 支援將啟用解除安裝到CPU。

為了更好的瞭解checkpointing和"offload_to_cpu"帶來的好處,我們將啟用分為兩種型別:

  • 內部啟用。其依靠 activation checkpointing 來儲存。
  • 外部啟用,也就是檢查點模組。其依靠offload_to_cpu來儲存。

就GPU記憶體節約效果而言:

  • 當內部啟用很大而外部啟用很小時,檢查點會帶來很大收穫,offload_to_cpu可能只帶來很小的收益。

  • 當內部啟用小而外部啟用很大時,檢查點幫助很小,offload_to_cpu會帶來很大收益。

  • 當內部啟用和外部啟用都很大時,檢查點和offload_to_cpu帶來的益處是疊加的。

另外,第一層和最後一層不太可能受益於offload_to_cpu標誌,因為:

  • 第一層的輸入通常有其他引用,因此GPU記憶體不會被釋放;
  • 最後一層的輸入會立即被向後傳播使用,不會節省記憶體。
def checkpoint_wrapper(
    module: nn.Module, offload_to_cpu: bool = False, maintain_forward_counter: bool = False
) -> nn.Module:
    """
    A friendlier wrapper for performing activation checkpointing.

    Compared to the PyTorch version, this version:

        - wraps an nn.Module, so that all subsequent calls will use checkpointing
        - handles keyword arguments in the forward
        - handles non-Tensor outputs from the forward
        - supports offloading activations to CPU

    Usage::

        checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True)
        a, b = checkpointed_module(x, y=3, z=torch.Tensor([1]))

    To understand the benefits of checkpointing and the `offload_to_cpu` flag,
    let's divide activations into 2 types: inner activations and outer
    activations w.r.t. the checkpointed modules. The inner ones are saved
    by activation checkpointing, the outer ones are saved by offload_to_cpu.

    In terms of GPU memory savings:

        - When inner ones are large in size and outer ones are small,
          checkpointing helps a lot, offload_to_cpu may help a little.
        - When inner ones are small and outer ones are large,
          checkpointing helps little, offload_to_cpu helps a lot.
        - When both inner and outer are large, both help and the
          benefit is additive.

    ..Note::

        The first and last layers are not likely to benefit from the `offload_to_cpu` flag
        because (1) there are typically other references to the first layer's input, so
        the GPU memory won't be freed; (2) the input to the last layer is immediately
        used by the backward pass and won't result in memory savings.

    Args:
        module (nn.Module):
            The module to be wrapped
        offload_to_cpu (bool):
            Whether to offload activations to CPU.
        maintain_forward_counter (bool):
            If True, maintain a forward counter per inner module. The counter will first
            increases in forward calls of outer forward pass and then decreases in the
            forward calls of outer backward pass. It is used by FullyShardedDataParallel.

    Returns:
        (nn.Module):
            Wrapped module
    """
    # Patch the batchnorm layers in case there are any in this module.
    patch_batchnorm(module)

    if maintain_forward_counter:
        init_counter(module)

    # The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.
    # When such cycle exists, gc won't collect the module when the module is freed.
    # That causes GPU memory to be leaked. See the unit test for how we catch that.
    #
    # We prefer this over a class wrapper since the class wrapper would have to
    # proxy a lot of fields and methods.
    module.forward = functools.partial(  # type: ignore
        _checkpointed_forward, type(module).forward, weakref.ref(module), offload_to_cpu
    )
    return module # 包裝一個nn.Module,以便所有後續呼叫都將使用checkpointing

3.2 如何使用

我們從原始碼之中找出一些程式碼,大家可以看看。

self.layers = nn.Sequential(
    nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4), nn.Linear(4, 8)),
    nn.Sequential(nn.Linear(8, 4), nn.Linear(4, 4), nn.Linear(4, 4)),
    nn.Sequential(nn.Linear(4, 6), nn.Linear(6, 8), nn.Linear(8, 2)),
)

if enable_checkpoint:
    for i, layer in enumerate(self.layers):
        # Only middle layer needs to have offloading
        self.layers[i] = checkpoint_wrapper(layer, cpu_offload if i == 1 else False)

3.2 _checkpointed_forward

前面提到對比PyTorch版本,FairScale有幾點益處,此處就對應了以下有下劃線的兩點:

  • 包裝一個nn.Module,以便所有後續呼叫都將使用checkpointing。

  • 處理前向過程中的關鍵字引數(keyword arguments)。

  • 處理來自正向過程中的非張量輸出。

  • 支援將啟用解除安裝到CPU。

程式碼邏輯如下:

  • 如果禁用了disabled,則直接使用 .forward() 。這樣做還可以確保內部fwd counter在前向過程中不會增加,但是這在eval過程中會是一個問題,因為不會有相應的後向過程來減少fwd counter。
  • 因為後向傳播必須為每個輸入引數返回一個梯度(或None),所以PyTorch中的Autograd函式在帶有位置資訊引數下工作最佳。將關鍵字引數扁平化可以讓這種處理更加方便。
  • 呼叫 CheckpointFunction 完成 activation checkpointing。這裡需要注意的是:當original_forward的輸入為非張量(即一個元組)時,因此 CheckpointFunction 傳入了一個帶有grad的 dummy tensor 引數來確保向後傳播被呼叫。
    • 在輸入為元組型別的情況下,即便設定張量的requires_grad標誌也不會觸發後向傳播。
    • 使用這個 dummy tensor 可以避免要求使用者設定輸入張量的requires_grad標誌。
  • 處理來自正向過程中的輸出為tuple,就是把張量和非張量打包在一起。

具體程式碼如下:

def _checkpointed_forward(
    original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any
) -> Any:
    module = weak_self()

    # If gradients are disabled, just use original `.forward()` method directly.
    # Doing so also ensures the internal fwd counter is not incremented in the forward pass,
    # which would be an issue during eval since there wouldn't be a corresponding backward pass
    # to decrement the fwd counter.
    # See https://github.com/facebookresearch/fairscale/pull/709.
    if not torch.is_grad_enabled():
        return original_forward(module, *args, **kwargs)

    # Autograd Functions in PyTorch work best with positional args, since
    # the backward must return gradients (or None) for every input argument.
    # We can flatten keyword arguments to make this easier.
    args = (module,) + args
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) # 處理輸入
    parent_ctx_dict: Dict[str, Any] = {
        "offload": offload_to_cpu,
    }
    # Dummy tensor with grad is used to ensure the backward pass is called. This is needed
    # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor
    # avoids requiring users to set their input tensors's requires_grad flag. In the case
    # of tuple type inputs, setting the flag won't even trigger the backward pass.
    output = CheckpointFunction.apply(
        torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args
    )
    
    # 處理非張量輸出
    if not isinstance(output, torch.Tensor):
        # parent_ctx_dict["packed_non_tensor_outputs"] 是 CheckpointFunction 返回的
        packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
        if packed_non_tensor_outputs:
            # 統一處理成tuple
            output = unpack_non_tensors(output, packed_non_tensor_outputs) # 處理輸出
    return output

3.2.1 處理輸入

在處理前向過程中的關鍵字引數(keyword arguments)之中,使用了pack_kwargs,其作用就是把引數的key,value整理成為兩個list,具體可以參見示例。

def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]:
    """
    Turn argument list into separate key list and value list (unpack_kwargs does the opposite)
    Usage::

        kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
        assert kwarg_keys == ("a", "b")
        assert flat_args == (1, 2, 3, 4)
        args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
        assert args == (1, 2)
        assert kwargs == {"a": 3, "b": 4}
    """
    kwarg_keys: List[str] = []
    flat_args: List[Any] = list(args)
    for k, v in kwargs.items():
        kwarg_keys.append(k)
        flat_args.append(v)
    return tuple(kwarg_keys), tuple(flat_args)

3.2.2 非張量輸出

3.2.2.1 壓縮非張量

把一個tuple分割為一個張量列表和後續重建所需要的資訊。

def split_non_tensors(
    mixed: Union[torch.Tensor, Tuple[Any, ...]]
) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]:
    """
    Split a tuple into a list of tensors and the rest with information
    for later reconstruction.

    Usage::

        x = torch.Tensor([1])
        y = torch.Tensor([2])
        tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
        assert tensors == (x, y)
        assert packed_non_tensors == {
            "is_tensor": [True, True, False, False],
            "objects": [None, 3],
        }
        recon = unpack_non_tensors(tensors, packed_non_tensors)
        assert recon == (x, y, None, 3)
    """
    if isinstance(mixed, torch.Tensor):
        return (mixed,), None
    tensors: List[torch.Tensor] = []
    packed_non_tensors: Dict[str, List[Any]] = {"is_tensor": [], "objects": []}
    for o in mixed:
        if isinstance(o, torch.Tensor):
            packed_non_tensors["is_tensor"].append(True)
            tensors.append(o)
        else:
            packed_non_tensors["is_tensor"].append(False)
            packed_non_tensors["objects"].append(o)
    return tuple(tensors), packed_non_tensors
3.2.2.2 解壓非張量

unpack_non_tensors 用來把非張量列表恢復成tuple。

def unpack_non_tensors(
    tensors: Tuple[torch.Tensor, ...], packed_non_tensors: Optional[Dict[str, List[Any]]]
) -> Tuple[Any, ...]:
    """See split_non_tensors."""
    if packed_non_tensors is None:
        return tensors
    assert isinstance(packed_non_tensors, dict), type(packed_non_tensors)
    mixed: List[Any] = []
    is_tensor_list = packed_non_tensors["is_tensor"]
    objects = packed_non_tensors["objects"]

    obj_i = tnsr_i = 0
    for is_tensor in is_tensor_list:
        if is_tensor:
            mixed.append(tensors[tnsr_i])
            tnsr_i += 1
        else:
            mixed.append(objects[obj_i])
            obj_i += 1
    return tuple(mixed)

3.3 CheckpointFunction

我們接下來分析 CheckpointFunction,就是具體 activation checkpointing 的業務函式。關於 PyTorch 的 CheckpointFunction 版本,可以參見 [原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

這裡對應了優點之中的:支援將啟用解除安裝到CPU

3.3.1 前向傳播

其前向傳播的邏輯如下:

  • 分割非張量引數列表,得到張量輸入和非張量輸入。
    • 如果設定了"offload",在上下文記錄裝置,梯度需求情況,並且把輸入張量放到cpu上。
  • 為後向傳播儲存輸入。
  • 如果設定了activation checkpointing,則處理引數,進行前向計算。
  • 如果輸出不是張量,因為Autograd Functions不喜歡非張量輸出。我們可以拆分為非張量和張量輸出,通過parent_ctx_dict引用返回前者,然後直接返回後者。
class CheckpointFunction(torch.autograd.Function):
    """Similar to the torch version, but support non-Tensor outputs.

    The caller is expected to provide a dict (*parent_ctx_dict*) that will hold
    the non-Tensor outputs. These should be combined with the Tensor *outputs*
    by calling :func:`unpack_non_tensors`.
    """

    @staticmethod
    def forward(  # type: ignore
        ctx: Any,
        dummy_tensor_requires_grad: torch.Tensor,
        run_function: Any,
        parent_ctx_dict: Dict[str, Any],
        kwarg_keys: Tuple[str, ...],
        *args: Any,
        **kwargs: Any
    ) -> Any:
        torch_checkpoint.check_backward_validity(args)

        ctx.run_function = run_function # 在上下文之中儲存前向傳播函式
        ctx.kwarg_keys = kwarg_keys
        ctx.fwd_rng_state = get_rng_state() # 在上下文之中儲存前向傳播狀態
        ctx.had_autocast_in_fwd = is_autocast_enabled()

        # 分割非張量引數列表,得到張量輸入和非張量輸入
        tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) 
        if parent_ctx_dict["offload"]:
            # 在上下文記錄裝置,梯度需求情況,並且把輸入張量放到cpu上
            ctx.fwd_device = tuple(x.device for x in tensor_inputs) # 在上下文儲存前向傳播裝置
            ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
            tensor_inputs = tuple(x.to("cpu", non_blocking=True) for x in tensor_inputs)
        else:
            ctx.fwd_device, ctx.grad_requirements = None, None

        # 為後向傳播儲存輸入
        ctx.save_for_backward(*tensor_inputs)
        ctx.packed_non_tensor_inputs = packed_non_tensor_inputs

        with torch.no_grad(), enable_checkpointing(): # 如果設定了activation checkpointing
            unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) # 處理引數
            outputs = run_function(*unpacked_args, **unpacked_kwargs) # 前向計算
            the_module = unpacked_args[0]
            inc_counter(the_module)

        if not isinstance(outputs, torch.Tensor): # 如果輸出不是張量
            # Autograd Functions don't like non-Tensor outputs. We can split the
            # non-Tensor and Tensor outputs, returning the former by reference
            # through *parent_ctx_dict* and returning the latter directly.
            # Autograd Functions不喜歡非張量輸出。我們可以拆分為非張量和張量輸出,
            # 通過parent_ctx_dict引用返回前者,然後直接返回後者。
            outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
            parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs
        return outputs


3.3.2 後向傳播

後向傳播邏輯如下:

  • 拿到儲存在上下文的張量輸入。
  • 如果設定了在裝置上計算,則:
    • 把 offlad 的張量再移到 GPU之上。
    • 找到需要計算的梯度。
  • 處理非張量輸入,最終和張量輸入組合在一起。
  • 儲存當前狀態。
  • 從上下文載入前向傳播時候的狀態。
  • 重新做前向傳播。
  • 處理前向傳播輸出。
  • 恢復後向傳播的狀態。
  • 從前向傳播輸出找到需要梯度的張量,在後向傳播的輸入之中找到對應的張量。
  • 進行後向傳播。
  • 返回梯度。
class CheckpointFunction(torch.autograd.Function):
    """Similar to the torch version, but support non-Tensor outputs.

    The caller is expected to provide a dict (*parent_ctx_dict*) that will hold
    the non-Tensor outputs. These should be combined with the Tensor *outputs*
    by calling :func:`unpack_non_tensors`.
    """

    @staticmethod
    def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")

        # 拿到儲存在上下文的張量輸入
        tensor_inputs: Tuple = ctx.saved_tensors
        tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs)
        if ctx.fwd_device is not None: # 如果設定了在裝置上計算
            # 把 offload 的張量再移到 GPU之上
            tensor_inputs = tuple(t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs))
            for i, need_grad in enumerate(ctx.grad_requirements): # 找到需要計算的梯度
                tensor_inputs[i].requires_grad = need_grad
        # 處理非張量輸入,最終和張量輸入組合在一起        
        inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs)

        # Store the current states.
        bwd_rng_state = get_rng_state() # 拿到之前儲存的當前狀態

        # Set the states to what it used to be before the forward pass.
        set_rng_state(ctx.fwd_rng_state) # 從上下文載入前向傳播時候的狀態

        with torch.enable_grad(), enable_recomputing(), autocast(ctx.had_autocast_in_fwd):
            unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
            outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) # 重新做前向傳播
            tensor_outputs, _ = split_non_tensors(outputs) # 處理前向傳播輸出
            the_module = unpacked_args[0]
            dec_counter(the_module)

        # Set the states back to what it was at the start of this function.
        set_rng_state(bwd_rng_state) # 恢復後向傳播的狀態

        # Run backward() with only Tensors that require grad
        outputs_with_grad = [] 
        args_with_grad = []
        # 從前向傳播輸出找到需要梯度的張量
        for i in range(len(tensor_outputs)):
            if tensor_outputs[i].requires_grad:
                outputs_with_grad.append(tensor_outputs[i])
                args_with_grad.append(args[i]) # 在後向傳播的輸入之中找到對應的張量
        if len(outputs_with_grad) == 0:
            raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary")

        # 進行後向傳播     
        torch.autograd.backward(outputs_with_grad, args_with_grad)

        # 從inputs裡面得到梯度
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
        return (None, None, None, None) + grads # 返回梯度

我們整理邏輯如下:

[原始碼分析] Facebook如何訓練超大模型--- (5)

0x04 OffloadFunction

前文在 OffloadModel 的 forward 方法之中,如果設定了 _checkpoint_activation,則呼叫 OffloadFunction 把啟用檢查點解除安裝到CPU之上,直接返回。我們接下來看看 OffloadFunction 如何實現與activation相關的操作。

此函式通過覆蓋nn.Module的向前和向後傳播,在分片邊界啟用中間啟用的檢查點。這樣只儲存分割槽邊界的啟用,並在workers之間交換這些啟用。

本節與上節的主要區別是:

  • CheckpointFunction只是把輸入張量在GPU和CPU之間移動,丟棄了內部啟用
  • OffloadFunction 把啟用(沒有丟棄)與模型都在在GPU和CPU之間移動,而且因為分割槽是一層或者多層layers,所以只是在worker之間交換這些分割槽邊界的啟用。

4.1 前向傳播

在FW過程中,它遍歷每一個分割槽,針對每一個分割槽,刪除前一個分片中的引數,並載入下一個分片的引數,然後進行這個分割槽的前向計算。FW過程中未構造任何計算圖。這使我們能夠解除安裝分片邊界上的中間啟用。

這裡有幾點說明:

  • model_instance.model_slices 是模型的分片,每個分片裡面包含一個或者多個層。
  • 除了之後一個分割槽的啟用,其餘分割槽之間的啟用都存在CPU之上。這裡假設目標張量也位於執行計算的GPU上,那麼對於最後一層計算來說,其輸出啟用也應該位於這個GPU之上。如果輸出啟用移動到CPU之上,反向傳播就可能找不到其梯度函式了。

具體程式碼如下:

class OffloadFunction(torch.autograd.Function):
    """
     This Function enables checkpointing of intermediate activations at
     shard boundaries by overriding the forward and backward pass of the nn.Module.

     - In the FW pass, it drops parameters in the previous shard and
     loads parameters for the next shard. No graph is constructed in the FW pass.
     This enables us to offload intermediate activations present at the shard
     boundaries.

     - In the BW pass, it does the reverse. We run the forward pass using the
     saved intermediate activations and calculate gradients as needed.
     The trade-off is latency vs memory when using activation checkpointing.

     - Follows heavily from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint.

     NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
     """

    @staticmethod
    @_conditional_amp_fwd_decorator  # type: ignore
    def forward(ctx: Any, inputs: Any, dummy_input: Any, model_instance: Any) -> Any:
        inputs = inputs if isinstance(inputs, tuple) else (inputs,)

        # 把後向傳播所需要的資訊儲存在上下文。
        ctx.inputs = inputs
        ctx.model_instance = model_instance
        # TODO(anj-s): We might need to store this for each boundary activation.
        # Currently we assume all boundary activation inputs require
        ctx.grad_requirements = tuple(x.requires_grad for x in inputs)
        ctx.fwd_rng_state = torch.get_rng_state()

        # List of input activations starting with the given input.
        model_instance._activations = [inputs]
        # Enumerate through layer shards and apply activations from the previous shard.
        for index, layer_shard in enumerate(model_instance.model_slices): # 遍歷模型的分割槽
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_load"):
                # Bring in the current activations onto the device.
                # 把當前啟用拷貝到裝置之上
                model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
                # Bring in the current layer shard onto the device.
                # 把當前層載入到裝置之上
                layer_shard.forward_load()

            # Apply the FP and store the activations on the CPU.
            inputs = model_instance._activations[index]
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:no_grad_forward_pass"):
                with torch.no_grad(): # 不會跟蹤下面的梯度,只是計算啟用
                    output_list: List[Any] = []
                    for given_input in inputs:
                        given_input_list = torch.chunk(given_input, model_instance._num_microbatches)
                        given_output_list = []
                        for inputs in given_input_list:
                            output = layer_shard(inputs) # 前向操作
                            given_output_list.append(output)
                        given_output = torch.cat(given_output_list).squeeze(-1)
                        output_list.append(given_output)
                    output = tuple(output_list) # 得到輸出

            output = output if isinstance(output, tuple) else (output,)
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_drop"):
                # Move the activation used back for the curent shard back to the CPU.
                # 把啟用移動到CPU
                model_instance._activations[index] = tuple([a.cpu() for a in list(model_instance._activations[index])])
                # The newly computed activations remain on the GPU ready for the next shard computation.
                model_instance._activations.append(output)
                # Move the layer shard back to the CPU.
                layer_shard.forward_drop() # 把層移動到CPU

        # The last instance will lose the gradient function if we move it to the CPU.
        # This is because all grad function are present on the device that ran the FW pass.
        # The last activation remains on the GPU and is the return value of this function.
        # Note that this assumes that the target is also on the GPU which is required for calculating
        # the loss.
        
        result = model_instance._activations[-1] # 最後一層的啟用
        result = [r.cuda() for r in result] # 把最後一層的啟用移動到裝置上,其餘的已經移動到CPU之上
        for r in result:
            r.requires_grad = True
        return result[0] if len(result) == 1 else result

4.2 後向傳播

在BW過程中,它執行相反的操作。我們使用儲存的中間啟用執行前向傳播,並根據需要計算梯度。在使用啟用檢查點時,需要權衡延遲和記憶體。因為這裡會用到幾個PyTorch的內建方法,所以我們需要首先來看看其用法和原理。

4.2.1 no_grad

torch.no_grad() 是一個上下文管理器,被 no_grad 包括起來的程式碼不會跟蹤其梯度。我們用一個例子來看看。

import torch

x = torch.tensor([2.2], requires_grad=True)
y = x * 3
print(y)
y.add_(2)
print(y)

with torch.no_grad():
    y.div_(3)
    print(y)

輸出為:

tensor([6.6000], grad_fn=<MulBackward0>) # 這裡記錄了梯度操作
tensor([8.6000], grad_fn=<AddBackward0>) # add操作被跟蹤
tensor([2.8667], grad_fn=<AddBackward0>) # 用了no_grad,所以div沒有被跟蹤

4.2.2 chunk

torch.chunk(tensor, chunk_num, dim) 將張量按dimension(行或列)分割得到 chunk_num 個張量塊,此函式將返回一個元組,比如下面例子。

x = torch.Tensor([[1,2,3]])
y = torch.Tensor([[4,5,6], [7,8,9], [10,11,12]])
z = torch.cat((x,y), dim=0)
print(z)
print(z.size())
c = torch.chunk(z,4,dim=0)
print(c)
print(len(c))

輸出為:

# cat之後的輸出
tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.]])
torch.Size([4, 3])

# chunk之後的輸出
(tensor([[1., 2., 3.]]), tensor([[4., 5., 6.]]), tensor([[7., 8., 9.]]), tensor([[10., 11., 12.]]))
4

4.2.3 反向傳播

OffloadFunction 的反向傳播如下,這裡有個reverse操作需要注意。

  • 在程式碼初期,會把模型分片和啟用進行reverse(注意,沒有把原始分配和啟用進行reverse,這裡是reverse之後的結果返回,不影響原始資料),因為計算梯度是從後向前,所以把-1放到第一個位置,依次類推,這樣可以方便使用backward_load和backward_drop。
  • 在程式碼最後,因為之前的reverse沒有對 model_instance._activations 做修改,所以可以直接返回輸入之中的梯度。

具體程式碼如下:

class OffloadFunction(torch.autograd.Function):

    # Ignore the following function for code coverage since the backward pass
    # is triggered by C++ code and cannot be calculated when overriding
    # autograd.Function
    @staticmethod
    @_conditional_amp_bwd_decorator
    def backward(ctx, *grad_outputs):  # type: ignore # pragma: no cover
        inputs = ctx.inputs
        model_instance = ctx.model_instance

        # 遍歷上下文儲存的資訊,給輸入設定是否需要梯度
        for i, need_grad in enumerate(ctx.grad_requirements):
            inputs[i].requires_grad = need_grad

        # 得到反向傳播的輸入
        all_grads = [grad_outputs]

        # 把模型分片和啟用進行reverse(注意,沒有把原始分配和啟用進行reverse,這裡是reverse之後的結果返回,不影響原始資料),因為計算梯度是從後向前,所以把-1放到第一個位置,依次類推,這樣可以方便使用backward_load和backward_drop。
        
        # 然後遍歷模型分片,針對每一個分片進行處理
        for model_shard, activation in zip(
            reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
        ):
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_load"):
                # Move the activation to the GPU.
                # 把當前分片的啟用移動到GPU
                activation = tuple([a.cuda() for a in list(activation)])

                # 把當前分片的模型移動到GPU
                # Move the model shard to the GPU.
                model_shard.backward_load()

            # Store the BW pass state.
            # 暫存反向傳播狀態
            bwd_rng_state = torch.get_rng_state()

            # TODO(anj-s): Why detach inputs?
            activation = torch.utils.checkpoint.detach_variable(activation)
            # Get the last gradient calculation.
            final_grads = all_grads[-1] # 這將會是最終生成的梯度

            if isinstance(activation, torch.Tensor):
                activation = (activation,)
            if isinstance(final_grads, torch.Tensor):
                final_grads = (final_grads,)
            # Iterate through all the inputs/outputs of a shard (there could be multiple).
            chunked_grad_list: List[Any] = []
            # Chunk the activation and grad based on the number of microbatches that are set.
            # 因為可能有多個微批次,所以需要把梯度和啟用分別做chunk操作
            for chunked_activation, chunked_grad in zip(
                torch.chunk(*activation, model_instance._num_microbatches),  # type: ignore
                torch.chunk(*final_grads, model_instance._num_microbatches),  # type: ignore
            ):
                # Set the states to what it used to be before the forward pass.
                torch.set_rng_state(ctx.fwd_rng_state) # 暫時使用前向傳播狀態

                # 構建為list
                if isinstance(chunked_activation, torch.Tensor):
                    chunked_activation = (chunked_activation,)  # type: ignore
                if isinstance(chunked_grad, torch.Tensor):
                    chunked_grad = (chunked_grad,)  # type: ignore

                # Since we need a grad value of a non leaf element we need to set these properties.
                for a in chunked_activation:
                    if a.dtype == torch.long:
                        continue
                    a.requires_grad = True # 因為需要計算非葉子結點,所以將其設定為需要梯度
                    a.retain_grad()

                with torch.autograd.profiler.record_function(
                    "fairscale.experimental.nn.offload:forward_pass_with_enable_grad"
                ):
                    with torch.enable_grad():
                        # calculate the output of the last shard wrt to the stored activation at the slice boundary.
                        outputs = model_shard(*chunked_activation) # 前向傳播

                # Set the states back to what it was at the start of this function.
                torch.set_rng_state(bwd_rng_state) # 恢復狀態
                
                with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_pass"):
                    torch.autograd.backward(outputs, chunked_grad) # 反向傳播
                    
                intermediate_grads = []
                for a in chunked_activation:
                    if a.grad is not None:
                        intermediate_grads.append(a.grad)
                if None not in intermediate_grads:
                    chunked_grad_list += intermediate_grads
             
            # 把梯度列表新增到all_grads之上
            if chunked_grad_list:
                # Append the list of grads to the all_grads list and this should be on the GPU.
                all_grads.append(torch.cat(chunked_grad_list).squeeze(-1))  # type: ignore
                
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_drop"):
                # Move the shard back to the CPU. This should move all the grad tensors to CPU as well.
                # We don't need to move activations since we are using a copy of the tensors on the GPU.
                model_shard.backward_drop() # 分割槽移動到CPU
           
        # 之前的reverse沒有對 model_instance._activations 做修改
        detached_inputs = model_instance._activations[0]
        # 從輸入之中拿到其梯度
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
        return (None, None) + grads # 返回梯度

邏輯擴充如下:

[原始碼分析] Facebook如何訓練超大模型--- (5)

至此,FSDP 分析完畢,我們下一個系列將會通過 NVIDIA Megatron 來介紹模型並行,敬請期待。

0xFF

https://arxiv.org/pdf/2101.06840.pdf

https://www.deepspeed.ai/tutorials/zero-offload/

DeepSpeed: Extreme-scale model training for everyone

https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/

https://www.marktechpost.com/2021/02/01/microsoft-and-the-university-of-california-merced-introduces-zero-offload-a-novel-heterogeneous-deeplearning-training-technology-to-train-multi-billion-parameter-models-on-a-single-gpu/

[1] Li et al. “PyTorch Distributed: Experiences on Accelerating Data Parallel Training” VLDB 2020.

[2] Cui et al. “GeePS: Scalable deep learning on distributed GPUs with a GPU-specialized parameter server” EuroSys 2016

[3] Shoeybi et al. “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.” arXiv preprint arXiv:1909.08053 (2019).

[4] Narayanan et al. “Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM.” arXiv preprint arXiv:2104.04473 (2021).

[5] Huang et al. “GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism.” arXiv preprint arXiv:1811.06965 (2018).

[6] Narayanan et al. “PipeDream: Generalized Pipeline Parallelism for DNN Training.” SOSP 2019.

[7] Narayanan et al. “Memory-Efficient Pipeline-Parallel DNN Training.” ICML 2021.

[8] Shazeer et al. “The Sparsely-Gated Mixture-of-Experts Layer Noam.” arXiv preprint arXiv:1701.06538 (2017).

[9] Lepikhin et al. “GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding.” arXiv preprint arXiv:2006.16668 (2020).

[10] Fedus et al. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.” arXiv preprint arXiv:2101.03961 (2021).

[11] Narang & Micikevicius, et al. “Mixed precision training.” ICLR 2018.

[12] Chen et al. 2016 “Training Deep Nets with Sublinear Memory Cost.” arXiv preprint arXiv:1604.06174 (2016).

[13] Jain et al. “Gist: Efficient data encoding for deep neural network training.” ISCA 2018.

[14] Shazeer & Stern. “Adafactor: Adaptive learning rates with sublinear memory cost.” arXiv preprint arXiv:1804.04235 (2018).

[15] Anil et al. “Memory-Efficient Adaptive Optimization.” arXiv preprint arXiv:1901.11150 (2019).

[16] Rajbhandari et al. “ZeRO: Memory Optimization Towards Training A Trillion Parameter Models Samyam.” arXiv preprint arXiv:1910.02054 (2019).

相關文章