Pytorch:利用torch.nn.Modules.parameters修改模型引數

orion發表於2022-05-20

1. 關於parameters()方法

Pytorch中繼承了torch.nn.Module的模型類具有named_parameters()/parameters()方法,這兩個方法都會返回一個用於迭代模型引數的迭代器(named_parameters還包括引數名字):

import torch

net = torch.nn.LSTM(input_size=512, hidden_size=64)
print(net.parameters())
print(net.named_parameters())
# <generator object Module.parameters at 0x12a4e9890>
# <generator object Module.named_parameters at 0x12a4e9890>

我們可以將net.parameters()迭代器和將net.named_parameters()轉化為列表型別,前者列表元素是模型引數,後者是包含引數名和模型引數的元組。

當然,我們更多的是對迭代器直接進行迭代:

for param in net.parameters():
    print(param.shape)
# torch.Size([256, 512])
# torch.Size([256, 64])
# torch.Size([256])
# torch.Size([256])
for name, param in net.named_parameters():
    print(name, param.shape)
# weight_ih_l0 torch.Size([256, 512])
# weight_hh_l0 torch.Size([256, 64])
# bias_ih_l0 torch.Size([256])
# bias_hh_l0 torch.Size([256])

我們知道,Pytorch在進行優化時需要給優化器傳入這個引數迭代器,如:

from torch.optim import RMSprop
optimizer = RMSprop(net.parameters(), lr=0.01)

2. 關於引數修改

那麼底層具體是怎麼對引數進行修改的呢?

我們在部落格《Python物件模型與序列迭代陷阱》中介紹過,Python序列中本身存放的就是物件的引用,而迭代器返回的是序列中的物件的二次引用,如果序列的引用指向基礎資料型別,則是不可以通過遍歷序列進行修改的,如:

my_list = [1, 2, 3, 4]
for x in my_list:
    x += 1
print(my_list) #[1, 2, 3, 4]

而序列中的引用指向複合資料型別,則可以通過遍歷序列來完成修改操作,如:

my_list = [[1, 2],[3, 4]]
for sub_list in my_list:
    sub_list[0] += 1
print(my_list)
# [1, 2, 3, 4]
# [[2, 2], [4, 4]]

具體原理可參照該篇部落格,此處我就不在贅述。這裡想提到的是,用net.parameters()/net.named_parameters()來迭代並修改引數,本質上就是上述第二種對複合資料型別序列的修改。我們可以如下寫:

for param in net.parameters():
    with torch.no_grad():
        param += 1

with torch.no_grad():表示將將所要修改的張量關閉梯度計算。所增加的1會廣播到param張量的中的每一個元素上。上述操作本質上為:

for param in net.parameters():
    with torch.no_grad():
        param += torch.ones(param.shape)

但是需要注意,如果我們想讓引數全部置為0,切不可像下列這樣寫:

for param in net.parameters():
    with torch.no_grad():
        param = torch.zeros(param.shape) 

param是二次引用,param=0操作再語義上會被解釋為讓param這個二次引用去指向新的全0張量物件,但是對引數張量本身並不會產生任何變動。該操作實際上類似下列這種操作:

list_1 = [1, 2]
list_2 = list_1
list_2 = [0, 0]
print(list_1) # [1, 2]

修改二次引用list_2自然不會影響到list_1引用的物件。

下面讓我們糾正這種錯誤,採用下列方法直接來將引數張量中的所有數值置0:

for param in net.parameters():
    with torch.no_grad():
        param[:] = 0 #張量型別自帶廣播操作,等效於param[:] = torch.zeros(param.shape) 

這時語義上就類似

list_1 = [1, 2]
list_2 = list_1
list_2[:] = [0, 0]
print(list_1) # [0, 0]

自然就能完成修改的操作了。

參考

相關文章