深度學習(模型引數直方圖)

Dsp Tian發表於2024-10-03

模型引數直方圖可以展示模型引數在訓練過程中的分佈情況。

透過直方圖,可以瞭解模型的學習狀態,識別過擬合或欠擬合問題,從而進行模型調優。

下面以ResNet18為例,顯示了不同層的引數直方圖。

import torchvision
from matplotlib import pyplot as plt
import torch

model = torchvision.models.resnet18(pretrained=True)

num = 1
# 遍歷模型的每一層
for name, module in model.named_modules():
    # 判斷是否為卷積層
    if isinstance(module, torch.nn.Conv2d):
        # 輸出卷積層名稱和權重
        print(f"layer {name} : {module.weight.data.shape}")
        Oc,Ic,H,W = module.weight.data.shape
        data = module.weight.data.view(Oc*Ic*H*W).numpy()            
        plt.subplot(5,4,num)
        plt.hist(data,bins=50)
        num +=1

plt.show()           

結果如下:

相關文章