1、nn.Parameter函式
2、torch.mm 和torch.matmul區別
都是 PyTorch 中用於矩陣乘法的函式,但它們在使用上有細微的差別
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyLinear(nn.Module):
def __init__(self, in_units, out_units):
super(MyLinear, self).__init__()
self.weight = nn.Parameter(torch.randn((in_units, out_units)))
self.bias = nn.Parameter(torch.randn(out_units))
def forward(self, x):
linear = torch.matmul(x, self.weight) + self.bias
return F.relu(linear)
linear = MyLinear(5, 3)
print(linear.weight)
y = linear(torch.rand((2, 5)))
print(y)