在訓練神經網路時,用到的資料量可能很大,訓練週期較長,如果半途中斷了訓練,下次從頭訓練就會很費時間,這時我們就想斷點續訓。
一、神經網路模型的儲存,基本兩種方式:
1. 儲存完整模型model, torch.save(model, save_path)
2. 只儲存模型的引數, torch.save(model.state_dict(), save_path) ,多卡訓練的話,在儲存引數時,使用 model.module.state_dict( ) 。
二、儲存模型訓練的斷點checkpoint
斷點dictionary中一般儲存訓練的網路的權重引數、最佳化器的狀態、學習率變化scheduler 的狀態以及epoch 。
checkpoint = {'parameter': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': epoch}
torch.save(checkpoint, './models/checkpoint/ckpt_{}.pth'.format(epoch+1))
三、載入斷點繼續訓練
if resume: # True
load_ckpt = torch.load(ckpt_dir, map_location=device) # 從斷點路徑載入斷點,指定載入到CPU記憶體或GPU
load_weights_dict = {k: v for k, v in load_ckpt['parameter'].items()
if model.state_dict()[k].numel() == v.numel()} # 簡單驗證
model.load_state_dict(load_weights_dict, strict=False)
# 如果是多卡訓練,載入weights後要設定DDP模式,其後先定義一下optimizer和scheduler,之後再載入斷點中儲存的optimizer和scheduler以及設定epoch,
optimizer.load_state_dict(load_ckpt['optimizer']) # 載入最佳化器狀態
scheduler.load_state_dict(load_ckpt['scheduler']) # 載入scheduler狀態
start_epoch = load_ckpt['epoch']+1 # 設定繼續訓練的epoch起點
iter_epochs = range(start_epoch, args.epochs) # arg.epochs指出訓練的總epoch數,包括斷點前的訓練次數
Enjoy it!