網路frozen

csjywu1發表於2024-04-27

參考網頁
https://blog.csdn.net/Orientliu96/article/details/104705912
device = torch.device("cuda" )

#Try to load models

model = DGCNN(args)
print(str(model))
model = model.to(device)

save_model = torch.load('model.t7')
model_dict = model.state_dict()

更新模型的引數,因為自己的網路比pretrain的模型更復雜

state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

for name,p in model.named_parameters():
if name.startswith('conv1'): p.requires_grad = False
if name.startswith('conv2'): p.requires_grad = False
if name.startswith('conv3'): p.requires_grad = False
if name.startswith('conv4'): p.requires_grad = False
if name.startswith('bn1'): p.requires_grad = False
if name.startswith('bn2'): p.requires_grad = False
if name.startswith('bn3'): p.requires_grad = False
if name.startswith('bn4'): p.requires_grad = False

opt = optim.SGD(filter(lambda x: x.requires_grad is not False ,model.parameters()), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)

解讀

這段程式碼使用了PyTorch框架,並展示瞭如何在一個卷積神經網路(DGCNN)上執行模型引數的載入、更新、部分凍結引數,以及設定最佳化器的過程。下面是每部分程式碼的具體解讀:

  1. 設定裝置

    device = torch.device("cuda")
    

    這行程式碼設定了使用CUDA裝置,也就是GPU來加速模型的計算。

  2. 初始化和列印模型

    model = DGCNN(args)
    print(str(model))
    model = model.to(device)
    

    這裡建立了一個DGCNN模型的例項,其中args可能包含了模型的一些配置引數(如層數、啟用函式等)。列印模型的結構後,將模型遷移到之前指定的CUDA裝置上。

  3. 載入預訓練模型並更新當前模型的引數

    save_model = torch.load('model.t7')
    model_dict = model.state_dict()
    state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
    print(state_dict.keys())
    model_dict.update(state_dict)
    model.load_state_dict(model_dict)
    

    首先從model.t7檔案中載入一個預訓練模型的狀態字典。然後獲取當前模型的狀態字典,並建立一個新的狀態字典,只包含預訓練模型中存在於當前模型中的引數。更新當前模型的狀態字典,並將其載入回模型中。

  4. 凍結特定層的引數

    for name, p in model.named_parameters():
        if name.startswith(('conv1', 'conv2', 'conv3', 'conv4', 'bn1', 'bn2', 'bn3', 'bn4')):
            p.requires_grad = False
    

    這部分遍歷模型的所有引數,並根據引數名字決定是否凍結引數。這裡凍結了以conv1, conv2, conv3, conv4, bn1, bn2, bn3, bn4開頭的所有層的引數。requires_grad = False意味著在訓練過程中這些引數不會被更新。

  5. 設定最佳化器

    opt = optim.SGD(filter(lambda x: x.requires_grad is not False, model.parameters()), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)
    

    最後,設定了一個隨機梯度下降(SGD)最佳化器,僅針對那些未被凍結(requires_grad 不是 False)的引數。這裡還對學習率進行了調整(可能是為了適應不同的訓練需求),並設定了動量和權重衰減,這些都有助於最佳化訓練過程。

整體上,這段程式碼展示瞭如何利用預訓練模型來初始化一個更復雜的網路模型,並透過凍結部分層的引數來加速訓練過程,同時專注於訓練模型的特定部分。

對與預訓練中重合的key做交集。再將做交集的內容網路值返回到現有的模型當中

  1. torch.load('model.t7'):這行程式碼從model.t7檔案載入模型權重。在PyTorch中,模型的狀態通常儲存在一個名為.t7.pth的檔案中,這個檔案包含了模型引數的字典。

  2. model.state_dict():每個PyTorch模型都有一個state_dict方法,它返回模型當前的狀態,即其所有引數的字典。

  3. state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}:這段程式碼遍歷save_model字典中的每一項(鍵值對)。它建立一個新的字典state_dict,這個字典只包括那些鍵同時存在於save_model和當前模型狀態字典model_dict中的項。這意味著,只有當預訓練模型的某個權重與當前模型中的一個權重具有相同的鍵名時,這個權重才會被新增到新的字典中。

  4. print(state_dict.keys()):列印出state_dict中所有鍵的列表,這通常用於驗證哪些引數被載入了。

  5. model_dict.update(state_dict):這行程式碼將state_dict中的項更新到model_dict中。如果model_dict中已經有相同的鍵,則這些鍵對應的值會被state_dict中的值覆蓋。

  6. model.load_state_dict(model_dict):最後,更新後的model_dict被載入回模型中,這樣模型就具有了部分預訓練的權重。

這個過程允許模型只載入那些已知的、匹配的權重,這對於模型微調是很有用的,尤其是當你有一個預訓練的模型,想要將它適配到一個新的、稍有不同的任務或網路架構時。這樣做可以確保只有對應的權重被載入和更新,從而保持模型中其他自定義部分的引數不變。

相關文章