[PyTorch 學習筆記] 5.2 Hook 函式與 CAM 演算法

張賢同學發表於2020-09-07

本章程式碼:

這篇文章主要介紹瞭如何使用 Hook 函式提取網路中的特徵圖進行視覺化,和 CAM(class activation map, 類啟用圖)

Hook 函式概念

Hook 函式是在不改變主體的情況下,實現額外功能。由於 PyTorch 是基於動態圖實現的,因此在一次迭代運算結束後,一些中間變數如非葉子節點的梯度和特徵圖,會被釋放掉。在這種情況下想要提取和記錄這些中間變數,就需要使用 Hook 函式。

PyTorch 提供了 4 種 Hook 函式。

torch.Tensor.register_hook(hook)

功能:註冊一個反向傳播 hook 函式,僅輸入一個引數,為張量的梯度。

hook函式:

hook(grad)

引數:

  • grad:張量的梯度

程式碼如下:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

# 儲存梯度的 list
a_grad = list()

# 定義 hook 函式,把梯度新增到 list 中
def grad_hook(grad):
	a_grad.append(grad)

# 一個張量註冊 hook 函式
handle = a.register_hook(grad_hook)

y.backward()

# 檢視梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
# 檢視在 hook 函式裡 list 記錄的梯度
print("a_grad[0]: ", a_grad[0])
handle.remove()

結果如下:

gradient: tensor([5.]) tensor([2.]) None None None
a_grad[0]:  tensor([2.])

在反向傳播結束後,非葉子節點張量的梯度被清空了。而通過hook函式記錄的梯度仍然可以檢視。

hook函式裡面可以修改梯度的值,無需返回也可以作為新的梯度賦值給原來的梯度。程式碼如下:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
    grad *= 2
    return grad*3

handle = w.register_hook(grad_hook)

y.backward()

# 檢視梯度
print("w.grad: ", w.grad)
handle.remove()

結果是:

w.grad:  tensor([30.])

torch.nn.Module.register_forward_hook(hook)

功能:註冊 module 的前向傳播hook函式,可用於獲取中間的 feature map。

hook函式:

hook(module, input, output)

引數:

  • module:當前網路層
  • input:當前網路層輸入資料
  • output:當前網路層輸出資料

下面程式碼執行的功能是 $3 \times 3$ 的卷積和 $2 \times 2$ 的池化。我們使用register_forward_hook()記錄中間卷積層輸入和輸出的 feature map。

[PyTorch 學習筆記] 5.2 Hook 函式與 CAM 演算法

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x

    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)

    # 初始化網路
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()

    # 註冊hook
    fmap_block = list()
    input_block = list()
    net.conv1.register_forward_hook(forward_hook)

    # inference
    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
    output = net(fake_img)


    # 觀察
    print("output shape: {}\noutput value: {}\n".format(output.shape, output))
    print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
    print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

輸出如下:

output shape: torch.Size([1, 2, 1, 1])
output value: tensor([[[[ 9.]],
         [[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)
feature maps shape: torch.Size([1, 2, 2, 2])
output value: tensor([[[[ 9.,  9.],
          [ 9.,  9.]],
         [[18., 18.],
          [18., 18.]]]], grad_fn=<ThnnConv2DBackward>)
input shape: torch.Size([1, 1, 4, 4])
input value: (tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)

torch.Tensor.register_forward_pre_hook()

功能:註冊 module 的前向傳播前的hook函式,可用於獲取輸入資料。

hook函式:

hook(module, input)

引數:

  • module:當前網路層
  • input:當前網路層輸入資料

torch.Tensor.register_backward_hook()

功能:註冊 module 的反向傳播的hook函式,可用於獲取梯度。

hook函式:

hook(module, grad_input, grad_output)

引數:

  • module:當前網路層
  • input:當前網路層輸入的梯度資料
  • output:當前網路層輸出的梯度資料

程式碼如下:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x

    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)

    def forward_pre_hook(module, data_input):
        print("forward_pre_hook input:{}".format(data_input))

    def backward_hook(module, grad_input, grad_output):
        print("backward hook input:{}".format(grad_input))
        print("backward hook output:{}".format(grad_output))

    # 初始化網路
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()

    # 註冊hook
    fmap_block = list()
    input_block = list()
    net.conv1.register_forward_hook(forward_hook)
    net.conv1.register_forward_pre_hook(forward_pre_hook)
    net.conv1.register_backward_hook(backward_hook)

    # inference
    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
    output = net(fake_img)

    loss_fnc = nn.L1Loss()
    target = torch.randn_like(output)
    loss = loss_fnc(target, output)
    loss.backward()

輸出如下:

forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)
backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]],
        [[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
backward hook output:(tensor([[[[0.5000, 0.0000],
          [0.0000, 0.0000]],
         [[0.5000, 0.0000],
          [0.0000, 0.0000]]]]),)

hook函式實現機制

hook函式實現的原理是在module__call()__函式進行攔截,__call()__函式可以分為 4 個部分:

  • 第 1 部分是實現 _forward_pre_hooks
  • 第 2 部分是實現 forward 前向傳播
  • 第 3 部分是實現 _forward_hooks
  • 第 4 部分是實現 _backward_hooks

由於卷積層也是一個module,因此可以記錄_forward_hooks

    def __call__(self, *input, **kwargs):
    	# 第 1 部分是實現 _forward_pre_hooks
        for hook in self._forward_pre_hooks.values():
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result

        # 第 2 部分是實現 forward 前向傳播
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)

        # 第 3 部分是實現 _forward_hooks
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result

        # 第 4 部分是實現 _backward_hooks
        if len(self._backward_hooks) > 0:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in self._backward_hooks.values():
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

Hook 函式提取網路的特徵圖

下面通過hook函式獲取 AlexNet 每個卷積層的所有卷積核引數,以形狀作為 key,value 對應該層多個卷積核的 list。然後取出每層的第一個卷積核,形狀是 [1, in_channle, h, w],轉換為 [in_channle, 1, h, w],使用 TensorBoard 進行視覺化,程式碼如下:

    writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")

    # 資料
    path_img = "imgs/lena.png"     # your path to image
    normMean = [0.49139968, 0.48215827, 0.44653124]
    normStd = [0.24703233, 0.24348505, 0.26158768]

    norm_transform = transforms.Normalize(normMean, normStd)
    img_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        norm_transform
    ])

    img_pil = Image.open(path_img).convert('RGB')
    if img_transforms is not None:
        img_tensor = img_transforms(img_pil)
    img_tensor.unsqueeze_(0)    # chw --> bchw

    # 模型
    alexnet = models.alexnet(pretrained=True)

    # 註冊hook
    fmap_dict = dict()
    for name, sub_module in alexnet.named_modules():

        if isinstance(sub_module, nn.Conv2d):
            key_name = str(sub_module.weight.shape)
            fmap_dict.setdefault(key_name, list())
            # 由於AlexNet 使用 nn.Sequantial 包裝,所以 name 的形式是:features.0  features.1
            n1, n2 = name.split(".")

            def hook_func(m, i, o):
                key_name = str(m.weight.shape)
                fmap_dict[key_name].append(o)

            alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)

    # forward
    output = alexnet(img_tensor)

    # add image
    for layer_name, fmap_list in fmap_dict.items():
        fmap = fmap_list[0]# 取出第一個卷積核的引數
        fmap.transpose_(0, 1) # 把 BCHW 轉換為 CBHW

        nrow = int(np.sqrt(fmap.shape[0]))
        fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
        writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)

使用 TensorBoard 進行視覺化如下:

[PyTorch 學習筆記] 5.2 Hook 函式與 CAM 演算法

CAM(class activation map, 類啟用圖)

暫未完成。列出兩個參考文章。

參考資料


如果你覺得這篇文章對你有幫助,不妨點個贊,讓我有更多動力寫出好文章。

相關文章