視覺化卷積核引數對理解卷積神經網路的工作原理、最佳化模型效能、提高模型泛化能力有一定幫助作用。
下面以resnet18為例,視覺化了部分卷積核引數。
import torchvision from matplotlib import pyplot as plt import torch model = torchvision.models.resnet18(pretrained=True) #model = torchvision.models.efficientnet_b0(pretrained=True) num = 1 # 遍歷模型的每一層 for name, module in model.named_modules(): # 判斷是否為卷積層 if isinstance(module, torch.nn.Conv2d): # 輸出卷積層名稱和權重 print(f"layer {name} : {module.weight.data.shape}") _,_,H,W = module.weight.data.shape if H >=3 and W >=3: plt.subplot(5,4,num) data = module.weight.data.numpy() plt.imshow(data[0,0,:,:]) #太多了,只顯示一個卷積核 num+=1 plt.show()
結果如下: