【PyTorch Lightning】簡介

棠雪清芬發表於2020-10-24

目錄

一、Lightning 簡約哲學

1.1 研究程式碼 (Research code)

1.2 工程程式碼 (Engineering code)

1.3 非必要程式碼 (Non-essential code)

二、典型的 AI 研究專案

三、生命週期

四、使用 Lightning 的好處


PyTorch 已足夠簡單易用,但簡單易用不等於方便快捷。特別是做大量實驗時,很多東西都會變得複雜,程式碼也會變得龐大,此時就容易出錯。

針對該問題,就有了 PyTorch Lightning。它可以重構你的 PyTorch 程式碼,抽出複雜重複部分,讓你專注於核心的構建,讓你的實驗更快速更便捷地開展迭代。


一、Lightning 簡約哲學

大部分的 DL/ML 程式碼都可以分為以下這三部分:

  • 研究程式碼 Research code
  • 工程程式碼 Engineering code
  • 非必要程式碼 Non-essential code

1.1 研究程式碼 (Research code)

這部分屬於模型(神經網路)部分,一般處理模型的結構、訓練等定製化部分。

在 Linghtning 中,這部分程式碼抽象為 LightningModule 類。

1.2 工程程式碼 (Engineering code)

這部分程式碼很重要的特點是:重複性強,如設定 early stopping、16位精度、GPUs 分佈訓練等。

在 Linghtning 中,這部分抽象為 Trainer 類。

1.3 非必要程式碼 (Non-essential code)

這部分程式碼有利於實驗的進行,但和實驗沒有直接關係,甚至可以不用。如檢查梯度、向 tensorboard 輸出 log。

在 Linghtning 中,這部分抽象為 Callbacks 類。


二、典型的 AI 研究專案

在大多數研究專案中,研究程式碼 通常可以歸納到以下關鍵部分:

  • 模型
  • 訓練/驗證/測試 資料
  • 優化器
  • 訓練/驗證/測試 計算

上面已經提到,研究程式碼 在 Lightning 中,是抽象為 LightningModule 類;而該類與我們平時在 PyTorch 中使用的 torch.nn.Module 是一樣的 (在原有程式碼中直接替換 Module 而不改其他程式碼也可以)。但不同的是,Lightning 圍繞 torch.nn.Module 做了很多功能性的補充,把上面 4 個關鍵部分都囊括了進來。

如此設定的意義在於:我們的 研究程式碼 都是圍繞 神經網路模型 來執行的,所以 Lightning 把這部分程式碼都集合在一個類裡。所以接下來的介紹,都圍繞 LightningModule 類來展開。


三、生命週期

為先呈現一個總體的概念,此處先介紹 LightningModule 中執行的生命流程。

以下所有函式都在 LightningModule 類中。

這部分是訓練開始之後的執行 “一般(預設)順序”

  • 首先是準備工作,包括初始化 LightningModule,準備資料 和 配置優化器。

這部分程式碼 只執行一次

1. `__init__()`(初始化 LightningModule )
2. `prepare_data()` (準備資料,包括下載資料、預處理等等)
3. `configure_optimizers()` (配置優化器)
  • 測試 “驗證程式碼”。

提前來做的意義在於:不需要等待漫長的訓練過程才發現驗證程式碼有錯。

這部分就是提前執行 “驗證程式碼”,所以和下面的驗證部分是一樣的。

1. `val_dataloader()`
2. `validation_step()`
3. `validation_epoch_end()`
  • 開始載入dataloader,用來給訓練載入資料

1. `train_dataloader()` 
2. `val_dataloader()` (如果你定義了)
  • 下面部分就是迴圈訓練了,_step() 指按 batch 進行的部分;_epoch_step() 指所有 batch 執行完後 (一個 epoch) 要進行的部分。

# 迴圈訓練與驗證
1. `training_step()`
2. `validation_step()`
3. `validation_epoch_end()`
  • 最後訓練完了,就要進行測試,但測試部分需手動呼叫 .test(),以避免誤操作。

# 測試(需要手動呼叫)
1. `test_dataloader()` 
2. `test_step()`
3. `test_epoch_end()`

不難總結,在訓練部分,主要包含三部分:_dataloader / _step / _epoch_end。Lightning 把訓練的三部分抽象成三個函式,而使用者只需要“填鴨式”地補充這三部分,就可以完成模型訓練部分程式碼的編寫。

為更清晰地展現這三部分的具體位置,以下用 PyTorch 實現方式 來展現其位置。

for epoch in epochs:
    for batch in train_dataloader:
        # train_step
        # ....
        # train_step
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    for batch in val_dataloader:
        # validation_step
        # ....
        # validation_step
    
    # *_step_end
    # ....
    # *_step_end

四、使用 Lightning 的好處

  • 只需專注於 研究程式碼

不需要寫一大堆的 .cuda() 和 .to(device),Lightning 會自動處理。如果要新建一個 tensor,可以使用 type_as 來使得新tensor 處於相同的處理器上。

def training_step(self, batch, batch_idx):
    x, y = batch

    # 把z放在和x一樣的處理器上
    z = sample_noise()
    z = z.type_as(x)

此處需要注意的是,不是所有的在 LightningModule 的 tensor 都會被自動處理,而是只有從 Dataloader 裡獲取的 tensor 才會被自動處理,所以對於 transductive learning 的訓練,最好自己寫 Dataloader 的處理函式。 

  • 工程程式碼引數化

平時寫模型訓練時,這部分程式碼會不斷重複,但又不得不做,比如 early stopping、精度調整、視訊記憶體記憶體間資料轉移等。這部分程式碼雖然不難,但減少這部分程式碼會使得 研究程式碼 更加清晰,整體也更加簡潔。

下面是簡單的展示,表示使用 LightningModule 建立好模型後,如何進行訓練。

model = LightningModuleClass()
trainer = pl.Trainer(gpus="0",  # 用來配置使用什麼GPU
                     precision=32, # 用來配置使用什麼精度,預設是32
                     max_epochs=200 # 迭代次數
                     )

trainer.fit(model)  # 開始訓練
trainer.test()  # 訓練完之後測試

參考文獻

https://zhuanlan.zhihu.com/p/120331610

https://pytorch-lightning.readthedocs.io/en/latest/

相關文章