[原始碼分析] Facebook如何訓練超大模型---(4)

羅西的思考發表於2022-01-24

[原始碼分析] Facebook如何訓練超大模型 --- (4)

0x00 摘要

我們在前文介紹過,微軟 ZeRO 可以對一個萬億引數模型可以使用 8 路模型並行、64 路管道並行和 8 路資料並行在 4,096 個 NVIDIA A100 GPU 上進行擴充套件。而FSDP(Fully Sharded Data Parallel)是Facebook 深度借鑑微軟ZeRO之後提出的PyTorch DDP升級版本,可以認為是對標微軟 ZeRO,其本質是 parameter sharding。Parameter sharding 就是把模型引數等切分到各個GPU之上。我們會以 Google,微軟和 Facebook 的論文,部落格以及程式碼來進行學習分析。

之前文章之中我們談到了FSDP支援混合精度訓練,所以我們再來看看相關知識。

本系列其他文章如下:

[原始碼解析] PyTorch 分散式之 ZeroRedundancyOptimizer

[論文翻譯] 分散式訓練 Parameter sharding 之 ZeRO

[論文翻譯] 分散式訓練 Parameter Sharding 之 Google Weight Sharding

[原始碼分析] Facebook如何訓練超大模型---(1)

[原始碼分析] Facebook如何訓練超大模型 --- (2)

[原始碼分析] Facebook如何訓練超大模型 --- (3)

0x01 背景知識

1.1 單精度、雙精度和半精度浮點格式的區別

我們從NVIDIA官博 What’s the Difference Between Single-, Double-, Multi- and Mixed-Precision Computing?摘錄如下:

IEEE 浮點算術標準是在計算機上用二進位制表示數字的通用約定。在雙精度格式中,每個數字佔用 64 位。單精度格式使用 32 位,而半精度只有 16 位。

在傳統的科學記數法中,pi 寫為 3.14 x \(10^0\)。但是計算機將這些資訊以二進位制形式儲存為浮點數,即表示數字及其相應指數的一系列 1 和 0,在本例中為 1.1001001 x\(2^1\)

在單精度 32 位格式中,一位用於判斷數字是正數還是負數。為指數保留了八位,指數(因為它是二進位制的)是 2 的某個冪。剩餘的 23 位用於表示組成數字的數字,稱為有效數。

相反,雙精度為指數保留 11 位,為有效數保留 52 位,大大擴充套件了它可以表示的數字的範圍和大小。半精度佔據了更小的部分,只有 5 個位用於指數,10 個位用於有效數。

以下是 pi 在每個精度級別的樣子

1.2 多精度和混合精度計算的區別

多精度計算意味著使用能夠以不同精度進行計算的處理器——在需要時使用雙精度,並依賴於應用程式的其他部分的半精度或單精度演算法。

混合精度,也稱為超精度,計算改為在單個操作中使用不同的精度級別,以在不犧牲精度的情況下實現計算效率。在混合精度中,計算從快速矩陣數學的半精度值開始。但是隨著數字的計算,機器以更高的精度儲存結果。例如,如果將兩個 16 位矩陣相乘,則答案大小為 32 位。

使用這種方法,當應用程式完成計算時,累積的答案在準確度上可與在雙精度算術中執行整個事情相媲美。這種技術可以將傳統雙精度應用程式的速度提高多達 25 倍,同時減少執行它們所需的記憶體、執行時間和功耗。它可用於 AI 和模擬 HPC 工作負載。

1.3 混合精度

採用FP16的優勢如下:

  • 記憶體佔用更少。如果採用FP16,則模型佔用是FP32的一半,這樣可以訓練更大的模型,使用更大的batch size,通訊量更少。
  • 計算更快。FP16的加速優化可以加快訓練和推理的計算。
  • 另外,隨著NVIDIA Tensor Core 的普及,FP6計算也越來越快。

FP16的問題主要是其表示範圍比FP32狹窄,所以會帶來兩個問題:溢位錯誤 和 舍入誤差。因此,百度和NVIDIA聯手在論文之中提出了一些技術。

  • 保留一份FP32格式的權重主備份。
  • 使用loss scale來避免梯度過小。
  • 使用FP16計算但是用FP32進行累加。

比如,對於主備份,論文之中圖例如下:

1.4 訓練過程

上面介紹的三種技術對於訓練過程是一個良好的補充,我們從NVIDIA官方文件 https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html 摘錄訓練過程具體如下。

  1. 維護一份FP32的引數主副本。
  2. 對於每次迭代:
    1. 製作權重的FP16副本。
    2. 使用FP16權重和啟用進行向前傳播。
    3. 將得到的損失乘以比例因子S。
    4. 使用FP16權重,啟用和它們的梯度進行後向傳播。
    5. 將權重梯度乘以1/S。
    6. 完成權重更新(包括gradient clipping等)。

一個更穩健的方法是動態地選擇損失比例因子。其基本思想是以一個大的比例因子開始,然後在每次訓練迭代中重新考慮它。如果在選定的迭代次數N中沒有發生溢位,則增加比例因子。如果發生溢位,則跳過權重更新,降低比例因子。我們發現,只要不頻繁地跳過更新,訓練計劃就不必調整,就可以達到與FP32訓練相同的精度。

請注意,N有效地限制了我們可以溢位和跳過更新的頻率。縮放因子的更新率可以通過選擇增加/減少的乘數以及N(增加前的非溢位迭代次數)來調整。

動態損失縮放方法對應了了以下訓練流程:

  1. 在FP32中保持一份權重的主副本。
  2. 將S初始化為一個大的數值。
  3. 對於每個迭代
    1. 製作一個權重的FP16副本。
    2. 前向傳播(FP16權重和啟用)。
    3. 用比例因子S乘以所得的損失。
    4. 後向傳播(FP16權重、啟用和它們的梯度)。
    5. 如果權重梯度中存在Inf或NaN。
      1. 減少S。
      2. 跳過權重更新,進入下一次迭代。
    6. 將權重梯度與1/S相乘。
    7. 完成權重更新(包括梯度剪裁等)。
    8. 如果在過去的N次迭代中沒有出現Inf或NaN,則增加S。

圖片來自 https://developer.nvidia.com/automatic-mixed-precision

0x02 PyTorch

2.1 英偉達算力

英偉達的Volta及Turing架構GPU在使用FP16計算時的特點如下:

  • FP16的記憶體頻寬和儲存需求相比FP32來說可以降低一半,這樣開發者在相同的硬體條件下可以使用更大更復雜的模型和更大的batch size。
  • 英偉達Volta和Turing架構GPU提供了Tensor Cores技術。Tensor Cores的FP16計算吞吐量是FP32的8倍。

因此,在相同的超引數下,使用半精度浮點(FP16)和單精度(FP32)浮點的混合精度訓練就可以達到與使用純單精度(FP32)訓練相同的準確率,而且模型訓練速度可以大大加速。

2.2 Torch.cuda.amp

PyTorch之中的混合精度主要是依賴 torch.cuda.amp 這個庫,這就說明這個功能是依賴於CUDA的。

前面分析提到了為何要混合計算的原因,這是因為:

  • 在某些場合下對精度損失不敏感,區域性精度損失對最終訓練效果影響非常微弱,並且能利用Tensor Cores進行加速,此時FP16有優勢。
  • 某些場合下對精度損失特別敏感,此時FP32有優勢。

PyTorch 之中,與混合精度相關的張量是torch.FloatTensor和torch.HalfTensor,這兩個混合起來使用就是混合精度了。而框架會根據實際需要來自動(有時需要手動調整)調整一個張量的型別,在torch.FloatTensor和torch.HalfTensor 之中切換,這就是automatic mixed precision(AMP)的來由。

2.2.1 使用

具體使用上,PyTorch 就是使用了autocast + GradScaler。我們從 https://github.com/NVIDIA/DeepLearningExamples 官方例子找出來看看。

GradScaler 的作用是放大loss,防止梯度underflow,但這只是在反向傳播傳遞梯度時候使用,更新權重時候還需要把梯度縮放回原來的大小。

autocast上下文應該只是包括前向傳播和loss計算,因為反向傳播會自動使用前向傳播同樣的型別。

from torch.cuda.amp import autocast as autocast

def do_train(
    model,
    data_loader,
    optimizer,
    scheduler,
    checkpointer,
    device,
    checkpoint_period,
    arguments,
    use_amp,
    cfg,
    dllogger,
    per_iter_end_callback_fn=None,
):
    # 模型預設的是torch.FloatTensor
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()

    if use_amp:
        # 構建GradScaler
        scaler = torch.cuda.amp.GradScaler(init_scale=8192.0)
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        iteration = iteration + 1
        images = images.to(device)
        targets = [target.to(device) for target in targets]

        if use_amp:
            with torch.cuda.amp.autocast(): # 前向傳播開啟autocast
                loss_dict = model(images, targets)
        else:
            loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        # Note: If mixed precision is not used, this ends up doing nothing
        # Otherwise apply loss scaling for mixed-precision recipe
        if use_amp:        
            scaler.scale(losses).backward() # 放大梯度
        else:
            losses.backward()

        def _take_step():
            if use_amp:
                scaler.step(optimizer) # 在方法內部,如果梯度正常,則更新權重,否則忽略此次更新
                scaler.update() # 是否需要增大scaler
            else:
                optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        if not cfg.SOLVER.ACCUMULATE_GRAD:
            _take_step()
        else:
            if (iteration + 1) % cfg.SOLVER.ACCUMULATE_STEPS == 0:
                for param in model.parameters():
                    if param.grad is not None:
                        param.grad.data.div_(cfg.SOLVER.ACCUMULATE_STEPS)
                _take_step()

2.2.2 多Model,losses和優化器

如果你的網路有多個損失,你必須在每個網路之中單獨呼叫scaler.scale。如果網路有多個優化器,你可以在它們之中任意一個單獨呼叫scaler.unscale,並且你必須在每個之中都單獨呼叫scaler.step。

但是,在迭代之中所有優化器都完成step操作之後,才可以呼叫 scaler.update,並且只能呼叫一次。

每個優化器檢查梯度是否為 infs/NaN,並獨立決定是否跳過該步驟。這可能會導致一個優化器跳過該步驟,而另一個則沒有。由於很少發生跳步(每幾百次迭代可能才有一次),這不應妨礙收斂。

scaler = torch.cuda.amp.GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer0.zero_grad()
        optimizer1.zero_grad()
        with autocast():
            output0 = model0(input)
            output1 = model1(input)
            loss0 = loss_fn(2 * output0 + 3 * output1, target)
            loss1 = loss_fn(3 * output0 - 5 * output1, target)

        # (retain_graph here is unrelated to amp, it's present because in this
        # example, both backward() calls share some sections of graph.)
        scaler.scale(loss0).backward(retain_graph=True)
        scaler.scale(loss1).backward()

        # You can choose which optimizers receive explicit unscaling, if you
        # want to inspect or modify the gradients of the params they own.
        scaler.unscale_(optimizer0)

        scaler.step(optimizer0)
        scaler.step(optimizer1)

        scaler.update()

2.2.3 分散式

torch.nn.DataParallel 在每個裝置上產生一個執行緒來執行正向傳遞。autocast state 是執行緒本地的,因此以下內容將不起作用:

model = MyModel()
dp_model = nn.DataParallel(model)

# Sets autocast in the main thread
with autocast():
    # dp_model's internal threads won't autocast.  The main thread's autocast state has no effect.
    output = dp_model(input)
    # loss_fn still autocasts, but it's too late...
    loss = loss_fn(output)

修復很簡單。在MyModel.forward之中使用 autocast。

MyModel(nn.Module):
    ...
    @autocast()
    def forward(self, input):
       ...

# Alternatively
MyModel(nn.Module):
    ...
    def forward(self, input):
        with autocast():
            ...

以下程式碼在dp_model的執行緒(執行forward)和主執行緒(執行loss_fn)中自動轉換:

model = MyModel()
dp_model = nn.DataParallel(model)

with autocast():
    output = dp_model(input)
    loss = loss_fn(output)

torch.nn.parallel.DistributedDataParallel 的文件建議每個程式使用一個 GPU 以獲得最佳效能。在這種情況下,DistributedDataParallel不會在內部產生執行緒,因此autocastGradScaler的使用不受影響。

或者在 forward 方法內部使用with autocast(),這樣可以保證autocast在程式內部生效,比如。

def _forward(self, sample):
    loss = None
    oom = 0
    try:
        if sample is not None:
            with amp.autocast(enabled=self.args.amp):
                # calculate loss and sample size
                logits, _ = self.model(**sample['net_input'])
                target = sample['target']
                probs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
                loss = self.criterion(probs, target)
    except RuntimeError as e:
        if 'out of memory' in str(e):
            print('| WARNING: ran out of memory in worker {}, skipping batch'.format(
                self.args.distributed_rank), force=True)
            oom = 1
            loss = None
        else:
            raise e
    return loss, oom

0x03 FSDP 使用

torch.cuda.amp.autocast 與FSDP完全相容,但是使用者需要把mixed_precision 設定為True,具體示例程式碼如下:

offload_model = OffloadModel(
    model=model,
    device=torch.device("cuda"),
    offload_device=torch.device("cpu"),
    num_slices=3,
    checkpoint_activation=True,
    num_microbatches=1,
)

torch.cuda.set_device(0)
device = torch.device("cuda")

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)

# To train 1 epoch.
offload_model.train()
for batch_inputs, batch_outputs in dataloader:
    batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda")
    start = time.time_ns()
    optimizer.zero_grad()
    inputs = batch_inputs.reshape(-1, num_inputs * num_inputs)
    
    with torch.cuda.amp.autocast(): # 設定使用 amp
        output = model(inputs)
        loss = criterion(output, target=batch_outputs)
        loss.backward()
    optimizer.step()

我們接下來看看FSDP的相關原始碼,

3.1 成員變數

因為涉及了 CPU offload 和分割槽等因素,所以FSDP不能簡單使用amp,需要和 CPU offload 和分割槽結合起來看,比如FP16引數也需要分割槽和offload,因為amp不會自動分割槽和offload,所以FSDP需要把這部分活承擔過來,顯式的進行部分切換工作

前文程式碼提到了一些與混合精度訓練相關的成員變數,這裡就是把32,16位引數分別進行分割槽操作,也會相應做offload操作

  • _fp32_shard:full precision的單個引數分片(通常為fp32,但這取決於使用者傳入的模型資料型別)。這可以在CPU或GPU上進行,具體取決於cpu_offload的值。
  • _fp16_shard:在混合精度模式下,我們在計算裝置上維護一個降低精度(通常是FP16)的引數分片,用於在前向/後向傳遞中執行計算。這就是``_fp16_shard,如果 mixed_precisionTrue`,這將是fp16中引數的單個shard,用於all-gather。
  • _full_param_padded:在向前和向後傳播中用於計算的全部權重(被填充為可被world_size均勻整除)。這將原地調整大小,並僅在需要時具體化(通過all-gather)。

程式碼之中也需要做相應設定,如果我們計劃將FP32/FP16引數保留在CPU上,那麼固定記憶體允許我們以後在將FP32/FP16引數碎片移動到計算裝置時使用非阻塞傳輸。分割槽操作是 FP32,FP16 統一處理的。

3.2 Scaler

在 Scaler 方法,FSDP也推出了有特色的 ShardedGradScaler。PyTorch自動混合精度的實際使用情況將取決於OSS是與DDP還是與ShardedDDP一起使用。

  • 如果OSS與DDP一起使用,那麼就可以使用正常的PyTorch GradScaler,不需要做任何改變。
  • 如果OSS與ShardedDDP一起使用(為了獲得梯度分片),那麼可以使用一個非常類似的流程,但它需要一個感知梯度的GradScaler。它可以在fairscale.optim.grad_scaler中使用。

在這兩種情況下,Autocast都可以照常使用,並且損失將以同樣的方式被縮放和處理。

我們看看ShardedGradScaler程式碼,會發現其特色在於使用 dist.all_reduce 在 ranks 之間進行規約。

import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
import torch.distributed as dist
from torch.optim import Optimizer

from .oss import OSS


class GradScaler(TorchGradScaler):
    def _unscale_grads_(
        self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
    ) -> Dict[torch.device, torch.Tensor]:
        return super()._unscale_grads_(optimizer, inv_scale, found_inf, True)


class ShardedGradScaler(TorchGradScaler):
    """
    A shard-aware :class:`GradScaler<torch.cuda.amp.GradScaler>`, to be used in conjunction with
    :class:`OSS` and :class:`ShardedOptimizer`.

    Interface and usecases are not changed, more explanations can be found in the corresponding pytorch
    documentation https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
    """

    def __init__(
        self,
        init_scale: float = 2.0 ** 16,
        growth_factor: float = 2.0,
        backoff_factor: float = 0.5,
        growth_interval: int = 2000,
        enabled: bool = True,
        process_group: Any = dist.group.WORLD,
    ) -> None:
        super().__init__(
            init_scale=init_scale,
            growth_factor=growth_factor,
            backoff_factor=backoff_factor,
            growth_interval=growth_interval,
            enabled=enabled,
        )
        self.display_warning = True
        self.group = process_group

    def unscale_(self, optimizer: Optimizer) -> None:
        # Could be a mistake, this scaler is supposed to work with ZeroRedundancyOptimizer only
        if self.display_warning and not isinstance(optimizer, OSS):
            logging.warning(
                "ShardedGradScaler is to be used in combination with a sharded optimizer, this could not be checked"
            )
        self.display_warning = False  # Only warn once

        # Call the upstream unscale_ method which will only act on this rank's gradients
        super().unscale_(optimizer)

        # Synchronize the detected inf across the ranks
        optimizer_state = self._per_optimizer_states[id(optimizer)]
        last_handle = None
        
        # 使用了 AllReduce
        for v in optimizer_state["found_inf_per_device"].values():
            last_handle = dist.all_reduce(v, async_op=True, group=self.group)

        # Make sure that the calls are done before moving out.
        # The calls are executed in sequence, waiting for the last one is enough
        if last_handle is not None:
            last_handle.wait()

3.3 初始化

我們接著看看 offload 和混合精度如何使用。在初始化方法 _init_param_attributes 之中,也有操作會為移動到CPU做準備,比如放到鎖頁記憶體之中,也會為混合精度建立張量,比如_fp16_shard。

@torch.no_grad()
def _init_param_attributes(self, p: Parameter) -> None:

    if hasattr(p, "_fp32_shard"):
        return

    # A single shard of the parameters in full precision.
    p._fp32_shard = p.data

    if self.mixed_precision:

        # 為移動到CPU做準備
        if self.move_params_to_cpu:
            # If we plan to keep the FP32 parameters on CPU, then pinning
            # memory allows us to later use non-blocking transfers when moving
            # the FP32 param shard to compute_device.
            p._fp32_shard = p._fp32_shard.pin_memory() 
            p.data = p._fp32_shard

        # 在混合精度模式下,我們在計算裝置上維護一個降低精度(通常是FP16)的引數分片,
        # 用於在前向/後向傳遞中執行計算。    
            
        # In mixed precision mode, we maintain a reduced precision
        # (typically FP16) parameter shard on compute_device for performing
        # the computation in the forward/backward pass. We resize the
        # storage to size 0 at init (here) and re-materialize (by copying
        # from _fp32_shard) as needed.
        p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
        free_storage_(p._fp16_shard)
    else:
        p._fp16_shard = None  # use _fp32_shard

    # We also maintain a full-sized parameter of type self.compute_dtype
    # (FP16 for mixed_precision or FP32 otherwise). We resize the
    # storage to size 0 at init (here) and only materialize as needed. The
    # storage may contain padding elements so that it is evenly divisible by
    # world_size, although these padding elements will be removed before the
    # relevant computation.
    if p._is_sharded:
        p._full_param_padded = torch.zeros(
            p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
        )
        free_storage_(p._full_param_padded)

    # 為移動到CPU做準備     
        
    if self.move_grads_to_cpu: 
        # We can optionally move the grad shard to CPU during the backward
        # pass. In this case, it's important to pre-allocate the CPU grad
        # shard in pinned memory so that we can do a non-blocking transfer.
        p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()

邏輯如下:

3.4 重建

我們以 _rebuild_full_params 為例。因為前面分析過,這裡只是把相關程式碼摘錄,程式碼會依據各種配置進行切換,比如如果指定了強制全精度,則還需要從FP16轉換為FP32,然後再進行all-gather。

@torch.no_grad()
def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:

    output_tensors: List[Tuple[torch.Tensor, bool]] = []

    def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
        """
        Helper function to update p.data pointer.
        """
        if custom_output_tensor is not None:
             # 省略
        elif not p._is_sharded:
            if self.mixed_precision and not force_full_precision: # 切換到 FP16
                p.data = p._fp16_shard
                output_tensors.append((p.data, True))
            else:
                # Here p.data == p._fp32_shard, so it's not safe to free.
                output_tensors.append((p.data, False))
        else:
            # 省略
        # Trim any padding and reshape to match original size.
        p.data = p.data[: p._orig_size.numel()].view(p._orig_size)


    with torch.cuda.stream(self._streams["all_gather"]):
      
        if self.mixed_precision and not force_full_precision:
            self._cast_fp32_param_shards_to_fp16() # 從fp32切換到fp16

        for p in self.params:
            if not p._is_sharded:  # e.g., when world_size == 1
                update_p_data()
            else:
                # If self.move_params_to_cpu and force_full_precision, we need to cast
                # the FP32 CPU param to CUDA for the all-gather.
                
                # 拷貝到GPU
                p_data = p.data.to(p._full_param_padded.device, non_blocking=True)

                p_size = p._full_param_padded.size()
                if self.mixed_precision and force_full_precision:
                    # Allocate fresh tensor in full precision since we are in
                    # mixed precision and full precision rebuild is asked.
                    
                    # 在全精度中分配新的張量,因為我們處於混合精度中,需要進行全精度重建。
                    output_tensor = p_data.new_zeros(p_size)
                else:
                    if p._full_param_padded.storage().size() != p_size.numel():
                        alloc_storage_(p._full_param_padded, size=p_size)
                    output_tensor = p._full_param_padded

                # Fill output_tensor with (p.data for each shard in self.world_size)
                dist.all_gather(chunks, p_data, group=self.process_group) # 簡化版本程式碼

                if self.mixed_precision and not force_full_precision:
                    self._free_fp16_param_shard([p]) # 釋放記憶體
                    
                # 省略    

邏輯如下:

3.5 cast操作

可以從 _cast_fp32_param_shards_to_fp16 之中看到如何做轉換操作。

@torch.no_grad()
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
    """Cast FP32 param shard to FP16 for a list of params."""
    if params is None:
        params = self.params
    with torch.cuda.stream(self._streams["fp32_to_fp16"]):
        for p in params:
            alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
            p._fp16_shard.copy_(
                # If cpu_offload is True, this will be non-blocking because
                # _fp32_shard is pinned, otherwise it's a no-op.
                p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
            )
            p.data = p._fp16_shard
    torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])

3.6 _post_reduction_hook

_post_backward_hook 之中會設定 callback_fn,就是在 reduce-scatter 之後呼叫_post_reduction_hook,可以理解,就是做完這個操作之後,可以把梯度移動到CPU了。

callback_fn = functools.partial(self._post_reduction_hook, param)

具體程式碼如下,和offload相關的是把梯度移動到CPU的操作,和混合精度相關的是把梯度轉換為引數張量的型別

def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
    """Hook to call on each param after the reduce-scatter."""

    param.grad.data = reduced_grad
    if self.gradient_postdivide_factor > 1:
        # Average grad by world_size for consistency with PyTorch DDP.
        param.grad.data.div_(self.gradient_postdivide_factor)
        
    # Cast grad to param's dtype (typically FP32). Note: we do this
    # before the move_grads_to_cpu step so that this entire hook remains
    # non-blocking. The downside is a bit more D2H transfer in that case.
    if self.mixed_precision:
        orig_param_grad_data = param.grad.data # 把梯度進行轉換,一般來說是切換回 FP32
        param.grad.data = param.grad.data.to(dtype=param.data.dtype)
        # Don't let this memory get reused until after the transfer.
        orig_param_grad_data.record_stream(torch.cuda.current_stream())
        
    if hasattr(param, "_saved_grad_shard") and param._saved_grad_shard is not None:
        param.grad.data += param._saved_grad_shard
        delattr(param, "_saved_grad_shard")
        
    # Optionally move gradients to CPU, typically used if one is running
    # the optimizer on the CPU.
    if self.move_grads_to_cpu: # 把梯度移動到CPU
        param._cpu_grad.copy_(param.grad.data, non_blocking=True)
        # Don't let this memory get reused until after the transfer.
        param.grad.data.record_stream(torch.cuda.current_stream())
        param.grad.data = param._cpu_grad

至此,混合精度分析完畢,我們下一篇看看 FSDP 如何使用 Activation recomputation,敬請期待。

0xFF 參考

https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-multiple-models-losses-and-optimizers

PyTorch的自動混合精度(AMP)

混合精度訓練最佳實踐

ZeRO-Offload: Democratizing Billion-Scale Model Training

https://www.deepspeed.ai/tutorials/zero-offload/

DeepSpeed: Extreme-scale model training for everyone

https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/

https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/

https://www.marktechpost.com/2021/02/01/microsoft-and-the-university-of-california-merced-introduces-zero-offload-a-novel-heterogeneous-deeplearning-training-technology-to-train-multi-billion-parameter-models-on-a-single-gpu/

相關文章