【PyTorch Lightning】簡介
目錄
1.3 非必要程式碼 (Non-essential code)
PyTorch 已足夠簡單易用,但簡單易用不等於方便快捷。特別是做大量實驗時,很多東西都會變得複雜,程式碼也會變得龐大,此時就容易出錯。
針對該問題,就有了 PyTorch Lightning。它可以重構你的 PyTorch 程式碼,抽出複雜重複部分,讓你專注於核心的構建,讓你的實驗更快速更便捷地開展迭代。
一、Lightning 簡約哲學
大部分的 DL/ML 程式碼都可以分為以下這三部分:
- 研究程式碼 Research code
- 工程程式碼 Engineering code
- 非必要程式碼 Non-essential code
1.1 研究程式碼 (Research code)
這部分屬於模型(神經網路)部分,一般處理模型的結構、訓練等定製化部分。
在 Linghtning 中,這部分程式碼抽象為 LightningModule 類。
1.2 工程程式碼 (Engineering code)
這部分程式碼很重要的特點是:重複性強,如設定 early stopping、16位精度、GPUs 分佈訓練等。
在 Linghtning 中,這部分抽象為 Trainer 類。
1.3 非必要程式碼 (Non-essential code)
這部分程式碼有利於實驗的進行,但和實驗沒有直接關係,甚至可以不用。如檢查梯度、向 tensorboard 輸出 log。
在 Linghtning 中,這部分抽象為 Callbacks 類。
二、典型的 AI 研究專案
在大多數研究專案中,研究程式碼 通常可以歸納到以下關鍵部分:
- 模型
- 訓練/驗證/測試 資料
- 優化器
- 訓練/驗證/測試 計算
上面已經提到,研究程式碼 在 Lightning 中,是抽象為 LightningModule 類;而該類與我們平時在 PyTorch 中使用的 torch.nn.Module 是一樣的 (在原有程式碼中直接替換 Module 而不改其他程式碼也可以)。但不同的是,Lightning 圍繞 torch.nn.Module 做了很多功能性的補充,把上面 4 個關鍵部分都囊括了進來。
如此設定的意義在於:我們的 研究程式碼 都是圍繞 神經網路模型 來執行的,所以 Lightning 把這部分程式碼都集合在一個類裡。所以接下來的介紹,都圍繞 LightningModule 類來展開。
三、生命週期
為先呈現一個總體的概念,此處先介紹 LightningModule 中執行的生命流程。
以下所有函式都在 LightningModule 類中。
這部分是訓練開始之後的執行 “一般(預設)順序”。
-
首先是準備工作,包括初始化 LightningModule,準備資料 和 配置優化器。
這部分程式碼 只執行一次。
1. `__init__()`(初始化 LightningModule )
2. `prepare_data()` (準備資料,包括下載資料、預處理等等)
3. `configure_optimizers()` (配置優化器)
-
測試 “驗證程式碼”。
提前來做的意義在於:不需要等待漫長的訓練過程才發現驗證程式碼有錯。
這部分就是提前執行 “驗證程式碼”,所以和下面的驗證部分是一樣的。
1. `val_dataloader()`
2. `validation_step()`
3. `validation_epoch_end()`
-
開始載入dataloader,用來給訓練載入資料
1. `train_dataloader()`
2. `val_dataloader()` (如果你定義了)
-
下面部分就是迴圈訓練了,_step() 指按 batch 進行的部分;_epoch_step() 指所有 batch 執行完後 (一個 epoch) 要進行的部分。
# 迴圈訓練與驗證
1. `training_step()`
2. `validation_step()`
3. `validation_epoch_end()`
-
最後訓練完了,就要進行測試,但測試部分需手動呼叫 .test(),以避免誤操作。
# 測試(需要手動呼叫)
1. `test_dataloader()`
2. `test_step()`
3. `test_epoch_end()`
不難總結,在訓練部分,主要包含三部分:_dataloader / _step / _epoch_end。Lightning 把訓練的三部分抽象成三個函式,而使用者只需要“填鴨式”地補充這三部分,就可以完成模型訓練部分程式碼的編寫。
為更清晰地展現這三部分的具體位置,以下用 PyTorch 實現方式 來展現其位置。
for epoch in epochs:
for batch in train_dataloader:
# train_step
# ....
# train_step
loss.backward()
optimizer.step()
optimizer.zero_grad()
for batch in val_dataloader:
# validation_step
# ....
# validation_step
# *_step_end
# ....
# *_step_end
四、使用 Lightning 的好處
- 只需專注於 研究程式碼
不需要寫一大堆的 .cuda() 和 .to(device),Lightning 會自動處理。如果要新建一個 tensor,可以使用 type_as 來使得新tensor 處於相同的處理器上。
def training_step(self, batch, batch_idx):
x, y = batch
# 把z放在和x一樣的處理器上
z = sample_noise()
z = z.type_as(x)
此處需要注意的是,不是所有的在 LightningModule 的 tensor 都會被自動處理,而是只有從 Dataloader 裡獲取的 tensor 才會被自動處理,所以對於 transductive learning 的訓練,最好自己寫 Dataloader 的處理函式。
-
工程程式碼引數化
平時寫模型訓練時,這部分程式碼會不斷重複,但又不得不做,比如 early stopping、精度調整、視訊記憶體記憶體間資料轉移等。這部分程式碼雖然不難,但減少這部分程式碼會使得 研究程式碼 更加清晰,整體也更加簡潔。
下面是簡單的展示,表示使用 LightningModule 建立好模型後,如何進行訓練。
model = LightningModuleClass()
trainer = pl.Trainer(gpus="0", # 用來配置使用什麼GPU
precision=32, # 用來配置使用什麼精度,預設是32
max_epochs=200 # 迭代次數
)
trainer.fit(model) # 開始訓練
trainer.test() # 訓練完之後測試
參考文獻
相關文章
- pytorch簡介PyTorch
- 簡單介紹pytorch中log_softmax的實現PyTorch
- 簡單介紹PyTorch中in-place operation的含義PyTorch
- pytorch 包介紹PyTorch
- PyTorch 介紹 | DATSETS & DATALOADERSPyTorch
- 簡單介紹Pytorch實現WGAN用於動漫頭像生成PyTorch
- Pytext 簡介——Facebook 基於 PyTorch 的自然語言處理 (NLP) 框架PyTorch自然語言處理框架
- pytorch-torch.nn介紹PyTorch
- TiDB EcoSystem Tools 原理解讀系列(二)TiDB-Lightning Toolset 介紹TiDB
- PyTorch 介紹 | AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRADPyTorch
- 想讀讀PyTorch底層程式碼?這份核心機制簡介送給你PyTorch
- 簡單使用PyTorch搭建GAN模型PyTorch模型
- [PyTorch 學習筆記] 5.1 TensorBoard 介紹PyTorch筆記ORB
- 簡介
- Jira使用簡介 HP ALM使用簡介
- USB-A, Micro, lightning and USB-C
- 最快影片轉繪-AnimateDiff-Lightning
- Pytorch 實現簡單線性迴歸PyTorch
- BookKeeper 介紹(1)--簡介
- ggml 簡介
- PCIe簡介
- valgrind簡介
- SpringMVC簡介SpringMVC
- HTML 簡介HTML
- 核心簡介
- DPDK簡介
- Docker簡介Docker
- SpotBugs 簡介
- webservice簡介Web
- OME 簡介
- Spring 簡介Spring
- 【QCustomPlot】簡介
- DuckDB簡介
- SDL簡介
- swagger簡介Swagger
- MongoDb簡介MongoDB
- RabbitMQ簡介MQ
- JetCache 簡介