【pytorch_5】線性迴歸的實現

MinSuga發表於2020-10-03

步驟:

1. 準備資料

import torch

x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]]) 

2. 設計模型:

class LinearModel(torch.nn.Module):     #繼承
    def __init__(self):         #建構函式,初始化物件
        super(LinearModel,self).__init__()     #呼叫父類
        self.linear = torch.nn.Linear(1,1)      #(權重,偏置)

    def forward(self,x):
        y_pred = self.linear(x)     #可呼叫物件
        return y_pred

model = LinearModel()   #例項化

3. 構造Loss損失和優化

MSE:

 

SGD:

criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)     #lr學習率

4. 訓練

for epoch in range(100):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss)

    optimizer.zero_grad()       #梯度歸零
    loss.backward()
    optimizer.step()

完整實現過程:

import torch

x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])      #準備資料

class LinearModel(torch.nn.Module):     #繼承
    def __init__(self):         #建構函式,初始化物件
        super(LinearModel,self).__init__()     #呼叫父類
        self.linear = torch.nn.Linear(1,1)      #(權重,偏置)

    def forward(self,x):
        y_pred = self.linear(x)     #可呼叫物件
        return y_pred

model = LinearModel()   #例項化

criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)     #lr學習率

for epoch in range(100):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss)

    optimizer.zero_grad()       #梯度歸零
    loss.backward()
    optimizer.step()


print('w = ',model.linear.weight.item())
print('b = ',model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ',y_test.data)


結果參考:

0 tensor(113.5261, grad_fn=<MseLossBackward>)
1 tensor(50.8733, grad_fn=<MseLossBackward>)
2 tensor(22.9772, grad_fn=<MseLossBackward>)
3 tensor(10.5539, grad_fn=<MseLossBackward>)
4 tensor(5.0188, grad_fn=<MseLossBackward>)
5 tensor(2.5501, grad_fn=<MseLossBackward>)
6 tensor(1.4465, grad_fn=<MseLossBackward>)
7 tensor(0.9508, grad_fn=<MseLossBackward>)
8 tensor(0.7257, grad_fn=<MseLossBackward>)
9 tensor(0.6211, grad_fn=<MseLossBackward>)
10 tensor(0.5703, grad_fn=<MseLossBackward>)
11 tensor(0.5435, grad_fn=<MseLossBackward>)
12 tensor(0.5273, grad_fn=<MseLossBackward>)
13 tensor(0.5161, grad_fn=<MseLossBackward>)
14 tensor(0.5070, grad_fn=<MseLossBackward>)
15 tensor(0.4990, grad_fn=<MseLossBackward>)
16 tensor(0.4915, grad_fn=<MseLossBackward>)
17 tensor(0.4843, grad_fn=<MseLossBackward>)
18 tensor(0.4773, grad_fn=<MseLossBackward>)
19 tensor(0.4704, grad_fn=<MseLossBackward>)
20 tensor(0.4636, grad_fn=<MseLossBackward>)
21 tensor(0.4569, grad_fn=<MseLossBackward>)
22 tensor(0.4504, grad_fn=<MseLossBackward>)
23 tensor(0.4439, grad_fn=<MseLossBackward>)
24 tensor(0.4375, grad_fn=<MseLossBackward>)
25 tensor(0.4312, grad_fn=<MseLossBackward>)
26 tensor(0.4250, grad_fn=<MseLossBackward>)
27 tensor(0.4189, grad_fn=<MseLossBackward>)
28 tensor(0.4129, grad_fn=<MseLossBackward>)
29 tensor(0.4070, grad_fn=<MseLossBackward>)
30 tensor(0.4011, grad_fn=<MseLossBackward>)
31 tensor(0.3953, grad_fn=<MseLossBackward>)
32 tensor(0.3897, grad_fn=<MseLossBackward>)
33 tensor(0.3841, grad_fn=<MseLossBackward>)
34 tensor(0.3785, grad_fn=<MseLossBackward>)
35 tensor(0.3731, grad_fn=<MseLossBackward>)
36 tensor(0.3677, grad_fn=<MseLossBackward>)
37 tensor(0.3625, grad_fn=<MseLossBackward>)
38 tensor(0.3572, grad_fn=<MseLossBackward>)
39 tensor(0.3521, grad_fn=<MseLossBackward>)
40 tensor(0.3471, grad_fn=<MseLossBackward>)
41 tensor(0.3421, grad_fn=<MseLossBackward>)
42 tensor(0.3371, grad_fn=<MseLossBackward>)
43 tensor(0.3323, grad_fn=<MseLossBackward>)
44 tensor(0.3275, grad_fn=<MseLossBackward>)
45 tensor(0.3228, grad_fn=<MseLossBackward>)
46 tensor(0.3182, grad_fn=<MseLossBackward>)
47 tensor(0.3136, grad_fn=<MseLossBackward>)
48 tensor(0.3091, grad_fn=<MseLossBackward>)
49 tensor(0.3047, grad_fn=<MseLossBackward>)
50 tensor(0.3003, grad_fn=<MseLossBackward>)
51 tensor(0.2960, grad_fn=<MseLossBackward>)
52 tensor(0.2917, grad_fn=<MseLossBackward>)
53 tensor(0.2875, grad_fn=<MseLossBackward>)
54 tensor(0.2834, grad_fn=<MseLossBackward>)
55 tensor(0.2793, grad_fn=<MseLossBackward>)
56 tensor(0.2753, grad_fn=<MseLossBackward>)
57 tensor(0.2713, grad_fn=<MseLossBackward>)
58 tensor(0.2674, grad_fn=<MseLossBackward>)
59 tensor(0.2636, grad_fn=<MseLossBackward>)
60 tensor(0.2598, grad_fn=<MseLossBackward>)
61 tensor(0.2561, grad_fn=<MseLossBackward>)
62 tensor(0.2524, grad_fn=<MseLossBackward>)
63 tensor(0.2488, grad_fn=<MseLossBackward>)
64 tensor(0.2452, grad_fn=<MseLossBackward>)
65 tensor(0.2417, grad_fn=<MseLossBackward>)
66 tensor(0.2382, grad_fn=<MseLossBackward>)
67 tensor(0.2348, grad_fn=<MseLossBackward>)
68 tensor(0.2314, grad_fn=<MseLossBackward>)
69 tensor(0.2281, grad_fn=<MseLossBackward>)
70 tensor(0.2248, grad_fn=<MseLossBackward>)
71 tensor(0.2216, grad_fn=<MseLossBackward>)
72 tensor(0.2184, grad_fn=<MseLossBackward>)
73 tensor(0.2152, grad_fn=<MseLossBackward>)
74 tensor(0.2122, grad_fn=<MseLossBackward>)
75 tensor(0.2091, grad_fn=<MseLossBackward>)
76 tensor(0.2061, grad_fn=<MseLossBackward>)
77 tensor(0.2031, grad_fn=<MseLossBackward>)
78 tensor(0.2002, grad_fn=<MseLossBackward>)
79 tensor(0.1973, grad_fn=<MseLossBackward>)
80 tensor(0.1945, grad_fn=<MseLossBackward>)
81 tensor(0.1917, grad_fn=<MseLossBackward>)
82 tensor(0.1890, grad_fn=<MseLossBackward>)
83 tensor(0.1862, grad_fn=<MseLossBackward>)
84 tensor(0.1836, grad_fn=<MseLossBackward>)
85 tensor(0.1809, grad_fn=<MseLossBackward>)
86 tensor(0.1783, grad_fn=<MseLossBackward>)
87 tensor(0.1758, grad_fn=<MseLossBackward>)
88 tensor(0.1732, grad_fn=<MseLossBackward>)
89 tensor(0.1707, grad_fn=<MseLossBackward>)
90 tensor(0.1683, grad_fn=<MseLossBackward>)
91 tensor(0.1659, grad_fn=<MseLossBackward>)
92 tensor(0.1635, grad_fn=<MseLossBackward>)
93 tensor(0.1611, grad_fn=<MseLossBackward>)
94 tensor(0.1588, grad_fn=<MseLossBackward>)
95 tensor(0.1565, grad_fn=<MseLossBackward>)
96 tensor(0.1543, grad_fn=<MseLossBackward>)
97 tensor(0.1521, grad_fn=<MseLossBackward>)
98 tensor(0.1499, grad_fn=<MseLossBackward>)
99 tensor(0.1477, grad_fn=<MseLossBackward>)
w =  1.7441235780715942
b =  0.5816670656204224
y_pred =  tensor([[7.5582]])

 

相關文章