點選檢視程式碼
import torch
import torch.nn as nn
from torchsummary import summary
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer1 = nn.Linear(300,200)
self.layer2 = nn.Linear(200,200)
self.out = nn.Linear(200,10)
nn.init.kaiming_normal_(self.layer1.weight)
nn.init.xavier_normal_(self.layer2.weight)
nn.init.kaiming_uniform_(self.out.weight)
def forward(self, x):
x = self.layer1(x)
x = torch.relu(x)
x = self.layer2(x)
x = torch.sigmoid(x)
x = self.out(x)
x = torch.softmax(x, dim=-1)
return x
if __name__ == '__main__':
device = torch.device('cuda:0')
net = Model().to(device)
data = torch.randn(500,300).to(device)
y = net(data)
print('y:',y)
summary(model=net,input_size=(300,),batch_size=500)
print("======檢視模型引數w和b======")
for name, parameter in net.named_parameters():
print(name, parameter)
可以看到我們直接使用了net來接收例項,然後在進行對標籤的計算時,直接使用了y=net(x)。
實際上,這是由於nn.Module自帶的__call__魔法方法在起作用。這裡,當我們呼叫net時,__call__會幫我們自動呼叫forward方法。所以,我們不管是使用net(x),net.forward(x)還是net.__call__(x),得到的結果都是對x進行了一輪計算。
以上,就是今天的pytorch小知識。