PyTorch儲存模型斷點以及載入斷點繼續訓練

BooTurbo發表於2023-04-27

 

 

 

在訓練神經網路時,用到的資料量可能很大,訓練週期較長,如果半途中斷了訓練,下次從頭訓練就會很費時間,這時我們就想斷點續訓。

一、神經網路模型的儲存,基本兩種方式:
1. 儲存完整模型model, torch.save(model, save_path) 

2. 只儲存模型的引數, torch.save(model.state_dict(), save_path) ,多卡訓練的話,在儲存引數時,使用 model.module.state_dict( ) 。

二、儲存模型訓練的斷點checkpoint

斷點dictionary中一般儲存訓練的網路的權重引數、最佳化器的狀態、學習率變化scheduler 的狀態以及epoch 。

checkpoint = {'parameter': model.module.state_dict(),
              'optimizer': optimizer.state_dict(),
              'scheduler': scheduler.state_dict(),
              'epoch': epoch}
torch.save(checkpoint, './models/checkpoint/ckpt_{}.pth'.format(epoch+1))

三、載入斷點繼續訓練

if resume:                                                                            # True
load_ckpt = torch.load(ckpt_dir, map_location=device)                                 # 從斷點路徑載入斷點,指定載入到CPU記憶體或GPU
load_weights_dict = {k: v for k, v in load_ckpt['parameter'].items()
                                      if model.state_dict()[k].numel() == v.numel()}  # 簡單驗證
model.load_state_dict(load_weights_dict, strict=False) 

# 如果是多卡訓練,載入weights後要設定DDP模式,其後先定義一下optimizer和scheduler,之後再載入斷點中儲存的optimizer和scheduler以及設定epoch,
optimizer.load_state_dict(load_ckpt[
'optimizer']) # 載入最佳化器狀態 scheduler.load_state_dict(load_ckpt['scheduler']) # 載入scheduler狀態
start_epoch
= load_ckpt['epoch']+1 # 設定繼續訓練的epoch起點 iter_epochs = range(start_epoch, args.epochs) # arg.epochs指出訓練的總epoch數,包括斷點前的訓練次數

 

 

 

 

 

Enjoy it!

相關文章