torch.save(),torch.load(),state_dict(),load_state_dict()

做梦当财神發表於2024-04-16

這些函式是PyTorch中用於模型儲存和載入的重要函式。下面是對它們的詳細解析:

  1. torch.save(obj, file):

    • 作用:將PyTorch模型儲存到檔案中。

    • 引數:

      • obj: 要儲存的物件,可以是模型、張量或字典。
      • file: 要儲存到的檔案路徑。
    • 示例:

      torch.save(model.state_dict(), 'model.pth')
      
  2. torch.load(file):

    • 作用:從檔案中載入儲存的PyTorch模型。

    • 引數:

      • file: 要載入的檔案路徑。
    • 返回值:載入的物件。

    • 示例:

      model.load_state_dict(torch.load('model.pth'))
      
  3. state_dict():

    • 作用:返回包含模型所有引數的字典物件。

    • 示例:

      model_state = model.state_dict()
      
  4. load_state_dict(state_dict, strict=True):

    • 作用:載入預訓練的引數字典到模型中。

    • 引數:

      • state_dict: 要載入的引數字典。
      • strict(可選): 如果為True(預設值),則要求state_dict中的鍵與模型的引數名完全匹配。
    • 示例:

      model.load_state_dict(torch.load('pretrained.pth'))
      

這些函式在訓練過程中非常有用,可以幫助儲存模型的狀態以及載入預訓練的引數,使得模型的訓練和部署更加方便。