Pytorch相關(第二篇)

玥茹苟發表於2024-09-07

Pytorch自動梯度法,實現自定義向前 向後傳播方法

在 PyTorch 中,自定義自動求導的功能可以透過實現繼承自 torch.autograd.Function 的類來實現。這允許您定義自己的前向傳播(forward)和反向傳播(backward)邏輯。下面是如何自定義實現向前和向後傳播的詳細步驟和示例程式碼。

自定義 autograd 製作步驟

  1. 建立繼承自 torch.autograd.Function 的類。
  2. 實現 forward 方法:計算前向傳播並儲存任何需要在反向傳播中使用的張量。
  3. 實現 backward 方法:計算反向傳播,使用 grad_output 計算每個輸入的梯度。
  4. 使用自定義的 Function:在訓練或評估中使用您的自定義操作。

示例程式碼:自定義平方函式

下面的示例自定義了一個平方操作的前向和反向傳播:

import torch

class MySquareFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        """前向傳播"""
        ctx.save_for_backward(input)  # 儲存輸入張量以用於反向傳播
        return input ** 2  # 返回平方結果

    @staticmethod
    def backward(ctx, grad_output):
        """反向傳播"""
        input, = ctx.saved_tensors  # 獲取儲存的輸入張量
        grad_input = 2 * input * grad_output  # 計算關於輸入的梯度
        return grad_input  # 返回梯度

# 使用自定義的 Function
x = torch.tensor([2.0, 3.0, 4.0], requires_grad=True)  # 建立輸入張量
square = MySquareFunction.apply  # 獲取自定義函式的引用

# 前向傳播
y = square(x)
print("Output:", y)  # 輸出: tensor([4.0, 9.0, 16.0], grad_fn=<MySquareFunctionBackward>)

# 計算損失並進行反向傳播
loss = y.sum()
loss.backward()  # 計算梯度

# 列印輸入的梯度
print("Gradient:", x.grad)  # 輸出: tensor([4.0, 6.0, 8.0])

程式碼解釋

  1. 自定義類 MySquareFunction

    • 繼承了 torch.autograd.Function
    • 實現了兩個靜態方法 forward()backward()
  2. 前向傳播:

    • forward 方法中,計算輸入的平方並將輸入張量儲存到上下文中,以便在反向傳播中使用。
    • 使用 ctx.save_for_backward(input) 儲存輸入。
  3. 反向傳播:

    • backward 方法中,從上下文中獲取儲存的輸入張量。
    • 計算關於輸入的梯度公式 d(output)d(input)=2⋅inputd(input)d(output)=2input。
    • 注意,grad_output 是由後續層傳播來的梯度。
  4. 使用自定義的 Function:

    • 建立一個需要計算梯度的張量 x,呼叫自定義的 square 函式進行前向傳播。
    • 計算損失並進行反向傳播,呼叫 loss.backward() 計算輸入的梯度。

總結

透過以上示例,您可以看到如何在 PyTorch 中自定義前向和反向傳播的邏輯。自定義 torch.autograd.Function 允許您實現複雜的操作和梯度計算,同時保留 PyTorch 的自動求導功能。這種方式在編寫新的模型或需要特定行為的操作時尤為有用。您可以根據具體需求修改上述示例,實現自己的自定義操作。