全面解析Pytorch框架下模型儲存,載入以及凍結

ZhiboZhao發表於2021-07-01

最近在做試驗中遇到了一些深度網路模型載入以及儲存的問題,因此整理了一份比較全面的在 PyTorch 框架下有關模型的問題。首先我們們先定義一個網路來進行後續的分析:

1、本文通用的網路模型

import torch
import torch.nn as nn
'''
定義網路中第一個網路模組 Net1
'''
class Net1(nn.Module):
    def __init__(self):
        super().__init__()
        
        # input size [B, 1, 3, 3] ==> [B, 1, 3, 3]
        self.n = nn.Conv2d(1, 2, 3, padding=1)
    def forward(self, x):
        x = self.n(x)
        return x
'''
定義網路中第二個網路模組 Net2
'''
class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.n = nn.Sequential(
            # input size [B, 1, 3, 3] ==> [B, 2, 3, 3]
            nn.Conv2d(2, 2, 3, padding=1),
            
            # input size [B, 2, 3, 3] ==> [B, 1, 1, 1]
            nn.Conv2d(2, 1, 3, padding=0),
            )
    def forward(self, x):
        x = self.n(x)
        return x
'''
定義網路中主網路模組 Network
'''
class Network(nn.Module):
    def __init__(self):
        super().__init__()     
        self.head = Net1()
        self.tail = Net2()   
    def forward(self, x):
        x = self.head(x)
        x = self.tail(x)
        return x

網路模組已經搭建好,我們先例項化一個模型然後列印看一下網路結構是否正確:

model = Network()	# 例項化網路模型
print(model)	# 輸出網路結構
Input = torch.randn(1,1,3,3)	# 自定義資料輸入
Output = model(Input)	# 計算網路輸出
print("Input 的維度為:{},Output 的維度為:{}".format(Input.shape, Output.shape))

則輸出結果為:

Network(
  (head): Net1(
    (n): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (tail): Net2(
    (n): Sequential(
      (0): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1))
    )
  )
)
Input 的維度為:torch.Size([1, 1, 3, 3]),Output 的維度為:torch.Size([1, 1, 1, 1])

從輸出結果看,網路包含兩個子模組 headtail,這兩個子模組分別是類 Net1Net2 的例項化物件。在 Net2 的定義中,使用了 nn.Sequential() 函式,它能夠將包含在裡面的網路按照輸入順序進行組合,封裝成一個新的模組,適用於網路中大量重複的結構,比如 Conv-ReLU-Conv 等模組。

2、對模型進行訓練得到權重

我們先對網路做一個簡單的訓練,訓練程式碼如下:

model = Network()	# 例項化網路模型
print(model) # 輸出網路結構

torch.manual_seed(0) # 固定隨機種子,確保每次產生的隨機輸入一致,方便我們評估訓練結果
Input = torch.randn(1,1,3,3) # 自定義資料輸入

Iter_num = 10	# 定義最大的迭代次數
Label = torch.tensor(1.0) # 定義有監督訓練的label,這裡的label必須是float型別的Tensor,否則會出錯
criterion = nn.MSELoss()	# 定義損失函式,這裡選用MSE

import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr = 0.01)	#定義優化器,這裡採用隨機梯度下降(SGD)

for index in range(Iter_num):
    Output = model(Input)	# 計算網路輸出
    loss = criterion(Output, Label) # 計算loss
    loss.backward()	# 反向傳播計算梯度
    optimizer.step()	# 梯度更新
    print("Iter:{}/{}\tloss:{}\tOutput:{}".format(index, Iter_num, loss.data, Output.data))

訓練過程如下:

Iter:0/10	loss:1.4089158773422241	Output:tensor([[[[-0.1870]]]])
Iter:1/10	loss:1.3796569108963013	Output:tensor([[[[-0.1746]]]])
Iter:2/10	loss:1.323099136352539	Output:tensor([[[[-0.1503]]]])
Iter:3/10	loss:1.2428957223892212	Output:tensor([[[[-0.1149]]]])
Iter:4/10	loss:1.143916130065918	Output:tensor([[[[-0.0695]]]])
Iter:5/10	loss:1.0316702127456665	Output:tensor([[[[-0.0157]]]])
Iter:6/10	loss:0.9117376208305359	Output:tensor([[[[0.0452]]]])
Iter:7/10	loss:0.7892979979515076	Output:tensor([[[[0.1116]]]])
Iter:8/10	loss:0.6688111424446106	Output:tensor([[[[0.1822]]]])
Iter:9/10	loss:0.5538586378097534	Output:tensor([[[[0.2558]]]])

3、模型儲存

3.1 模型引數一起儲存與載入
'''
這種方式儲存模型的引數,而非整個模型
'''
torch.save(model.state_dict(), model_path)	# 儲存網路模型的引數
checkpoint = torch.load(model_path)	# 先載入模型的引數
model.load_state_dict(checkpoint)	# 再將載入的引數填入例項化的網路模型中
'''
這種方式儲存整個模型
'''
torch.save(model,model_path)	# 直接儲存整個模型,包括模型結構和引數
model = torch.load(model_path)	# 不用例項化,直接載入就可以用

儲存整個模型與儲存模型引數的區別:

  1. 整個模型:是儲存整個網路結構和引數,使用時會載入結構和其中的引數,即邊搭框架邊填充引數;
  2. 僅引數:僅儲存網路模型中的引數,在使用時需要先用訓練時的模型例項化,再往裡面填入引數,即需要先搭好框架再往框架裡填引數。

下面我們就分別通過這兩種方式進行模型儲存與載入:

model_path_dict = './ckpt_dict.pth'	# 模型引數的儲存路徑
torch.save(model.state_dict(), model_path_dict)

model_path_model = './ckpt_model.pth' # 整個模型的儲存路徑
torch.save(model, model_path_model)

model_test = Network()	# 重新例項化一個網路物件
test_out = model_test(Input)	# 先看一下初始化輸出
print("test_out: ", test_out.data)
 
checkpoint = torch.load(model_path_dict)	# 採用載入引數的方式載入與訓練模型
model_test.load_state_dict(checkpoint)
print("test_out1: ", model_test(Input).data)	# 檢視預訓練模型載入後的輸出

model_test2 = torch.load(model_path_model)	# 直接載入整個模型
print("test_out1: ", model_test2(Input).data)	# 檢視預訓練模型載入後的輸出

對應的輸出結果如下:

test_out:   tensor([[[[0.1190]]]])  # 網路剛開始的輸出結果
test_out1:  tensor([[[[0.2558]]]])	# 載入引數後的網路輸出
test_out2:  tensor([[[[0.2558]]]])  # 載入整個模型後的網路輸出

從結果中可以看出,這兩種方式載入網路模型的效果是一樣的,但是隻儲存引數的模型所佔空間為 2731位元組,整個模型所佔的空間為4071位元組,所以一般建議採取第一種方法。

3.2 模型引數分開儲存
model_path_dict2 = './ckpt_dict2.pth'	# 模型的儲存路徑
torch.save({
    'net1':model.head.state_dict(),
    'net2':model.tail.state_dict(),
     }, model_path_dict2)	# 將模型的head和tail模組分開儲存
model3 = Network()	# 例項化一個新的網路
print("test_out: ", model3(Input).data)	# 測試一下原始輸出

checkpoint = torch.load(model_path_dict2)
model3.head.load_state_dict(checkpoint['net1'])	# 給不同的模組分別載入不同的模型
model3.tail.load_state_dict(checkpoint['net2'])	
print("test_out: ", model3(Input).data)	#測試一下最後的輸出
test_out:  tensor([[[[-0.1870]]]])
test_out:  tensor([[[[0.2558]]]])

4、載入模型的部分引數

很多時候我們在訓練過程中或多或少都會遇到如下問題:

  1. 已經有了與網路匹配的預訓練模型,根據情況需要在網路中新增一個小模組,但是還想利用之前的與訓練模型
  2. 雖然用的是同一個網路結構,但是由於定義的方法不一樣,導致與訓練模型的 key 對應不上

在這些情況下,上述載入模型的方式不能很好地解決這些問題,因此在載入模型時需要更精細的控制才能滿足我們的要求。首先我們要先了解一下網路載入模型的實質,其實網路和模型都是按照字典的格式進行儲存的,如下所示:

net_dic = model.state_dict()	# 載入網路的字典
for key, value in net_dic.items():	# 顯示網路的 key value 值
    print(key)
    print(value)
for key, value in checkpoint.items():	# 顯示模型的 key value 值
    print(key)
    print(value)

輸出結果如下:

"""
這是網路的key-value
"""
head.n.weight
tensor([[[[-0.2744,  0.2048, -0.0635],
          [-0.1417,  0.2827, -0.2909],
          [ 0.0396, -0.0686,  0.2342]]],
          ...])
head.n.bias
tensor([-0.2389,  0.0188])
tail.n.0.weight
tensor([[[[-0.1658, -0.1408, -0.1394],
          [ 0.1010, -0.1735, -0.0215],
          [ 0.0153,  0.1298, -0.2054]]
          ...]])
tail.n.0.bias
tensor([0.0328, 0.1939])
tail.n.1.weight
tensor([[[[ 0.0598,  0.2197,  0.1340],
          [-0.1290,  0.1500, -0.1595],
          [-0.1066,  0.0536,  0.1065]],
          ...]])
tail.n.1.bias
tensor([0.0029])
"""
這是與訓練模型的key-value
"""
head.n.weight
tensor([[[[-0.2744,  0.2048, -0.0635],
          [-0.1417,  0.2827, -0.2909],
          [ 0.0396, -0.0686,  0.2342]]],
       ...])
head.n.bias
tensor([-0.2389,  0.0188])
tail.n.0.weight
tensor([[[[-0.1658, -0.1408, -0.1394],
          [ 0.1010, -0.1735, -0.0215],
          [ 0.0153,  0.1298, -0.2054]],
        ...]])
tail.n.0.bias
tensor([0.0328, 0.1939])
tail.n.1.weight
tensor([[[[ 0.0598,  0.2197,  0.1340],
          [-0.1290,  0.1500, -0.1595],
          [-0.1066,  0.0536,  0.1065]],
			...]])
tail.n.1.bias
tensor([0.0029])

因此模型載入的實質可以總結為:找到網路與模型相同的key,將模型對應的引數填入到網路中去。因此若要解決上述問題,只需要在載入模型引數時,進行 if-else 判斷進行選擇特定的網路層或者篩選特定的模型引數。所以 3.1節中載入模型引數可以寫成:

checkpoint = torch.load(model_path_dict)	# 採用載入引數的方式載入與訓練模型
model_stic = model.state_dict()	# 提取網路的字典
state_dic = {k:v for k,v in checkpoint.items() if k in model_stic.keys()}	# 找出待載入模型中與網路key一樣的引數
model_stic.update(state_dic) # 更新網路引數
print("test_out1: ", model_test(Input).data)	# 檢視預訓練模型載入後的輸出

5、凍結模型的部分引數

在訓練網路的時候,有的時候不一定需要網路的每個結構都按照同一個學習率更新,或者有的模組乾脆不更新,因此這就需要凍結部分模型引數的梯度,但是又不能截斷反向傳播的梯度流,不然就會導致網路無法正常訓練。

5.1 方法一:requires_grad = false
for name, para in model.named_parameters():
    if 'tail' in name:
        para.requires_grad = False	# 將 tail 模組的梯度更新關閉,即凍結tail的引數
  
for para in model.parameters():	# 在訓練前輸出一下網路引數,與訓練後進行對比
    print(para)
    
for index in range(Iter_num):
    Output = model(Input)
    loss = criterion(Output, Label)
    loss.backward()
    optimizer.step()
    print("Iter:{}/{}\tloss:{}\tOutput:{}".format(index, Iter_num, loss.data, Output.data))
    
for para in model.parameters():	# 輸出訓練後的模型引數
    print(para)

訓練前的網路的部分引數:

Parameter containing:
tensor([[[[ 0.1211,  0.2768, -0.0686],
          [ 0.2494, -0.0537,  0.0353],
          [ 0.3018, -0.3092, -0.2098]]],
					...], requires_grad=True)
Parameter containing:
tensor([0.1487, 0.1616], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208,  0.0399],
          [-0.2201, -0.1703, -0.1215],
          [ 0.1487,  0.1382, -0.1045]],
				...]])
Parameter containing:
tensor([ 0.0469, -0.2050])
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
          [ 0.2094,  0.1792, -0.2351],
          [ 0.0441, -0.0397, -0.0388]],
				...]])
Parameter containing:
tensor([0.1177])

訓練後網路的引數:

Parameter containing:
tensor([[[[ 0.1256,  0.2754, -0.0720],
          [ 0.2429, -0.0717,  0.0461],
          [ 0.2887, -0.3248, -0.2124]]],
					...], requires_grad=True)
Parameter containing:
tensor([0.1525, 0.1894], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208,  0.0399],
          [-0.2201, -0.1703, -0.1215],
          [ 0.1487,  0.1382, -0.1045]],
         ...]])
Parameter containing:
tensor([ 0.0469, -0.2050])
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
          [ 0.2094,  0.1792, -0.2351],
          [ 0.0441, -0.0397, -0.0388]],
					...]])
Parameter containing:
tensor([0.1177])

通過對比可以發現,網路只更新了 head 層的引數,被凍結的 tail 層引數並沒有更新。

5.2 從優化器中設定更新的網路層
import torch.optim as optim
optimizer = optim.SGD(model.head.parameters(), lr = 0.001)	# 在優化器中只填入head層的引數
for para in model.parameters():	# 在訓練前輸出一下網路引數,與訓練後進行對比
    print(para)
    
for index in range(Iter_num):
    Output = model(Input)
    loss = criterion(Output, Label)
    loss.backward()
    optimizer.step()
    print("Iter:{}/{}\tloss:{}\tOutput:{}".format(index, Iter_num, loss.data, Output.data))
    
for para in model.parameters():	# 輸出訓練後的模型引數
    print(para)

訓練前的網路的部分引數:

Parameter containing:
tensor([[[[ 0.1211,  0.2768, -0.0686],
          [ 0.2494, -0.0537,  0.0353],
          [ 0.3018, -0.3092, -0.2098]]],
					...], requires_grad=True)
Parameter containing:
tensor([0.1487, 0.1616], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208,  0.0399],
          [-0.2201, -0.1703, -0.1215],
          [ 0.1487,  0.1382, -0.1045]],
					...]], requires_grad=True)
Parameter containing:
tensor([ 0.0469, -0.2050], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
          [ 0.2094,  0.1792, -0.2351],
          [ 0.0441, -0.0397, -0.0388]],
					...]], requires_grad=True)
Parameter containing:
tensor([0.1177], requires_grad=True)

訓練後的網路的部分引數:

Parameter containing:
tensor([[[[ 0.1256,  0.2754, -0.0720],
          [ 0.2429, -0.0717,  0.0461],
          [ 0.2887, -0.3248, -0.2124]]],
					...], requires_grad=True)
Parameter containing:
tensor([0.1525, 0.1894], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208,  0.0399],
          [-0.2201, -0.1703, -0.1215],
          [ 0.1487,  0.1382, -0.1045]],
					...]], requires_grad=True)
Parameter containing:
tensor([ 0.0469, -0.2050], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
          [ 0.2094,  0.1792, -0.2351],
          [ 0.0441, -0.0397, -0.0388]],
					...]], requires_grad=True)
Parameter containing:
tensor([0.1177], requires_grad=True)

對比這兩種方法都能夠實現網路某一層引數的凍結而不影響其它層的梯度更新,但是仔細觀察發現方法一中不更新引數的網路層的 requires_grad = False,而方法二中所有層的 requires_grad = True。由於個人知識水平有限,難免有錯誤的地方,還請不吝指正,相互學習,共同進步。

相關文章