18-神經網路-自定義帶引數的層

不是孩子了發表於2024-08-25

1、nn.Parameter函式


2、torch.mm 和torch.matmul區別
都是 PyTorch 中用於矩陣乘法的函式,但它們在使用上有細微的差別

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyLinear(nn.Module):
    def __init__(self, in_units, out_units):
        super(MyLinear, self).__init__()
        self.weight = nn.Parameter(torch.randn((in_units, out_units)))
        self.bias = nn.Parameter(torch.randn(out_units))

    def forward(self, x):
        linear = torch.matmul(x, self.weight) + self.bias
        return F.relu(linear)

linear = MyLinear(5, 3)
print(linear.weight)

y = linear(torch.rand((2, 5)))
print(y)

相關文章