以pytorch的forward hook為例探究hook機制

Niax23發表於2024-10-09

在看pytorch的nn.Module部分的原始碼的時候,看到了一堆"鉤子",也就是hook,然後去研究了一下這是啥玩意。

基本概念

在深度學習中,hook 是一種可以在模型的不同階段插入自定義程式碼的機制。透過自定義資料在透過模型的特定層的額外行為,可以用來監控狀態,協助除錯,獲得中間結果。

以前向hook為例

前向hook是模型在forward過程中會呼叫的hook,透過torch.nn.Module的register_forward_hook() 函式,將一個自定義的hook函式註冊給模型的一個層
該層在進行前向之後,根據其輸入和輸出會進行相應的行為。


import torch
import torch.nn as nn

# 定義模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.fc1(x)
        return self.fc2(x)

model = SimpleModel()

# 自定義的forward hook
def my_forward_hook(module, input, output):
    print(f"層: {module}")
    print(f"輸入: {input}")
    print(f"輸出: {output}")


# 為模型的fc1層註冊hook
hook = model.fc1.register_forward_hook(my_forward_hook)

# 移除這個hook
hook.remove()

介面

hook函式的格式

需要是一個接受三個特定引數,返回None的函式

def hook_function(module, input, output):
    # 自定義邏輯
    return None
  • module: 觸發鉤子的模型層,事實上是呼叫register_forward_hook的nn.Module例項
  • input: 傳遞給該層的輸入張量(可能是元組),是前向傳播時該層接收到的輸入。
  • output: 該層的輸出張量,是前向傳播時該層生成的輸出。
    函式內部可以做自定義行為,可以在函式內部對output進行修改,從而改變模型的輸出。

註冊hook

hook = model.fc1.register_forward_hook(my_forward_hook)

hook.remove()

對於定義好的hook函式,將其作為引數,呼叫需要註冊的模型層的註冊函式即可。
如果不再需要這個hook,呼叫remove函式。

簡單的原始碼討論

還是以forward hook為例。一個nn.Module具有成員_forward_hooks,這是一個有序字典,在__init__()函式呼叫的時候被初始化

self._forward_hooks = OrderedDict()

註冊鉤子的register函式。
每個hook對應一個RemovableHandle物件,以其id作為鍵註冊到hook字典中,利用其remove函式實現移除。

def register_forward_hook(
        self,
        hook: Union[
            Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
            Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
        ],
        *,
        prepend: bool = False,
        with_kwargs: bool = False,
        always_call: bool = False,
    ) -> RemovableHandle:
        handle = RemovableHandle(
            self._forward_hooks,
            extra_dict=[
                self._forward_hooks_with_kwargs,
                self._forward_hooks_always_called,
            ],
        )
        self._forward_hooks[handle.id] = hook
        if with_kwargs:
            self._forward_hooks_with_kwargs[handle.id] = True
        if always_call:
            self._forward_hooks_always_called[handle.id] = True
        if prepend:
            self._forward_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]
        return handle

#簡化版的handle類
class RemovableHandle:
    def __init__(self, hooks_dict, handle_id):
        self.hooks_dict = hooks_dict
        self.id = handle_id

    def remove(self):
        del self.hooks_dict[self.id]

相關文章