首先注意pytorch中模型儲存有兩種格式,pth和pkl,其中,pth是pytorch預設格式,pkl還支援pickle庫,不過一般如果沒有特殊需求的時候,推薦使用預設pth格式儲存
pytorch中有兩種資料儲存方法,一種是儲存整個模型,一種只儲存引數
方法一:儲存整個模型
#儲存
torch.save(model1, 'net.pth')
#讀取
model1 = torch.load('net.pth')
方法二:儲存模型引數
#儲存
torch.save(model.state_dict(), 'checkpoint.pth')
#提取
state_dict = torch.load('checkpoint.pth')
model.load_state_dict(state_dict)
state_dict說明
state_dict 包含了模型使用的所有引數(Parameter型別),如果自定義的模型引數沒有用Parameter封裝,那麼不會出現在state_dict中, 所以使用的時候,自定義引數一定不要忘記使用Parameter進行封裝。
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.w1 = torch.randn(10,2)
self.w2 = nn.Parameter(torch.randn(2,1))
self.l1 = nn.Linear(10,1)
def forward(self,x):
pass
net = MLP()
net.state_dict()
輸出,可以發現只有w2和l1
OrderedDict([('w2',
tensor([[0.9826],
[0.4665]])),
('l1.weight',
tensor([[ 0.3098, 0.0985, -0.2566, -0.1024, 0.0449, -0.1681, -0.1743, 0.2985,
-0.0644, -0.0181]])),
('l1.bias', tensor([-0.2871]))])
中間狀態儲存
在訓練的時候,可以儲存訓練中的中間狀態,只需要把引數都儲存到state字典中就可以了。 例如,在斷點續傳任務中,可以把epoch,模型狀態,優化器狀態,初始learning rate 等進行儲存。
state = {
'state_dict': net.state_dict(),
'optimizer': optim.optimizer.state_dict(),
'lr_base': optim.lr_base
'epoch': epoch
}
torch.save(
state,
self.CKPTS_PATH +
'ckpt_' + self.VERSION +
'/epoch'+ str(epoch) +
'.pkl'
)
載入
state = torch.load(
self.CKPTS_PATH +
'ckpt_' + self.VERSION +
'/epoch'+ str(epoch) +
'.pkl'
)
net.load_state_dict(state['state_dict'])
optim.optimizer.load_state_dict(state['optimizer'])
optim.lr_base = state['lr_base']
start_epoch = state['epoch']