基於pytorch實現模型剪枝

嵌入式視覺發表於2023-02-23

一,剪枝分類

所謂模型剪枝,其實是一種從神經網路中移除"不必要"權重或偏差(weigths/bias)的模型壓縮技術。關於什麼引數才是“不必要的”,這是一個目前依然在研究的領域。

1.1,非結構化剪枝

非結構化剪枝(Unstructured Puning)是指修剪引數的單個元素,比如全連線層中的單個權重、卷積層中的單個卷積核引數元素或者自定義層中的浮點數(scaling floats)。其重點在於,剪枝權重物件是隨機的,沒有特定結構,因此被稱為非結構化剪枝

1.2,結構化剪枝

與非結構化剪枝相反,結構化剪枝會剪枝整個引數結構。比如,丟棄整行或整列的權重,或者在卷積層中丟棄整個過濾器(Filter)。

1.3,本地與全域性修剪

剪枝可以在每層(區域性)或多層/所有層(全域性)上進行。

二,PyTorch 的剪枝

目前 PyTorch 框架支援的權重剪枝方法有:

  • Random: 簡單地修剪隨機引數。
  • Magnitude: 修剪權重最小的引數(例如它們的 L2 範數)

以上兩種方法實現簡單、計算容易,且可以在沒有任何資料的情況下應用。

2.1,pytorch 剪枝工作原理

剪枝功能在 torch.nn.utils.prune 類中實現,程式碼在檔案 torch/nn/utils/prune.py 中,主要剪枝類如下圖所示。

pytorch_pruning_api_file.png

剪枝原理是基於張量(Tensor)的掩碼(Mask)實現。掩碼是一個與張量形狀相同的布林型別的張量,掩碼的值為 True 表示相應位置的權重需要保留,掩碼的值為 False 表示相應位置的權重可以被刪除。

Pytorch 將原始引數 <param> 複製到名為 <param>_original 的引數中,並建立一個緩衝區來儲存剪枝掩碼 <param>_mask。同時,其也會建立一個模組級的 forward_pre_hook 回撥函式(在模型前向傳播之前會被呼叫的回撥函式),將剪枝掩碼應用於原始權重。

pytorch 剪枝的 api 和教程比較混亂,我個人將做了如下表格,希望能將 api 和剪枝方法及分類總結好。

pytorch_pruning_api

pytorch 中進行模型剪枝的工作流程如下:

  1. 選擇剪枝方法(或者子類化 BasePruningMethod 實現自己的剪枝方法)。
  2. 指定剪枝模組和引數名稱。
  3. 設定剪枝方法的引數,比如剪枝比例等。

2.2,區域性剪枝

Pytorch 框架中的區域性剪枝有非結構化和結構化剪枝兩種型別,值得注意的是結構化剪枝只支援區域性不支援全域性。

2.2.1,區域性非結構化剪枝

1,區域性非結構化剪枝(Locall Unstructured Pruning)對應函式原型如下:

def random_unstructured(module, name, amount)

1,函式功能

用於對權重引數張量進行非結構化剪枝。該方法會在張量中隨機選擇一些權重或連線進行剪枝,剪枝率由使用者指定。

2,函式引數定義:

  • module (nn.Module): 需要剪枝的網路層/模組,例如 nn.Conv2d() 和 nn.Linear()。
  • name (str): 要剪枝的引數名稱,比如 "weight" 或 "bias"。
  • amount (int or float): 指定要剪枝的數量,如果是 0~1 之間的小數,則表示剪枝比例;如果是證照,則直接剪去引數的絕對數量。比如amount=0.2 ,表示將隨機選擇 20% 的元素進行剪枝。

3,下面是 random_unstructured 函式的使用示例。

import torch
import torch.nn.utils.prune as prune
conv = torch.nn.Conv2d(1, 1, 4)
prune.random_unstructured(conv, name="weight", amount=0.5)
conv.weight
"""
tensor([[[[-0.1703,  0.0000, -0.0000,  0.0690],
          [ 0.1411,  0.0000, -0.0000, -0.1031],
          [-0.0527,  0.0000,  0.0640,  0.1666],
          [ 0.0000, -0.0000, -0.0000,  0.2281]]]], grad_fn=<MulBackward0>)
"""

可以看書輸出的 conv 層中權重值有一半比例為 0

2.2.2,區域性結構化剪枝

區域性結構化剪枝(Locall Structured Pruning)有兩種函式,對應函式原型如下:

def random_structured(module, name, amount, dim)
def ln_structured(module, name, amount, n, dim, importance_scores=None)

1,函式功能

與非結構化移除的是連線權重不同,結構化剪枝移除的是整個通道權重。

2,引數定義

與區域性非結構化函式非常相似,唯一的區別是您必須定義 dim 引數(ln_structured 函式多了 n 引數)。

n 表示剪枝的範數,dim 表示剪枝的維度。

對於 torch.nn.Linear:

  • dim = 0: 移除一個神經元。

  • dim = 1:移除與一個輸入的所有連線。

對於 torch.nn.Conv2d:

  • dim = 0(Channels) : 通道 channels 剪枝/過濾器 filters 剪枝
  • dim = 1(Neurons): 二維卷積核 kernel 剪枝,即與輸入通道相連線的 kernel

2.2.3,區域性結構化剪枝示例程式碼

在寫示例程式碼之前,我們先需要理解 Conv2d 函式引數、卷積核 shape、軸以及張量的關係。

首先,Conv2d 函式原型如下;

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

而 pytorch 中常規卷積的卷積核權重 shape 都為(C_out, C_in, kernel_height, kernel_width),所以在程式碼中卷積層權重 shape[3, 2, 3, 3],dim = 0 對應的是 shape [3, 2, 3, 3] 中的 3。這裡我們 dim 設定了哪個軸,那自然剪枝之後權重張量對應的軸機會發生變換。

dim

理解了前面的關鍵概念,下面就可以實際使用了,dim=0 的示例如下所示。

conv = torch.nn.Conv2d(2, 3, 3)
norm1 = torch.norm(conv.weight, p=1, dim=[1,2,3])
print(norm1)
"""
tensor([1.9384, 2.3780, 1.8638], grad_fn=<NormBackward1>)
"""
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=0)
print(conv.weight)
"""
tensor([[[[-0.0005,  0.1039,  0.0306],
          [ 0.1233,  0.1517,  0.0628],
          [ 0.1075, -0.0606,  0.1140]],

         [[ 0.2263, -0.0199,  0.1275],
          [-0.0455, -0.0639, -0.2153],
          [ 0.1587, -0.1928,  0.1338]]],


        [[[-0.2023,  0.0012,  0.1617],
          [-0.1089,  0.2102, -0.2222],
          [ 0.0645, -0.2333, -0.1211]],

         [[ 0.2138, -0.0325,  0.0246],
          [-0.0507,  0.1812, -0.2268],
          [-0.1902,  0.0798,  0.0531]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000]]]], grad_fn=<MulBackward0>)
"""

從執行結果可以明顯看出,卷積層引數的最後一個通道引數張量被移除了(為 0 張量),其解釋參見下圖。

dim_understand

dim = 1 的情況:

conv = torch.nn.Conv2d(2, 3, 3)
norm1 = torch.norm(conv.weight, p=1, dim=[0, 2,3])
print(norm1)
"""
tensor([3.1487, 3.9088], grad_fn=<NormBackward1>)
"""
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=1)
print(conv.weight)
"""
tensor([[[[ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000]],

         [[-0.2140,  0.1038,  0.1660],
          [ 0.1265, -0.1650, -0.2183],
          [-0.0680,  0.2280,  0.2128]]],


        [[[-0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]],

         [[-0.2087,  0.1275,  0.0228],
          [-0.1888, -0.1345,  0.1826],
          [-0.2312, -0.1456, -0.1085]]],


        [[[-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]],

         [[-0.0891,  0.0946, -0.1724],
          [-0.2068,  0.0823,  0.0272],
          [-0.2256, -0.1260, -0.0323]]]], grad_fn=<MulBackward0>)
"""

很明顯,對於 dim=1的維度,其第一個張量的 L2 範數更小,所以shape 為 [2, 3, 3] 的張量中,第一個 [3, 3] 張量引數會被移除(即張量為 0 矩陣) 。

2.3,全域性非結構化剪枝

前文的 local 剪枝的物件是特定網路層,而 global 剪枝是將模型看作一個整體去移除指定比例(數量)的引數,同時 global 剪枝結果會導致模型中每層的稀疏比例是不一樣的。

全域性非結構化剪枝函式原型如下:

# v1.4.0 版本
def global_unstructured(parameters, pruning_method, **kwargs)
# v2.0.0-rc2版本
def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):

1,函式功能

隨機選擇全域性所有引數(包括權重和偏置)的一部分進行剪枝,而不管它們屬於哪個層。

2,引數定義

  • parameters((Iterable of (module, name) tuples)): 修剪模型的引數列表,列表中的元素是 (module, name)。
  • pruning_method(function): 目前好像官方只支援 pruning_method=prune.L1Unstuctured,另外也可以是自己實現的非結構化剪枝方法函式。
  • importance_scores: 表示每個引數的重要性得分,如果為 None,則使用預設得分。
  • **kwargs: 表示傳遞給特定剪枝方法的額外引數。比如 amount 指定要剪枝的數量。

3,global_unstructured 函式的示例程式碼如下所示。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)
# 計算卷積層和整個模型的稀疏度
# 其實呼叫的是 Tensor.numel 內內函式,返回輸入張量中元素的總數
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)
# 程式執行結果
"""
Sparsity in conv1.weight: 3.70%
Global sparsity: 20.00%
"""

執行結果表明,雖然模型整體(全域性)的稀疏度是 20%,但每個網路層的稀疏度不一定是 20%。

三,總結

另外,pytorch 框架還提供了一些幫助函式:

  1. torch.nn.utils.prune.is_pruned(module): 判斷模組 是否被剪枝。
  2. torch.nn.utils.prune.remove(module, name): 用於將指定模組中指定引數上的剪枝操作移除,從而恢復該引數的原始形狀和數值。

雖然 PyTorch 提供了內建剪枝 API ,也支援了一些非結構化和結構化剪枝方法,但是 API 比較混亂,對應文件描述也不清晰,所以後面我還會結合微軟的開源 nni 工具來實現模型剪枝功能。

參考資料

  1. How to Prune Neural Networks with PyTorch
  2. PRUNING TUTORIAL
  3. PyTorch Pruning

相關文章