pytorch小知識(01):forward方法

小丑与锁鸟發表於2024-06-08
我們都知道torch.nn.Module在被繼承時要求我們改寫兩個方法,一個是__init__,一個是forward。前者用於定義層,後者用於定義前向計算的流程。但是當我們在實際使用一個網路時,我們不會使用forward這個方法進行計算,而是進行如下的操作:
點選檢視程式碼
import torch
import torch.nn as nn
from torchsummary import summary

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(300,200)
        self.layer2 = nn.Linear(200,200)
        self.out = nn.Linear(200,10)
        nn.init.kaiming_normal_(self.layer1.weight)
        nn.init.xavier_normal_(self.layer2.weight)
        nn.init.kaiming_uniform_(self.out.weight)
    def forward(self, x):
        x = self.layer1(x)
        x = torch.relu(x)
        x = self.layer2(x)
        x = torch.sigmoid(x)
        x = self.out(x)
        x = torch.softmax(x, dim=-1)
        return x

if __name__ == '__main__':
    device = torch.device('cuda:0')
    net = Model().to(device)
    data = torch.randn(500,300).to(device)
    y = net(data)
    print('y:',y)
    summary(model=net,input_size=(300,),batch_size=500)
    print("======檢視模型引數w和b======")
    for name, parameter in net.named_parameters():
        print(name, parameter)

可以看到我們直接使用了net來接收例項,然後在進行對標籤的計算時,直接使用了y=net(x)。 實際上,這是由於nn.Module自帶的__call__魔法方法在起作用。這裡,當我們呼叫net時,__call__會幫我們自動呼叫forward方法。所以,我們不管是使用net(x),net.forward(x)還是net.__call__(x),得到的結果都是對x進行了一輪計算。 以上,就是今天的pytorch小知識。