PyTorch如何恢復指定權重

weixin_34292287發表於2018-08-28

1. 如何從已訓練好的網路模型中提取指定層權重

import torch 
# vgg為官方提供的model
# https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
import vgg

model = torch.load('logs/vgg16.pkl') 

restore_param = ['classifier.2.bias']
# 當然 如果你的目的是不想匯入某些層的權重,將下述程式碼改為`if not k in restore_param`
restore_param = {v for k, v in model.state_dict().items() if k in restore_param}
print(restore_param)


------>:
{tensor([-0.0048,  0.0048], device='cuda:0')}

2. 如何載入模型部分引數並更新

import torch
import vgg

model = torch.load('logs/vgg16.pkl')
vgg16 = vgg.vgg16().cuda()
vgg16_dict = vgg16.state_dict()
for k, v in vgg16_dict.items():
    print(v)

print()
print('##################################################################################')
print()

restore = ['classifier.2.bias']
restore_param = {k: v for k, v in model.state_dict().items() if k in restore}
vgg16_dict.update(restore_param)
for k, v in vgg16_dict.items():
    print(v)


------>:
tensor([[[[-0.0198,  0.0425, -0.0221],
          [ 0.0636,  0.0193, -0.0661],
          [-0.0035,  0.0031, -0.0395]],

         [[-0.0525,  0.0796,  0.0263],
          [-0.0669,  0.1537,  0.1025],
          [ 0.0002, -0.0456, -0.0086]],

         [[-0.0344,  0.0566, -0.0090],
          [ 0.0915,  0.0133, -0.0007],
          [-0.0228, -0.0143,  0.0841]]],
...
tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')
tensor([[ 2.7670e-03, -1.6860e-02, -6.6972e-03,  ...,  6.7144e-03,
         -7.2912e-03,  2.0684e-03],
        [ 4.2978e-03, -9.8524e-03,  1.2163e-02,  ...,  6.3420e-03,
         -5.1077e-03,  6.4550e-03]], device='cuda:0')
tensor([0., 0.], device='cuda:0')

##################################################################################

tensor([[[[-0.0198,  0.0425, -0.0221],
          [ 0.0636,  0.0193, -0.0661],
          [-0.0035,  0.0031, -0.0395]],

         [[-0.0525,  0.0796,  0.0263],
          [-0.0669,  0.1537,  0.1025],
          [ 0.0002, -0.0456, -0.0086]],

         [[-0.0344,  0.0566, -0.0090],
          [ 0.0915,  0.0133, -0.0007],
          [-0.0228, -0.0143,  0.0841]]],
...
tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')
tensor([[ 2.7670e-03, -1.6860e-02, -6.6972e-03,  ...,  6.7144e-03,
         -7.2912e-03,  2.0684e-03],
        [ 4.2978e-03, -9.8524e-03,  1.2163e-02,  ...,  6.3420e-03,
         -5.1077e-03,  6.4550e-03]], device='cuda:0')
tensor([-0.0048,  0.0048], device='cuda:0')

可以發現classifier.2.bias的值由[0., 0.]變為了[-0.0048, 0.0048]

參考文章

相關文章