3D高斯損失函式(1)單純損失函式

MKT-porter發表於2024-06-28

在PyTorch中,定義一個損失函式並更新梯度通常涉及以下幾個步驟:

定義損失函式:可以使用PyTorch內建的損失函式或者自定義一個損失函式。
前向傳播:透過模型計算預測值。
計算損失:使用定義好的損失函式計算預測值與真實值之間的損失。
反向傳播:透過計算損失的梯度來更新模型引數。
下面是一個完整的示例,展示如何定義一個自定義損失函式並更新梯度:

步驟1:定義損失函式
可以使用內建的損失函式,如 nn.MSELoss,也可以定義一個自定義損失函式。

import torch
import torch.nn as nn

# 定義一個簡單的線性模型
model = nn.Linear(1, 1)

# 定義損失函式 (均方誤差)
loss_fn = nn.MSELoss()

  自定義損失函式

import torch.nn.functional as F

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, y_pred, y_true):
        loss = torch.mean((y_pred - y_true) ** 2)  # 簡單的均方誤差
        return loss

loss_fn = CustomLoss()

  

步驟2:前向傳播
透過模型計算預測值。

# 示例資料
x = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True)
y_true = torch.tensor([[2.0], [4.0], [6.0]])

# 計算預測值
y_pred = model(x)

  

步驟3:計算損失
使用定義好的損失函式計算預測值與真實值之間的損失。

# 計算損失
loss = loss_fn(y_pred, y_true)
print("Loss:", loss.item())

  

步驟4:反向傳播並更新梯度
透過計算損失的梯度來更

# 使用最佳化器 (如Adam或SGD)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 清零梯度
optimizer.zero_grad()

# 反向傳播計算梯度
loss.backward()

# 更新模型引數
optimizer.step()

  

完整示例程式碼
以下是一個完整的示例程式碼,展示瞭如何定義一個自定義損失函式並更新梯度:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 定義一個簡單的線性模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()

# 自定義損失函式
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, y_pred, y_true):
        loss = torch.mean((y_pred - y_true) ** 2)  # 簡單的均方誤差
        return loss

loss_fn = CustomLoss()

# 示例資料
x = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True)
y_true = torch.tensor([[2.0], [4.0], [6.0]])

# 最佳化器
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 前向傳播
y_pred = model(x)

# 計算損失
loss = loss_fn(y_pred, y_true)
print("Loss:", loss.item())

# 清零梯度
optimizer.zero_grad()

# 反向傳播計算梯度
loss.backward()

# 更新模型引數
optimizer.step()

  

相關文章