在看pytorch的nn.Module部分的原始碼的時候,看到了一堆"鉤子",也就是hook,然後去研究了一下這是啥玩意。
基本概念
在深度學習中,hook 是一種可以在模型的不同階段插入自定義程式碼的機制。透過自定義資料在透過模型的特定層的額外行為,可以用來監控狀態,協助除錯,獲得中間結果。
以前向hook為例
前向hook是模型在forward過程中會呼叫的hook,透過torch.nn.Module的register_forward_hook() 函式,將一個自定義的hook函式註冊給模型的一個層
該層在進行前向之後,根據其輸入和輸出會進行相應的行為。
import torch
import torch.nn as nn
# 定義模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = self.fc1(x)
return self.fc2(x)
model = SimpleModel()
# 自定義的forward hook
def my_forward_hook(module, input, output):
print(f"層: {module}")
print(f"輸入: {input}")
print(f"輸出: {output}")
# 為模型的fc1層註冊hook
hook = model.fc1.register_forward_hook(my_forward_hook)
# 移除這個hook
hook.remove()
介面
hook函式的格式
需要是一個接受三個特定引數,返回None的函式
def hook_function(module, input, output):
# 自定義邏輯
return None
- module: 觸發鉤子的模型層,事實上是呼叫register_forward_hook的nn.Module例項
- input: 傳遞給該層的輸入張量(可能是元組),是前向傳播時該層接收到的輸入。
- output: 該層的輸出張量,是前向傳播時該層生成的輸出。
函式內部可以做自定義行為,可以在函式內部對output進行修改,從而改變模型的輸出。
註冊hook
hook = model.fc1.register_forward_hook(my_forward_hook)
hook.remove()
對於定義好的hook函式,將其作為引數,呼叫需要註冊的模型層的註冊函式即可。
如果不再需要這個hook,呼叫remove函式。
簡單的原始碼討論
還是以forward hook為例。一個nn.Module具有成員_forward_hooks,這是一個有序字典,在__init__()函式呼叫的時候被初始化
self._forward_hooks = OrderedDict()
註冊鉤子的register函式。
每個hook對應一個RemovableHandle物件,以其id作為鍵註冊到hook字典中,利用其remove函式實現移除。
def register_forward_hook(
self,
hook: Union[
Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
],
*,
prepend: bool = False,
with_kwargs: bool = False,
always_call: bool = False,
) -> RemovableHandle:
handle = RemovableHandle(
self._forward_hooks,
extra_dict=[
self._forward_hooks_with_kwargs,
self._forward_hooks_always_called,
],
)
self._forward_hooks[handle.id] = hook
if with_kwargs:
self._forward_hooks_with_kwargs[handle.id] = True
if always_call:
self._forward_hooks_always_called[handle.id] = True
if prepend:
self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
return handle
#簡化版的handle類
class RemovableHandle:
def __init__(self, hooks_dict, handle_id):
self.hooks_dict = hooks_dict
self.id = handle_id
def remove(self):
del self.hooks_dict[self.id]