Pytorch系列:(四)IO操作

Q發表於2021-05-06

首先注意pytorch中模型儲存有兩種格式,pth和pkl,其中,pth是pytorch預設格式,pkl還支援pickle庫,不過一般如果沒有特殊需求的時候,推薦使用預設pth格式儲存

pytorch中有兩種資料儲存方法,一種是儲存整個模型,一種只儲存引數

方法一:儲存整個模型

#儲存

torch.save(model1, 'net.pth')

#讀取

model1 = torch.load('net.pth')

方法二:儲存模型引數

#儲存

torch.save(model.state_dict(), 'checkpoint.pth')

#提取

state_dict = torch.load('checkpoint.pth')

model.load_state_dict(state_dict)

state_dict說明

state_dict 包含了模型使用的所有引數(Parameter型別),如果自定義的模型引數沒有用Parameter封裝,那麼不會出現在state_dict中, 所以使用的時候,自定義引數一定不要忘記使用Parameter進行封裝。

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.w1 = torch.randn(10,2)
        self.w2 = nn.Parameter(torch.randn(2,1))
        self.l1 = nn.Linear(10,1)

    def forward(self,x):
        pass 


net = MLP()

net.state_dict()

輸出,可以發現只有w2和l1

OrderedDict([('w2',
              tensor([[0.9826],
                      [0.4665]])),
             ('l1.weight',
              tensor([[ 0.3098,  0.0985, -0.2566, -0.1024,  0.0449, -0.1681, -0.1743,  0.2985,
                       -0.0644, -0.0181]])),
             ('l1.bias', tensor([-0.2871]))])

中間狀態儲存

在訓練的時候,可以儲存訓練中的中間狀態,只需要把引數都儲存到state字典中就可以了。 例如,在斷點續傳任務中,可以把epoch,模型狀態,優化器狀態,初始learning rate 等進行儲存。

state = {
          'state_dict': net.state_dict(),
          'optimizer': optim.optimizer.state_dict(),
          'lr_base': optim.lr_base
          'epoch': epoch
        }
            
torch.save(
            state,
            self.CKPTS_PATH +
            'ckpt_' + self.VERSION +
            '/epoch'+ str(epoch) +
            '.pkl'
          )

載入

state = torch.load(
                    self.CKPTS_PATH +
                    'ckpt_' + self.VERSION +
                    '/epoch'+ str(epoch) +
                    '.pkl'
                   )  

 
net.load_state_dict(state['state_dict'])

optim.optimizer.load_state_dict(state['optimizer'])
optim.lr_base = state['lr_base']
start_epoch = state['epoch']

相關文章