[原始碼分析] Facebook如何訓練超大模型 --- (2)

羅西的思考發表於2022-01-19

[原始碼分析] Facebook如何訓練超大模型 --- (2)

0x00 摘要

我們在前文介紹過,微軟 ZeRO 可以對一個萬億引數模型可以使用 8 路模型並行、64 路管道並行和 8 路資料並行在 4,096 個 NVIDIA A100 GPU 上進行擴充套件。

而FSDP(Fully Sharded Data Parallel)是Facebook 深度借鑑微軟ZeRO之後提出的PyTorch DDP升級版本,可以認為是對標微軟 ZeRO,其本質是 parameter sharding。Parameter sharding 就是把模型引數等切分到各個GPU之上。我們會以 Google,微軟和 Facebook 的論文,部落格以及程式碼來進行學習分析。

前文我們介紹了 FSDP 如何使用,本文從原始碼角度來介紹 FSDP 如何實現引數分割槽。

本系列其他文章如下:

[原始碼解析] PyTorch 分散式之 ZeroRedundancyOptimizer

[論文翻譯] 分散式訓練 Parameter sharding 之 ZeRO

[論文翻譯] 分散式訓練 Parameter Sharding 之 Google Weight Sharding

[原始碼分析] Facebook如何訓練超大模型---(1)

0x01 回顧

1.1 ZeRO

我們首先回顧一下ZeRO。

深度模型訓練之中,視訊記憶體主要是被Model StatesActivation 兩部分所佔用。Model States 包括:

  • Optimizer States:優化器在梯度更新時候所用到資料,比如SGD 中的Momentum
  • Gradient: 反向傳播所產生的梯度。
  • Model Parameter: 模型引數,即在訓練過程中通過資料“學習”到的資訊。

ZeRO 解決的就是Model States佔用問題,而ZeRO分為三個級別,如下圖所示,分別對應了對於Model States在不同程度上的分割。

img

1.1.1 ZeRO-1

此級別會切分Optimizer States

無論在前向傳播還是反向傳播階段,優化器都不會起作用,Optimizer States只是在梯度產生之後,在利用梯度做更新時才會與模型引數一起計算,產生新的引數。因此,ZeRO-1對Optimizer States做切分,假設有N個worker,則讓每個worker只擁有1/N的Optimizer States,利用這1/N的Optimizer States更新與之對應的1/N引數之後,再把所有引數拼接起來,構成完整的模型(具體是通過broadcast或allgather操作以確保所有rank都收到最新更新的引數值)。

1.1.2 ZeRO-2

ZeRO-2會分割Optimizer StatesGradients

ZeRO-2 是建立在 ZeRO-1 基礎之上,因為ZeRO-1已經把Optimizer States分段儲存在了多個worker之中,所以自然而然就只需要得到自己worker的Optimizer States對應的梯度,也就完成了梯度分片。

所有worker的梯度通過AllReduce進行聚合,每個worker只選擇自己需要的部分梯度即可,其餘梯度可以丟棄。

1.1.3 ZeRO-3

ZeRO-3會分割Optimizer StatesGradientsParameters。在 ZeRO-1, ZeRO-2基礎之上,使得每個worker只保留部分模型分片,所有worker通力合作提供一個完整的Model States,按照具體計算需求進行引數的收集和釋放。這裡要強調一點,ZeRO-3 做得是的引數收集與釋放,就是針對每個引數進行細緻處理,我們後續會結合程式碼進行分析。

1.2 DDP VS FSDP

我們先從原始碼早期版本中找出一個圖來看看DDP與FSDP的區別,大家可以回顧一下。

0x02 總體邏輯

2.1 FSDP

我們首先回憶FSDP總體邏輯如下:

  • Model shard :每個GPU上僅存在模型的分片
  • All-gather :每個GPU通過all-gather從其他GPU收集所有權重,以在本地計算前向傳播。就是論文思路Pp下劃線部分
  • Forward(local):在本地進行前向操作。前向計算和後向計算都是利用完整模型。
  • All-gather :然後在後向傳播之前再次執行此權重收集。就是論文思路Pp之中的下劃線部分
  • Backward(local):本地進行後向操作。前向計算和後向計算都是利用完整模型,此時每個GPU上也都是全部梯度
  • Reduce-scatter :在向後傳播之後,區域性梯度聚合並且通過 reduce-scatter 在各個GPU上分片,每個分片上的梯度是聚合之後本分割槽對應的那部分,就是論文思路Pg之中的下劃線部分。
  • Update Weight(local):每個GPU更新其區域性權重分片。

2.2 原始ZeRO

其次,我們看看微軟 ZeRO 原始程式碼是如何處理的,大家可以結合上面FSD思路來對照,在後續FSDP程式碼分析之中也可以看看兩者的具體實現區別。

1.2.1 初始化

ZeRO初始化時候會對引數進行均勻切分給各個程式,它會:

  • 把原始引數打平為一維。
  • 每個worker依據自己rank來找到在一維引數之中的起始和終止位置,然後拷貝自己對應的資料。
  • 為了防止後續填充和分割槽導致原始資料特性的丟失,會在 _convert_to_deepspeed_param 之中記錄原始張量的資訊,比如shape, numel等等。
  • 會把原始引數釋放掉,變成一個標量型別的張量。

因為前向傳播/後向傳播時候都需要完整引數,所以需要知道如何得到全部引數,ZeRO會在初始化時候就構建控制資訊,具體操作是給每個submodule 建立 4個hooks。

  • _pre_forward_module_hook,在submodule的前向傳播開始前收集module parameters
  • _post_forward_module_hook,在submodule的前向傳播結束後釋放module parameters
  • _pre_backward_module_hook,在submodule的反向傳播開始前收集module parameters
  • _post_backward_module_hook,在submodule的反向傳播結束後釋放module parameters

具體程式碼是:

# Pre forward hook
module.register_forward_pre_hook(_pre_forward_module_hook)
# Post forward hook
module.register_forward_hook(_post_forward_module_hook)
# Pre backward hook
module.register_forward_hook(_pre_backward_module_hook)
# post backward hook
module.register_forward_pre_hook(_post_backward_module_hook)

然後會構建兩個類:PartitionedParameterCoordinator 和 PrefetchCoordinator,它們負責具體收集和釋放,被每個hook呼叫。

1.2.2 前向傳播

前向傳播開始之前,_pre_forward_module_hook會收集各個分割槽上的權重,構建原始引數。這裡有一些優秀的技巧。

因為訓練是逐層進行,所以ZeRO會進行預取操作,即在收集本層引數時候,也會把下一層引數也收集進來,這樣可以節省通訊時間。具體是:

  • 在做第一個迭代時候,就記錄一個模型的完整執行記錄,就是每個 nn.module的執行順序。
  • 會依據執行記錄來收集本層和下一層的引數,並行依據前面提到的_convert_to_deepspeed_param 之中記錄的原始張量資訊,重建成原始的大引數。
  • 執行本層submodule的forward,記錄本次走到哪一步,這樣就知道下一次預取哪一個層的引數。

前向傳播結束之後,會呼叫_post_forward_module_hook來釋放本層重建的原始的大引數。

這裡就是關鍵了,具體all-gather / release 是逐層操作的,就是每次迭代之中,逐步構建每層的原始大引數 / 釋放每層的原始大引數,GPU上始終沒有一個完整的所有層的模型,而是陸續擁有每層的原始引數。比如有6層,則達到計算第4層時候,某個GPU上如下,只有第三層是完整的引數,前面三層釋放掉了,後面兩層還沒有收集。

+---------------------------------------------+
| GPU n                                       |
|                        +  gather            |
|                        |  forward           |
|                        |  release           |
|                        v                    |
|                    +---+----+               |
|  Layer 0           |        |               |
|                    +---+----+               |
|                        |  gather            |
|                        |  forward           |
|                        |  release           |
|                        v                    |
|                    +---+----+               |
|  Layer 1           |        |               |
|                    +---+----+               |
|                        |  gather            |
|                        |  forward           |
|                        |  release           |
|                        v                    |
|                    +---+----+               |
|  Layer 2           |        |               |
|                    +---+----+               |
|                        |                    |
|                        |  gather            |
|                        |  forward           |
|                        v                    |
|           +------------+------------------+ |
|  Layer 3  |                               | |
|           +-------------------------------+ |
|                                             |
|                    +--------+               |
|  Layer 4           |        |               |
|                    +--------+               |
|                                             |
|                    +--------+               |
|  Layer 5           |        |               |
|                    +--------+               |
+---------------------------------------------+

通過這樣的方式,每個worker 中 submodule 只需要在前向傳播計算前收集/構建引數計算後釋放引數,就可以減少冗餘記憶體空間。如果我們單獨拿出一層來看前向傳播和後向傳播,則其執行機制如下,通過幾個 hook 完成了引數分割槽。

+-------------------------------------------------------------+
| submodule                                                   |
|                                                             |
|        _pre_forward_module_hook()      gather & rebuild     |
|                                                             |
|                    +                                        |
|                    |                                        |
|                    |                                        |
|                    v                                        |
|                                                             |
|                 forward                                     |
|                                                             |
|                    +                                        |
|                    |                                        |
|                    |                                        |
|                    v                                        |
|                                                             |
|        _post_forward_module_hook()      release             |
|                                                             |
|                    +                                        |
|                    |                                        |
|                    |                                        |
|                    v                                        |
|                                                             |
|        _pre_backward_module_hook()      gather & rebuild    |
|                                                             |
|                    +                                        |
|                    |                                        |
|                    |                                        |
|                    v                                        |
|                                                             |
|                 backward                                    |
|                                                             |
|                    +                                        |
|                    |                                        |
|                    |                                        |
|                    v                                        |
|                                                             |
|        _post_backward_module_hook()     release             |
|                                                             |
+-------------------------------------------------------------+

1.2.3 反向傳播

_pre_backward_module_hook也是類似前向傳播那樣收集,預取引數,記錄執行步驟。

_post_backward_module_hook也是類似前向傳播那樣釋放計算不再需要的冗餘引數。

只是因為 PyTorch 不支援Pre Backward Hook,所以在 register_forward_hook時候配置了一個autograd.Function,其目的是在module 做backward 之前執行自定義的操作,於是all-gather和scatter reduce 操作就掛到每個submodule之上。

2.3 FSDP程式碼

然後,我們結合程式碼做一下總述。

2.3.1 初始化

初始化主要作用就是把 Model Parameters 進行切分,每個worker都會分攤部分模型引數。

具體分片操作是通過將每個引數視為一維張量並僅保留一個切片來實現的,其中切片大小由資料並行worker的數量決定。需要注意的是:模型引數必須在載入到GPU之前就進行拆分,然後才能載入到各個worker的GPU之上

因為後期有填充和分割槽操作,為了放置原始資料特性丟失,FSDP利用了 PyTorch的 data.size() 方法,把原始資料特性記錄在 p._orig_size 之中。

VS ZeRO:此時FSDP沒有做hook的控制操作。

2.3.2 前向傳播

這一部分的核心是:每個GPU之上進行前向傳播,同時為後向傳播建立控制關係,這樣後向傳播知道應該如何收集引數,如何釋放引數。具體有如下操作:

  • 首先,因為前向傳播利用的是完整模型,所以先要使用All-gather來從其他GPU收集所有權重,具體是通過呼叫 _rebuild_full_params() 完成重建所有模型引數,其會利用p._orig_size儲存的原始資訊進行重建原始引數。
  • 呼叫_register_post_backward_hooks為後向傳播建立 reduce-scatter。
  • 進行前向操作。
  • 呼叫 _register_pre_backward_hooks(outputs) 為後向傳播註冊 all-gather。

具體對應簡化程式碼是:

self._rebuild_full_params() # 做前向操作之前的 all-gather
self._register_post_backward_hooks() # 為後向傳播註冊 reduce-scatter
outputs = self.module(*args, **kwargs) # 模型前向傳播
outputs = self._register_pre_backward_hooks(outputs) # 為後向傳播註冊 all-gather

VS ZeRO:FSDP在此時hook的控制操作,但是沒有利用module的各種hook,而是統一利用張量的 register_hook。

2.3.3 分層優化

大家可能有疑惑,這和ZeRO原始程式碼不同呀,ZeRO原始程式碼是每一層都執行收集/丟棄,FSDP這裡看起來是對整體模型做了一次forward,沒有分層執行

其實,以上程式碼只是一個標準實現或者說只把整個系統看作是一層,沒有涉及到分層執行收集/丟棄。FSDP已經考慮到了分層的情況,具體如下:

為了最大限度地提高記憶體效率,我們可以在每層向前傳播後丟棄全部權重,為後續層節省記憶體。這可以通過將FSDP包裝應用於網路中的每一層來實現(使用auto_wrap來實現包裝每一層,以及設定reshard_after_forward=True)。下面是虛擬碼示意:

FSDP forward pass:
    for layer_i in layers:
        all-gather full weights for layer_i # 權重
        forward pass for layer_i
        discard full weights for layer_i # 權重

FSDP backward pass:
    for layer_i in layers:
        all-gather full weights for layer_i # 權重
        backward pass for layer_i
        discard full weights for layer_i # 權重
        reduce-scatter gradients for layer_i # 梯度

2.3.4 小結

我們可以看到,如果模型引數被分片,則本地的優化器就會優化這些本地被分到的引數,則優化器狀態就自動被分片了,從而梯度也被自動分片了,就是圖之中最下面的\(P_{os+g+p}\)

img

0x03 初始化

初始化主要作用就是把 Model Parameters進行切分,每個worker都會分攤部分模型引數。假如有3個,則每個worker分擔了1/3,於是它們就把不屬於自己的另外2/3(因為已經是冗餘的了)釋放掉。但是3個worker各自模型引數合併起來,恰好又是整個模型引數。

我們首先統覽初始化方法全域性,大家有一個大致的印象,接下來會仔細逐步分析。

class FullyShardedDataParallel(nn.Module):

    def __init__(
        self,
        module: nn.Module,
        process_group: Optional[ProcessGroup] = None,
        reshard_after_forward: bool = True,
        mixed_precision: bool = False,
        fp32_reduce_scatter: bool = False,
        flatten_parameters: bool = True,
        move_params_to_cpu: bool = False,
        compute_dtype: Optional[torch.dtype] = None,
        buffer_dtype: Optional[torch.dtype] = None,
        move_grads_to_cpu: Optional[bool] = None,
        bucket_cap_mb: int = 25,
        compute_device: Optional[torch.device] = None,
        no_broadcast_optim_state: Optional[bool] = False,
        state_dict_device: Optional[torch.device] = None,
        clear_autocast_cache: bool = False,
        force_input_to_fp32: bool = False,
        verbose: bool = False,
        cpu_offload: bool = False,
    ):
        init_start = time.time()
        super().__init__()
        self.process_group = process_group or get_process_group_cached()
        self.rank = self.process_group.rank()
        self.world_size = self.process_group.size()
        self.reshard_after_forward = reshard_after_forward
        self.mixed_precision = mixed_precision
        self.fp32_reduce_scatter = fp32_reduce_scatter
        self.flatten_parameters = flatten_parameters
        self.move_params_to_cpu = move_params_to_cpu or cpu_offload
        self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
        self.buffer_dtype = buffer_dtype or self.compute_dtype
        self.move_grads_to_cpu = self.move_params_to_cpu if move_grads_to_cpu is None else move_grads_to_cpu
        self.bucket_cap_mb = bucket_cap_mb
        self.compute_device = compute_device or _get_default_cuda_device(module)
        self.uncollected_opt_state: Dict[int, Dict] = {}
        self.no_broadcast_optim_state = no_broadcast_optim_state
        self.state_dict_device = state_dict_device or self.compute_device
        self.clear_autocast_cache = clear_autocast_cache
        self.force_input_to_fp32 = force_input_to_fp32
        self.verbose = verbose

        self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
        self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor

        self.numel_padded_per_param: List[int] = []
        self._tstart = time.time()

        # skip validation if the process group was created above
        if process_group:
            validate_process_group(self.compute_device, self.process_group)

        # enable pytorch sync_bn just in case model contains sync_bn layers.
        enable_pytorch_sync_bn(module)

        # 1. 打平引數
        
        # Only handle params which are not already sharded. This enables
        # sharding individual layers of a Module, with an outer wrapper to
        # shard any leftover parameters.
        param_names = []
        params = []
        
        # 1.1 遍歷模型引數,收集到params之中
        for param_name, param in module.named_parameters():
            if not hasattr(param, "_is_sharded"):
                param_names.append(param_name)
                params.append(param)

        self._has_params = len(params) > 0

        # 1.2 把需要打平的引數收集到 to_be_flatten_params 之中
        to_be_flatten_params: List[List[Parameter]] = [[]]
        non_flatten_params = params
        param_name_groups = [[n] for n in param_names]
        if self.flatten_parameters:
            to_be_flatten_params = [params]
            non_flatten_params = []
            param_name_groups = [param_names]
        del param_names

        # 1.3 使用 FlattenParamsWrapper 來打平引數
        self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=to_be_flatten_params)
        del module  # free original module in case it helps garbage collection

        # Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten
        # params for doing sharding, gradient hooks, etc. Note, the ordering of the
        # list matters: flatten params are always in the front.
        #
        # The self._num_flatten_params and self._param_name_groups are computed
        # and kept here to support summon_full_params and shard-to-full weight
        # consolidation.
        
        # 1.4 把打平的引數和其他引數拼接到 self.params 之中
        self.params = cast(List[Parameter], self._fsdp_wrapped_module.flat_params) + non_flatten_params
        self._num_flatten_params = len(self._fsdp_wrapped_module.flat_params)
        self._param_name_groups = param_name_groups

        # 2. 進行引數分割槽
        
        # Shard module parameters in place
        self._shard_parameters_() # 

        # 3. 惰性初始化
        self._reset_lazy_init()

        # Flag to indicate if we require gradient reduction in the backward
        # pass. This will be False when inside the no_sync context manager.
        self._require_backward_grad_sync: bool = True

        # Enum to indicate if we're in the forward/backward pass, idle, etc.
        self.training_state = TrainingState.IDLE

        # Flag to indicate if the full params are gathered.
        self.has_full_params: bool = False

        # Register hook after state_dict() to remove the "_fsdp_wrapped_module."
        # prefix and before load_state_dict() to add it back.
        self._register_state_dict_hook(_post_state_dict_hook)
        self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)

        # Flag to indicate whether state_dict() should automatically summon the
        # full params. This defaults to True, but may be set to False if the
        # user explicitly requests the local state dict via local_state_dict().
        self._return_full_state_dict = True
        init_end = time.time()

        # Flag to guard multiple pre-backward hook being executed per iteration.
        # This is reset at the end of the backward pass.
        self._pre_backward_hook_has_run = False

3.1 處理引數

初始化方法的第一步是處理引數,下面這些具體數字對應程式碼之中的註釋。

  • 1.1 會遍歷模型引數,收集到params之中。
  • 1.2 把需要打平的引數收集到 to_be_flatten_params 之中。
  • 1.3 使用 FlattenParamsWrapper 來打平引數。
    • 現在,我們得到了一個列表 self.params,其儲存用於sharding、gradient hooks等操作的待展平和不展平引數。列表的順序是:展平引數總是在前面。
    • self._num_flatten_params 和 self._param_name_groups 也會被計算出來,以支援summon_full_params and shard-to-full 權重合並。
  • 1.4 把打平的引數和其他引數拼接到 self.params 之中。

3.1.2 進行分片

初始化然後會呼叫 shard_parameters 進行分片。從註釋可知,在初始化時會包裝一個有完整引數的模組,並將引數原地切分。具體分片操作是通過將每個引數視為一維張量並僅保留一個切片來實現的,其中切片大小由資料並行worker的數量決定。

需要注意的是:模型引數必須在載入到GPU之前就進行拆分,然後才能載入到各個worker的GPU之上

@torch.no_grad()
def _shard_parameters_(self) -> None:
    """
    At initialization we wrap a module with full parameters and shard the
    parameters in-place. Sharding is implemented by viewing each parameter
    as a 1D Tensor and retaining only a single slice, where the slice size
    is determined by the number of data parallel workers.

    Wrapping modules with many small parameters (or with a very large data
    parallel world size) will result in many small parameter shards and slow
    performance. In this case it's better to set *``flatten_parameters``* to
    ``True``, so that all of the small parameters in the module are combined
    into a single contiguous Tensor and sharded once.

    After this initial sharding is complete, the user can initialize a
    ``torch.optim.Optimizer`` in the usual way, i.e.::

    The optimizer will see only a single slice of parameters and will thus
    allocate less memory for optimizer state, avoiding redundancy across
    data parallel workers.
    """
    self.numel_padded_per_param = []
    for p in self.params: # 遍歷模型引數列表

        # If world_size is 1, then we all-reduce grads instead of sharding.
        p._is_sharded = self.world_size > 1
        p._orig_size = p.data.size() # 記錄張量原始資訊(shape, numel, etc)

        if not p._is_sharded:
            self.numel_padded_per_param.append(0)
            continue
        p._is_sharded = True

        # Replace p.data with the relevant shard.
        orig_data = p.data # 拿到原始資料
        p.data, num_padded = self._get_shard(p.data) # 獲取這個模型引數的分割槽
        self.numel_padded_per_param.append(num_padded)
        free_storage_(orig_data) # 釋放冗餘資料

_get_shard 就是具體做分割槽操作,但只是會返回本rank對應的分割槽。

ZeRO原始程式碼之中,會對每個模型引數張量套一個_convert_to_deepspeed_param馬甲,這樣可以把張量原始特性(shape, numel, etc)記錄下來,防止後期因為填充和分割槽導致原始資料特性丟失,FSDP沒有采用這個辦法,而是記錄在 p._orig_size 之中,具體是利用了 PyTorch的 data.size() 方法

def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
    """Return the local shard of a full tensor."""
    
    # Shard using torch.chunk to match all-gather/reduce-scatter.
    # 把傳入的張量打平,按照world size分成一個list
    chunks = list(torch.flatten(tensor).chunk(self.world_size))
    # 把list之中都初始化
    while len(chunks) < self.world_size:
        chunks.append(chunks[0].new_empty(0)) # 插入空白張量

    # Determine number of padding elements.
    # 看看需要pad多少元素
    num_to_pad = chunks[0].numel() - chunks[self.rank].numel()

    # 獲得本rank對應的分割槽
    shard = chunks[self.rank].clone()
    if num_to_pad > 0:
        shard = F.pad(shard, [0, num_to_pad]) # pad
    # 返回    
    return shard, num_to_pad

3.3 惰性初始化

在 forward 或者其他方法之中,會進行惰性初始化,具體是:

  • 呼叫 _init_param_attributes 來初始化引數,對於需要移動的引數,為後續移動到CPU做準備,放到pin_memory之中。
  • 呼叫 _set_is_root 進行root設定,這裡主要是針對程式組做一些設定。
  • 呼叫 _setup_streams 建立CUDA流。為"fp32_to_fp16","all_gather" 和 "post_backward" 分別建立不同的 CUDA 流,建立 ReduceScatterBucketer。
  • 呼叫 _wait_for_previous_optim_step 等待流完成。
def _lazy_init(self) -> None:
    """Initialization steps that should happen lazily, typically right
    before the first forward pass.
    """
    # Initialize param attributes lazily, in case the param's dtype or
    # device changes after __init__.
    for p in self.params:
        self._init_param_attributes(p) # 1. 初始化引數

    # Initialize _is_root and setup streams. These steps would ideally
    # happen in __init__, but _is_root can only be determined after the
    # entire model hierarchy is setup, thus we run it lazily.
    if self._is_root is None:
        self._set_is_root()
        self._setup_streams()

    if self._is_root:
        # Buffers stay on GPU, and don't get sharded. Since _cast_buffers
        # applies recursively, we only call this from the root instance.
        self._cast_buffers()

        # Don't free the full params for the outer-most (root) instance,
        # since those params will be needed immediately after for the
        # backward pass.
        self.reshard_after_forward = False

        # Due to the use of streams, we need to make sure the previous
        # ``optim.step()`` is done before we all-gather parameters.
        self._wait_for_previous_optim_step()

3.3.1 初始化引數

此處會設定以下引數,這裡就能看出來混合精度的切換:

  • _fp32_shard:full precision的單個引數分片(通常為fp32,但這取決於使用者傳入的模型資料型別)。可以在CPU或GPU上進行,具體取決於cpu_offload的值。
  • _fp16_shard:如果 mixed_precisionTrue,這將是fp16中引數的單個shard,用於all-gather。
  • _full_param_padded:在向前和向後傳播中用於計算的全部權重(被填充為可被world_size均勻整除)。這將原地調整大小,並僅在需要時具體化(通過all-gather)。

主要邏輯是:為後續移動到CPU做準備,某些引數會放到pin_memory之中,生成一個容納所有權重的_full_param_padded 。

@torch.no_grad()
def _init_param_attributes(self, p: Parameter) -> None:
    """
    We manage several attributes on each Parameter instance. The first two
    are set by :func:`_shard_parameters_`:

        ``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
            if the Parameter is intentionally not sharded (in which case we
            will all-reduce grads for this param).
        ``_orig_size``: the size of the original Parameter (before sharding)

    The remaining attributes are set here:
        ``_fp32_shard``: a single shard of the parameters in full precision
            (typically FP32, but this is dependent on the dtype of the model
            as it's passed in by the user). This can be on CPU or GPU
            depending on the value of *``cpu_offload``*.
        ``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be
            a single shard of the parameters in FP16, used for all-gather.
        ``_full_param_padded``: the full weight (padded to be evenly
            divisible by ``world_size``), used for computation in the
            forward and backward pass. This will be resized in place and
            only materialized (via all-gather) as needed.
    """
    if hasattr(p, "_fp32_shard"):
        return

    # A single shard of the parameters in full precision.
    p._fp32_shard = p.data

    if self.mixed_precision:
        if self.move_params_to_cpu: 
            # 為後續移動到CPU做準備,放到pin_memory之中
            # If we plan to keep the FP32 parameters on CPU, then pinning
            # memory allows us to later use non-blocking transfers when moving
            # the FP32 param shard to compute_device.
            p._fp32_shard = p._fp32_shard.pin_memory() 
            p.data = p._fp32_shard

        # In mixed precision mode, we maintain a reduced precision
        # (typically FP16) parameter shard on compute_device for performing
        # the computation in the forward/backward pass. We resize the
        # storage to size 0 at init (here) and re-materialize (by copying
        # from _fp32_shard) as needed.
        p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
        free_storage_(p._fp16_shard)
    else:
        p._fp16_shard = None  # use _fp32_shard

    # We also maintain a full-sized parameter of type self.compute_dtype
    # (FP16 for mixed_precision or FP32 otherwise). We resize the
    # storage to size 0 at init (here) and only materialize as needed. The
    # storage may contain padding elements so that it is evenly divisible by
    # world_size, although these padding elements will be removed before the
    # relevant computation.
    if p._is_sharded:
        p._full_param_padded = torch.zeros( # _full_param_padded 是所有權重
            p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
        )
        free_storage_(p._full_param_padded)

    if self.move_grads_to_cpu: 
        # 為後續移動到CPU做準備,放到pin_memory之中
        # We can optionally move the grad shard to CPU during the backward
        # pass. In this case, it's important to pre-allocate the CPU grad
        # shard in pinned memory so that we can do a non-blocking transfer.
        p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()

3.3.2 root設定

這裡主要是針對程式組做一些設定。

def _set_is_root(self) -> None:
    """If ``True``, implies that no other :class:`FullyShardedDataParallel`
    instance wraps this one. Called once by :func:`_lazy_init`.
    Also sets self.children_share_process_group = True if all child
    instances share the same process group. If some child instances use a
    different process group, self.clip_grad_norm_ will raise an error.
    """
    if self._is_root is not None:
        return
    # No FSDP instance wraps this, else _is_root would be set to False.
    self._is_root = True
    # As the root, we now set all children instances to False and
    # give them a closure to try to queue a wait_for_post_backward.
    self.children_share_process_group = True
    for n, m in self.named_modules():
        # `n != ""` excludes self.
        if n != "" and isinstance(m, FullyShardedDataParallel):
            # We relax the assert for non-root instance, when the nested inialized module is wrapped
            # again in FSDP later, for example after training to run inference.
            assert m._is_root is None or not m._is_root
            if m._is_root is None:
                m._is_root = False
            if m.process_group != self.process_group:
                self.children_share_process_group = False

            # if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
            # Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
            m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
                (m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
            )

3.3.3 建立CUDA流

為"fp32_to_fp16","all_gather" 和 "post_backward" 分別建立不同的 CUDA 流,建立 ReduceScatterBucketer。

def _setup_streams(self) -> None:
    """Create streams to overlap data transfer and computation."""
    if len(self._streams) > 0 or not self._is_root:
        return

    if torch.cuda.is_available():
        # Stream to move main FP32 params (may be on CPU) to FP16 for forward.
        self._streams["fp32_to_fp16"] = torch.cuda.Stream()
        # Stream for all-gathering parameters.
        self._streams["all_gather"] = torch.cuda.Stream()
        # Stream for overlapping grad reduction with the backward pass.
        self._streams["post_backward"] = torch.cuda.Stream()

    # Helper for bucketing reduce-scatter ops. This is also shared with
    # children instances to improve bucket utilization.
    self._reducer = ReduceScatterBucketer(self.bucket_cap_mb)
    # We share streams with all children instances, which allows them to
    # overlap transfers across the forward pass without synchronizing with
    # the default stream.
    for n, m in self.named_modules():
        if n != "" and isinstance(m, FullyShardedDataParallel):
            m._streams = self._streams
            m._reducer = self._reducer

3.3.4 同步

等待流完成操作。

def _wait_for_previous_optim_step(self) -> None:
    """
    The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root
    instance) needs to synchronize with the default stream to ensure the
    previous optimizer step is done.
    """
    if not torch.cuda.is_available():
        return
    if self.mixed_precision:
        self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
    else:
        self._streams["all_gather"].wait_stream(torch.cuda.current_stream())

所以,目前狀態如下,假設有2個GPU,模型引數被切分到兩個GPU之上。假設模型有兩個引數,Parameter 0,Parameter 1,每個引數都被切分成兩段,分別存在兩個GPU之上,其中 Parameter 0 被分成 Parameter 0_0 和 Parameter 0_1,Parameter 1 被分成 Parameter 1_0 和 Parameter 1_1 。

                  Model Parameter
                  +----------------------------+
                  |       Parameter 0          |
                  |                            |
                  |       Parameter 1          |
                  |                            |
                  +------------+---------------+
                               |
                               | split
                               v
                         +-----+-----+
                         |           |
                         |           |
 GPU 0                   v           v                       GPU 1
+------------------------+----+   +--+---------------------------+
|  Model Parameter Shard 0    |   |  Model Parameter Shard 1     |
| +-------------------------+ |   | +--------------------------+ |
| |    Parameter 0_0        | |   | |      Parameter 0_1       | |
| |                         | |   | |                          | |
| |    Parameter 1_0        | |   | |      Parameter 1_1       | |
| |                         | |   | |                          | |
| +-------------------------+ |   | +--------------------------+ |
+-----------------------------+   +------------------------------+

0x04 前向傳播

這部分核心是根據引數分片需求做到精確的引數收集/使用/釋放。收集就是下面圖的All-gather,釋放就是Reduce-Scatter。

4.1 forward

依據前文的分析,我們知道前向操作包括兩部分:

  • All-gather :每個GPU通過all-gather從其他GPU收集所有權重,以在本地計算前向傳播。
  • Forward(local):在本地進行前向操作。前向計算和後向計算都是利用完整模型。

對應到程式碼,具體邏輯是:

  1. 如果使用混合精度,則把輸入轉換為FP16。
  2. 如果不使用混合精度,切強制轉換FP32,則進行轉換 。
  3. 呼叫 _rebuild_full_params() 做前向操作之前的 all-gather,這樣可以重建所有模型引數。
  4. 因為引數的收集/釋放是發生在前向傳播和後向傳播之中,所以在前向傳播時候就需要為後向傳播做好配置。具體就是呼叫_register_post_backward_hooks為後向傳播建立 reduce-scatter。
  5. 進行前向操作。
  6. 切換到主FP32引數分片。我們在整個程式碼中都保持這個不變數,即在每個函式之後,p.data == p._fp32_shard。因為優化器狀態通常在optim.step()中延遲初始化,這還確保在第一次forward之後,優化器狀態將使用正確的資料型別和(分片)大小來初始化。
  7. 呼叫 _register_pre_backward_hooks(outputs) 為後向傳播註冊 all-gather。這裡是最終的output張量上註冊了hook,所以在反向傳播時候,會第一個呼叫這個hook,就可以順理成章的做all-gather。因為這個必須在最終output之上註冊,所以 _register_pre_backward_hooks 是在前向傳播最後部分才呼叫。
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
    self._lazy_init()

    # Start of a forward pass.
    self.training_state = TrainingState.FORWARD

		# 1. 如果使用混合精度,則把輸入轉換為FP16
    
    # For root and mixed precision, we convert the input to FP16 (no_grad is needed for
    # the conversion).
    if self._is_root and self.mixed_precision:
        args, kwargs = cast_floats_to_right_precision(True, True, *args, **kwargs)

    # 2. 如果不使用混合精度,切強制轉換FP32,則進行轉換  
        
    # If enabled, convert the input to FP32 if we are in full precision.
    # no_grad is not used because the input might be for a non-root instance,
    # which mean autograd needs to go through the conversion.
    if self.force_input_to_fp32 and not self.mixed_precision:
        args, kwargs = cast_floats_to_right_precision(False, False, *args, **kwargs)

    # 3. 呼叫 _rebuild_full_params() 做前向操作之前的 all-gather
        
    # All-gather full parameters. This will also transfer FP32 parameters to
    # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
    self._rebuild_full_params() # 做前向操作之前的 all-gather

    # 4. 呼叫_register_post_backward_hooks為後向傳播建立 Reduce-scatter
    
    # Register backward hooks to reshard params and reduce-scatter grads.
    # These need to be re-registered every forward pass.
    self._register_post_backward_hooks() # 為後向傳播建立 Reduce-scatter

    # 5. 進行前向操作
    
    outputs = self.module(*args, **kwargs)

    # 6. 丟棄多餘模型引數
    if self.reshard_after_forward:
        self._free_full_params()
        if self.mixed_precision:
            self._free_fp16_param_shard()
            
    # 7. 切換到主FP32引數分片。我們在整個程式碼中都保持這個不變數,即在每個函式之後,``p.data == p._fp32_shard``。因為優化器狀態通常在``optim.step()``中延遲初始化,這還確保在第一次forward之後,優化器狀態將使用正確的資料型別和(分片)大小來初始化,   

    # Switch to main FP32 param shard. We maintain this invariant throughout
    # the code, i.e., ``p.data == p._fp32_shard`` after each function. This
    # also ensures that after the first forward, the optimizer state will be
    # initialized with the correct dtype and (sharded) size, since optimizer
    # state is typically initialized lazily in ``optim.step()``.
    self._use_fp32_param_shard()

    # 8. 呼叫 _register_pre_backward_hooks(outputs) 為後向傳播註冊 all-gather
    
    # Register pre-backward hooks to all-gather the params for the backward
    # pass (if output's grad was needed). This won't register anything if
    # we are in eval mode.
    #
    # Some model does forward pass multiple times, we need to register the
    # pre-backward hook on every output since the last output's hook has to
    # fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run``
    # to prevent repeated overhead from multiple hook callbacks.
    outputs = self._register_pre_backward_hooks(outputs) # 為後向傳播註冊 all-gather

    # Done with a forward pass.
    self.training_state = TrainingState.IDLE

    # Only need to clear cache during forward. During backward, the cache is not used.
    # TODO (Min): Future PyTorch versions may provide a way to completely disable this
    #     cache. Update this when that's available.
    if self.clear_autocast_cache:
        torch.clear_autocast_cache()

    return outputs

我們接下來看看每個部分如何實現。

4.1.1 All-gather

self._rebuild_full_params() 會進行前向操作之前的all-gather操作。

# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self._rebuild_full_params()
4.1.1.1 _rebuild_full_params

這裡和OSS一樣,都是同步所有的模型引數,具體邏輯是:

  1. 如果我們已經有完整的引數並且不需要完整的精度,那麼就提前退出。

  2. 設定後續操作使用 all_gather"對應的流。

  3. 進行精度轉換。

  4. 遍歷所有模型引數:

4.1 如果world_size==1,則直接更新,因為只有一個rank。

4.2 把資料從CPU移動到CUDA。

4.3 進行all-gather操作

4.4 用all-gather結果對本地張量進行更新,其會利用p._orig_size儲存的原始資訊進行重建。

@torch.no_grad()
def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
    """
    Gather all shards of params.

    Args:
        force_full_precision (bool, Optional): by default params will be gathered
            in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
            ``True``, in which case they will be gathered in full precision
            (e.g., FP32), possibly in fresh storage. The parameter that's being
            rebuilt will end up in full precision as well.

    Returns:
        A list of tuples, where the first element is the full-sized param
        and the second element is a bool indicating if it's safe for the
        caller to free the full-sized param. This will be ``None`` if
        ``force_full_precision=False`` and the full params are already gathered.
    """
    output_tensors: List[Tuple[torch.Tensor, bool]] = []

    def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
        """
        Helper function to update p.data pointer.

        Args:
            custom_output_tensor (torch.Tensor, Optional): if not None, this
            tensor contains the data we just gathered.
        """
        if custom_output_tensor is not None:
            assert p._is_sharded
            p.data = custom_output_tensor
            output_tensors.append((p.data, True))
        elif not p._is_sharded:
            if self.mixed_precision and not force_full_precision:
                assert p._fp16_shard is not None
                p.data = p._fp16_shard
                output_tensors.append((p.data, True))
            else:
                # Here p.data == p._fp32_shard, so it's not safe to free.
                output_tensors.append((p.data, False))
        else:
            p.data = p._full_param_padded
            output_tensors.append((p.data, True))
        # Trim any padding and reshape to match original size.
        p.data = p.data[: p._orig_size.numel()].view(p._orig_size)

	  # 1. 如果我們已經有完整的引數並且不需要完整的精度,那麼就提前退出。
        
    # Early exit if we already have full params and don't need full precision.
    if self.has_full_params and not force_full_precision:
        for p in self.params:
            update_p_data()
        return output_tensors

    self.has_full_params = True

    # 2. 使用 all_gather"對應的流
    with torch.cuda.stream(self._streams["all_gather"對應的流]):
      
      	# 3. 進行精度轉換
        
        if self.mixed_precision and not force_full_precision:
            self._cast_fp32_param_shards_to_fp16()

        # 4. 遍歷所有模型引數
            
        for p in self.params: 
            if not p._is_sharded:  # e.g., when world_size == 1
                update_p_data() # 4.1 如果world_size==1,則直接更新,因為只有一個rank
            else:
              
                # 4.2 把資料從CPU移動到CUDA
                
                # If self.move_params_to_cpu and force_full_precision, we need to cast
                # the FP32 CPU param to CUDA for the all-gather.
                p_data = p.data.to(p._full_param_padded.device, non_blocking=True)

                p_size = p._full_param_padded.size()
                if self.mixed_precision and force_full_precision:
                    # Allocate fresh tensor in full precision since we are in
                    # mixed precision and full precision rebuild is asked.
                    output_tensor = p_data.new_zeros(p_size)
                else:
                    if p._full_param_padded.storage().size() != p_size.numel():
                        # Allocate based on full size from all shards.
                        alloc_storage_(p._full_param_padded, size=p_size)
                    output_tensor = p._full_param_padded

                # 4.3 進行all-gather操作     
                
                # Fill output_tensor with (p.data for each shard in self.world_size)
                if hasattr(dist, "_all_gather_base"):
                    # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
                    dist._all_gather_base(output_tensor, p_data, group=self.process_group)  # type: ignore
                else:
                    chunks = list(output_tensor.chunk(self.world_size))
                    dist.all_gather(chunks, p_data, group=self.process_group)

                # 4.4 用all-gather結果對本地張量進行更新
                
                # Set p.data = output_tensor (with padding trimmed)
                update_p_data(output_tensor)

                if self.mixed_precision and not force_full_precision:
                    self._free_fp16_param_shard([p])
                    
    torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
    return output_tensors
4.1.1.2 精度操作

_cast_fp32_param_shards_to_fp16 會把 FP32引數分片轉換為一個FP16 引數分片。

@torch.no_grad()
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
    """Cast FP32 param shard to FP16 for a list of params."""
    if params is None:
        params = self.params
    with torch.cuda.stream(self._streams["fp32_to_fp16"]):
        for p in params:
            alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
            p._fp16_shard.copy_(
                # If cpu_offload is True, this will be non-blocking because
                # _fp32_shard is pinned, otherwise it's a no-op.
                p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
            )
            p.data = p._fp16_shard
    torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])

4.1.2 丟棄多餘引數

這部分包括兩種可能,或者丟棄 FP32 引數,比如_free_full_params。

@torch.no_grad()
def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
    """Free up storage for full parameters."""
    if params is None:
        params = self.params
    self.has_full_params = False
    current_stream = torch.cuda.current_stream()
    for p in params:
        if not p._is_sharded:  # e.g., world_size == 1
            if self.mixed_precision:
                self._free_fp16_param_shard([p])
            continue
        # Don't let PyTorch reuse this memory until all work in the current
        # stream is complete.
        p._full_param_padded.record_stream(current_stream)
        # There may be external references to the Tensor Storage that we
        # can't modify, such as references that are created by
        # ctx.save_for_backward in the forward pass. Thus when we
        # unshard parameters, we should reuse the original Tensor
        # Storage object and unshard it in-place. For now, just resize
        # the Storage to 0 to save memory.
        free_storage_(p._full_param_padded)

或者丟棄 FP16 引數。

@torch.no_grad()
def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
    """Free storage for FP16 shards for a list of params."""
    if params is None:
        params = self.params
    current_stream = torch.cuda.current_stream()
    for p in params:
        if p._fp16_shard is not None:
            # _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
            # free it until the work in the current stream completes.
            p._fp16_shard.record_stream(current_stream)
            free_storage_(p._fp16_shard)

4.3 配置backward

這部分功能就是為後向傳播設定,讓其在開始做一個all-gather,結束時候做一個reduce-scatter。這些配置在具體執行邏輯之中就變成了:

  • All-gather :然後在後向傳播之前再次執行此權重收集。就是論文思路Pp之中的下劃線部分
  • Backward(local):本地進行後向操作。前向計算和後向計算都是利用完整模型,此時每個GPU上也都是全部梯度
  • Reduce-scatter :在向後傳播之後,區域性梯度聚合並且通過 reduce-scatter 在各個GPU上分片,每個分片上的梯度是聚合之後本分割槽對應的那部分,就是論文思路Pg之中的下劃線部分。

對應前面程式碼之中就是:

        self._register_post_backward_hooks() # 為後向傳播註冊 reduce-scatter
        outputs = self.module(*args, **kwargs) # 模型前向傳播
        outputs = self._register_pre_backward_hooks(outputs) # 為後向傳播註冊 all-gather

我們接下來一一分析。

4.3.1 _register_post_backward_hooks

_register_post_backward_hooks 是註冊反向傳播之後呼叫的hook,這裡的hook 就是重新分割槽和reduce-scatter操作。我們解讀一下其註釋:

_register_post_backward_hooks 這在前向傳播時被呼叫。目的是在每個引數的梯度生成函式(下文的grad_acc)上附加一個鉤子方法,以便在該引數的所有梯度計算出來後,呼叫該鉤子

我們的目標是:

  1. 我們希望這個鉤子只觸發一次,而且是在該引數的所有梯度被計算之後。
  2. 如果它啟動超過一次,我們最終會錯誤地將梯度分成多次。(可能導致維度太小)。
  3. 如果它啟動一次但太早或沒有啟動,我們就不對梯度進行分片。(可能導致維度過大)。

由於多路前向操作,這個函式可以在一次前向傳播之中對同一個引數進行多次呼叫。如果我們多次註冊hook,此hook最終會被多次呼叫。我們可以嘗試每次獲得一個新的鉤子,並刪除之前註冊的鉤子

然而,由於未知的原因,在混合精度模式下,我們在這個函式的不同呼叫中(在同一個前向傳播中)得到兩個不同的grad_acc物件。如果我們保留最後一個,鉤子就會過早地啟動。在全精度模式下,我們很幸運地得到了相同的grad_acc物件,所以刪除和重新註冊仍然能確保在所有梯度生成後鉤子只啟動一次。

根據經驗,每次前向傳播時維持註冊的第一個鉤子似乎是最有效的。我們也確實需要在後向傳播結束時刪除鉤子。否則,下一個前向傳播將不會註冊一個新的鉤子,而這是新的前向傳播所需要的。

除了註釋之外,這裡還有幾個特殊的技巧:

  • 為何在 grad_fn.next_functions [0] [0] 之上註冊 hook 而非在張量 p 之上直接註冊 hook?

這裡比較複雜,只能簡單說一下,有興趣的讀者可以自己深研究原始碼。

首先,AccumulateGrad 派生了 TraceableFunction,而 TraceableFunction 派生了 Node。

其次,Node 之中,有兩種 hook。

std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;

如果在張量 p 上執行 register_hook,則註冊的是 p.grad_fn 之上的 pre_hooks_,此時還沒有進行梯度計算,所以此時得到的梯度只是梯度函式的輸入,是一個臨時變數,並沒有累加到實際的 grad 記憶體上,所以 tensor 的hook 一般是專門用來觀察臨時梯度的。

如果在AccumulateGrad 上執行 register_hook,則註冊的是 p.grad_fn 之上的 post_hooks_,此時已經完成了梯度計算,此時才能得到梯度,p 和它的 grad 是不會釋放的。

grad_fn 的 hook 是沒有預設傳入引數的,實現 allreduce 一般傳入引數 p,p也就是這個 grad_fn 對應的變數,,所以使用 functools.partial 來構建引數。

所以,為了得到正確的梯度,應該使用 post_hook,就是在梯度函式上直接執行 register_hook。

  • expand_as 的使用。

這是因為在呼叫 _register_post_backward_hooks 時候,還沒有前向計算,所以此時 p 上的梯度函式 grad_fn 還沒有生成。expand_as 的作用就是:可以生成這個梯度函式 grad_fn,而且不會產生實際梯度。

如下程式碼可以演示的更清楚。

a = torch.tensor(1.0, requires_grad=True)
print(a.grad_fn) # None,此時沒有前向計算,所以沒有梯度函式

a_temp = a.expand_as(a) # 雖然沒有前向計算,但是也可以生成梯度函式,而且不產生實際梯度
print(a_temp.grad_fn) # ExpandBackward
print(a_temp.grad_fn.next_functions[0][0]) # AccumulateGrad
  
# 輸出是:
None
<ExpandBackward object at 0x7fef9794e898> 
<AccumulateGrad object at 0x7fef9794e7f0> # 就是在這裡註冊 hook 

注:以上技巧是我從一個朋友 Huper (https://www.zhihu.com/people/huper-52/answers) 那裡學習到的

具體註冊 hook 程式碼如下:

def _register_post_backward_hooks(self) -> None:
    """
    Register backward hooks to reshard params and reduce-scatter grads.

    This is called during forward pass. The goal is to attach a hook
    on each of the parameter's gradient generating function (``grad_acc``
    below) so that the hook is called *after* all gradients for that
    param are computed.

    Goals:

    1. We want the hook to fire once and only once *after* all gradients
    are accumulated for a param.
    2. If it fires more than once, we end up incorrectly shard the grad
    multiple times. (could lead to dimension too small)
    3. If it fires once but too early or doesn't fire, we leave gradients
    unsharded. (could lead to dimension too large)

    Due to multiple-pass forward, this function can be called on
    the same parameter multiple times in a single forward pass. If we register
    the hook multiple time, we end up getting called multiple times. We
    could try to get a new hook every time and delete the previous one
    registered. However, due to *unknown reason* (I have debugged it for
    a long time!), in mixed precision mode, we get two different ``grad_acc``
    objects below during different calls of this function (in the same
    forward pass). If we keep the last one, the hook end up firing too
    early. In full precision mode, we luckily get the *same* ``grad_acc``
    object, so deleting and re-registering still ensured the hook fire
    once after all gradients are generated.

    Empirically, keep the first hook register per forward pass seems to
    work the best. We do need to remove the hook at the end of the
    backward pass. Otherwise, the next forward pass will not register
    a new hook, which is needed for a new forward pass.
    """
    if not torch.is_grad_enabled():
        return  # don't register grad hooks if grad isn't enabled
    for p in self.params:
        if p.requires_grad:
            if hasattr(p, "_shard_bwd_hook"):
                continue
            # Register a hook on the first call, empirically, autograd
            # fires it at the end for this param, which makes sense.
            p_tmp = p.expand_as(p)  # Get a grad_fn on p_tmp.
            assert p_tmp.grad_fn is not None
            grad_acc = p_tmp.grad_fn.next_functions[0][0]  # Gets its GradAccumulation object.
            handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
            p._shard_bwd_hook = (grad_acc, handle)

4.3.2 _post_backward_hook

_post_backward_hook 就是hook函式,其會註冊 _post_reduction_hook 和 self._reducer.reduce_scatter_async。

_post_backward_hook的開始,param.grad包含本地批次的全部梯度。reduce-scatter操作將把param.grad 替換為所有GPU的梯度總和的單一分片。這個分片就是當前rank對應的分片,比如:

    before reduce_scatter:
        param.grad (GPU #0): [1, 2, 3, 4]
        param.grad (GPU #1): [5, 6, 7, 8]

    after reduce_scatter:
        param.grad (GPU #0): [6, 8]    # 1+5, 2+6
        param.grad (GPU #1): [10, 12]  # 3+7, 4+8

本地GPU的optim.step負責更新params的單個分片,也對應於當前GPU的rank。這個對齊方式是由_shard_parameters_建立的,它確保本地優化器只看到相關的引數分片。

以下程式碼刪除了部分檢查功能。

@torch.no_grad()
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
    """
    At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
    full gradient for the local batch. The reduce-scatter op will replace
    ``param.grad`` with a single shard of the summed gradient across all
    GPUs. This shard will align with the current GPU rank. 

    The local GPU's ``optim.step`` is responsible for updating a single
    shard of params, also corresponding to the current GPU's rank. This
    alignment is created by :func:`_shard_parameters_`, which ensures that
    the local optimizer only sees the relevant parameter shard.
    """
    # First hook callback will see PRE state. If we have multiple params,
    # then subsequent hook callbacks will see POST state. When checkpoint
    # fwd counter is used, IDLE is also possible since the pre-backward hook
    # is not triggered (see ``auto_wrap_bn`` below, we have to use
    # FSDP(checkpoint(conv, FSDP(bn), ...)), with reshard_after_forward=False).

    self.training_state = TrainingState.BACKWARD_POST

    # If this is a checkpointed module, we check if the following
    # counter reaches 0. If not, it is not the final backward call
    # for this module yet. Therefore, we early return in that case.
    if hasattr(self._fsdp_wrapped_module, "_checkpoint_fwd_counter"):
        if self._fsdp_wrapped_module._checkpoint_fwd_counter != 0:
            return

    if self._require_backward_grad_sync or self.reshard_after_forward:
        # Free full params. As a special case, we don't free the full params
        # when in a ``no_sync`` context (as inversely indicated by
        # ``self._require_backward_grad_sync``), since the params will not
        # get updated before the next forward. This saves networking
        # bandwidth but uses more GPU memory.
        self._free_full_params([param])

    if self.mixed_precision:
        # This is a no-op if reshard_after_forward is True, since we already
        # free the param shard when rebuilding the full params in the
        # pre_backward_hook.
        self._free_fp16_param_shard([param])

    # Switch to FP32 shard after backward.
    self._use_fp32_param_shard([param])

    # Wait for all work in the current stream to finish, then start the
    # reductions in post_backward stream.
    self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(self._streams["post_backward"]):
        orig_grad_data = param.grad.data

        if self.mixed_precision and self.fp32_reduce_scatter:
            # Cast grad to FP32.
            param.grad.data = param.grad.data.to(param.dtype)

        if self.gradient_predivide_factor > 1:
            # Average grad by world_size for consistency with PyTorch DDP.
            param.grad.data.div_(self.gradient_predivide_factor)

        # 執行reduce-scatter操作    
            
        callback_fn = functools.partial(self._post_reduction_hook, param)
        if param._is_sharded:
            grad_chunks = chunk_and_pad(param.grad.data, self.world_size)
            self._reducer.reduce_scatter_async(grad_chunks, group=self.process_group, callback_fn=callback_fn)
        else:
            # Currently the only way for _is_sharded to be False is if
            # world_size == 1. This could be relaxed in the future, in which
            # case grads should be all-reduced here.
            callback_fn(param.grad.data)

        # After _post_backward_hook returns, orig_grad_data will eventually
        # go out of scope, at which point it could otherwise be freed for
        # further reuse by the main stream while the div/reduce_scatter/copy
        # are underway in the post_backward stream. See:
        # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
        orig_grad_data.record_stream(self._streams["post_backward"])

4.3.3 Reduce Scatter

ReduceScatterBucketer 用於將小張量上的多個reduce-scatter操作集中到較大的reduce-scatter操作中,以提高通訊效率。非同步地對張量列表進行Reduce-scatter 可以讓較小的reductions可以被集中在一起操作。給定的回撥(callback_fn)將在稍後的某個時間被呼叫,並得到規約的結果。可以呼叫flush()來強制執行所有排隊的操作和回撥。請注意,大輸入將被立即規約,這個函式也可能重新整理相關的桶,以便為input_list騰出空間。

class ReduceScatterBucketer:
    """
    Helper for bucketing multiple reduce-scatter operations on small tensors
    into larger reduce-scatter ops to improve communication efficiency.
    """

    def __init__(self, bucket_cap_mb: int = 25):
        self.bucket_cap_mb = bucket_cap_mb
        self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}

    @torch.no_grad()
    def reduce_scatter_async(
        self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None,
    ) -> None:
        """
        Reduce-scatter a list of tensors asynchronously, so smaller reductions
        can be bucketed together. The given callback (``callback_fn``) will be
        called with the reduced result at some later time. Call ``flush()`` to
        force all queued ops and callbacks to be executed.

        Note that large inputs will be reduced immediately, and this function
        may also flush the relevant bucket to make room for ``input_list``.

        Args:
            input_list (List[Tensor]): list of tensors to reduce-scatter. List
                should contain ``group.size()`` tensors and each tensor should
                have identical shape, dtype and device.
            group (ProcessGroup): process group for reduction
            callback_fn (Callable, Optional): callback function to call after
                the reduction executes. Function will be called with a single
                argument corresponding to the reduced result.
        """
        world_size = group.size()
        first_input = input_list[0]
        first_input_size = first_input.numel()

        bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
        if first_input_size > bucket_shard_size:
            # input is too big to fit in the bucket, reduce-scatter directly
            output = torch.zeros_like(input_list[0])
            if hasattr(dist, "_reduce_scatter_base"):
                input_flattened = torch.cat(input_list)
                dist._reduce_scatter_base(output, input_flattened, group=group)  # type: ignore
            else:
                # fallback
                dist.reduce_scatter(output, input_list, group=group)
            if callback_fn is not None:
                callback_fn(output)
            return

        bucket = self._get_bucket(first_input, group)
        if first_input_size > bucket.data.size(1) - bucket.offset:
            # not enough space remaining in bucket, flush it now
            bucket.flush()

        # copy data from input_list into bucket
        stacked_input = torch.stack(input_list).view(world_size, first_input_size)
        offset = bucket.offset
        bucket.data[:, offset : offset + first_input_size].copy_(stacked_input)
        bucket.offset += first_input_size

        # callback will be given the reduced result
        if callback_fn is not None:
            result_view = bucket.output_shard[offset : offset + first_input_size].view_as(first_input)
            bucket.callbacks.append(functools.partial(callback_fn, result_view))

4.3.4 _register_pre_backward_hooks

這裡註冊了一個後向傳播之前會呼叫的hook,hook 之中呼叫了 _rebuild_full_params,其內部會呼叫 all-gather。因為這個必須在最終output之上註冊,所以 _register_pre_backward_hooks 是在前向傳播最後部分才呼叫。

def _register_pre_backward_hooks(self, outputs: Any) -> Any:
    """Register pre-backward hook to run before the wrapped module's
    backward. Hooks should be attached to all outputs from the forward.

    Returns:
        outputs: new outputs with hooks registered if they requires gradient.
    """
    if not torch.is_grad_enabled():
        return outputs  # don't register hooks if grad isn't enabled

    if self._is_root:
        # This actually means that only root instance has
        # _post_backward_callback_queued defined. Accidentally accessing this field
        # will assert on all other instances, giving us a nice bug checker.
        self._post_backward_callback_queued = False

    def _pre_backward_hook(*unused: Any) -> None:
        # try to queue final backward callback only once for root, so
        # that final backward callback is attached to the outer most
        # backward graph task and called after all the backward
        # calls are completed.
        if self._is_root:
            self._queue_wait_for_post_backward()

        if self._pre_backward_hook_has_run:
            return  # only run once (from multiple outputs or multiple forward passes)
        self._pre_backward_hook_has_run = True

        # Start of a backward pass.
        self.training_state = TrainingState.BACKWARD_PRE

        # All-gather full parameters.
        if self.reshard_after_forward:
            self._rebuild_full_params() # 這裡呼叫 all-gather
        else:
            self._use_full_params()

        # Prepare p.grad.
        self._prep_grads_for_backward()

    def _register_hook(t: torch.Tensor) -> torch.Tensor:
        if t.requires_grad:
            t.register_hook(_pre_backward_hook)
        return t

    # Attach hooks to Tensor outputs.
    outputs = apply_to_tensors(_register_hook, outputs)

    return outputs

執行邏輯大致如下:

+---------------------------------------------+     +--------------------------------------------------+
| forward                                     |     | backward                                         |
|                                             |     |                                                  |
| +                                           |     |                                                  |
| |       all_gather()                        |     |                ^                               ^ |
| |           +                               |     |                |                               | |
| |           |                               |     |                |                               | |
| |           |                               |     |                |                               | |
| |           |                               +     +                |                               | |
| |           v                               register               |                               | |
| | _register_post_backward_hooks()  +--------+-----+--> _post_backward_hook() +--> reduce_scatter() | |
| |                                           |     |                                                | |
| |           +                               |     |                ^                               | |
| |           |                               |     |                |                               | |
| |           |                               |     |                |                               | |
| |           v                               |     |                +                               | |
| |    outputs = self.module(*args, **kwargs) |     |         compute gradient                       | |
| |           +                               |     |                                                | |
| |           |                               |     |                ^                               | |
| |           |                               |     |                |                               | |
| |           |                               +     +                |                               | |
| |           v                               register               +                               | |
| | _register_pre_backward_hooks(outputs) +---+-----+--> _pre_backward_hook() +---> all_gather()     | |
| v                                           |     |                                                + |
|                                             |     |                                                  |
| Timeline                                    |     |                                         Timeline |
|                                             |     |                                                  |
+---------------------------------------------+     +--------------------------------------------------+

手機如下:

至此,我們介紹了FSDP如何對模型引數分片以減少視訊記憶體開銷,下一篇我們看看Offload如何進一步節約視訊記憶體。

0xFF 參考

https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html

https://developer.nvidia.com/automatic-mixed-precision

https://blogs.nvidia.com/blog/2019/11/15/whats-the-difference-between-single-double-multi-and-mixed-precision-computing/

https://www.paddlepaddle.org.cn/documentation/docs/en/advanced_guide/performance_improving/amp/amp.html

https://on-demand.gputechconf.com/gtc-taiwan/2018/pdf/5-1_Internal Speaker_Michael Carilli_PDF For Sharing.pdf

http://bindog.github.io/blog/2020/04/12/model-parallel-with-apex/

https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html

Optimizer state sharding (ZeRO)

相關文章