模型剪枝:剪枝粒度、剪枝標準、剪枝時機、剪枝頻率

passion2021發表於2024-11-15

模型剪枝

模型剪枝:將模型中不重要的權重和分支裁剪掉。將權重矩陣中一部分元素變為零元素。

image-20241113084517922

減去不重要的突觸(Synapses)或神經元(Neurons)。

剪枝型別

非結構化剪枝

非結構化剪枝:破壞了原有模型的結構。

怎麼做:
非結構化剪枝並不關心權重在網路中的位置,只是根據某種標準(例如,權重的絕對值大小)來決定是否移除這個權重。移除權重後,剩下的權重分佈是稀疏的,即大多數權重為零。

實際情況:
非結構化剪枝能極大降低模型的引數量和理論計算量,但是現有硬體架構的計算方式無法對其進行加速,通常需要特殊的硬體或軟體支援來有效利用結果模型的稀疏性。所以在實際執行速度上得不到提升,需要設計特定的硬體才可能加速。

結構化剪枝

結構化剪枝則更加關注模型的組織結構,這種剪枝方法可能涉及到移除整個神經元、卷積核、層或者更復雜的結構。

通常以filter或者整個網路層為基本單位進行剪枝。

一個filter被剪枝,那麼其前一個特徵圖和下一個特徵圖都會發生相應的變化,但是模型的結構卻沒有被破壞,仍然能夠透過 GPU 或其他硬體來加速。

半結構化剪枝

這種剪枝方法可能涉及到移除整個神經元或過濾器的一部分,而不是全部。

通常的做法是按某種規則對結構中的一部分進行剪枝,比如在某個維度上做非結構化剪枝,而在其他維度上保持結構化。

剪枝範圍

區域性剪枝:關注的是模型中的單個權重或引數。這種剪枝方法通常針對模型中的每個權重進行評估,然後決定是否將其設定為零。

全域性剪枝:全域性剪枝則考慮模型的整體結構和效能。這種剪枝方法可能會移除整個神經元、卷積核、層或者更復雜的結構,如卷積核組。全域性剪枝通常需要對模型的整體結構有深入的理解,並且可能涉及到模型架構的重設計。這種方法可能會對模型的最終效能產生更大的影響,因為它改變了模型的整體特徵提取能力。

剪枝粒度

按照剪枝粒度進行劃分,剪枝可分為細粒度剪枝(Fine-grained Pruning)、基於模式的剪枝(Pattern-based Pruning)、向量級剪枝(Vector-level Pruning)、核心級剪枝(Kernel-level Pruning)與通道級剪枝(Channel-level Pruning)。

如下圖所示,展示了從細粒度剪枝到通道級的剪枝,剪枝越來越規則和結構化。

image-20241113092055218

細粒度剪枝

import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time

plt.rcParams['font.sans-serif'] = ['SimHei']  # 解決中文亂碼
# plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']

def timing_decorator(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        execution_time = end_time - start_time
        print("{} 函式的執行時間為:{:.8f} 秒".format(func.__name__, execution_time))
        return result
    return wrapper


# 建立一個視覺化2維矩陣函式,將值為0的元素與其他區分開(用於顯示剪枝效果)
def plot_tensor(tensor, title):
    # 建立一個新的影像和軸
    fig, ax = plt.subplots()

    # 使用 CPU 上的資料,轉換為 numpy 陣列,並檢查相等條件,設定顏色對映
    ax.imshow(tensor.cpu().numpy() == 0, vmin=0, vmax=1, cmap='tab20c')
    ax.set_title(title)
    ax.set_yticklabels([])
    ax.set_xticklabels([])

    # 遍歷矩陣中的每個元素並新增文字標籤
    for i in range(tensor.shape[1]):
        for j in range(tensor.shape[0]):
            text = ax.text(j, i, f'{tensor[i, j].item():.2f}', ha="center", va="center", color="k")

    # 顯示影像
    plt.show()


def test_plot_tensor():
    weight = torch.tensor([[-0.46, -0.40, 0.39, 0.19, 0.37],
                           [0.00, 0.40, 0.17, -0.15, 0.16],
                           [-0.20, -0.23, 0.36, 0.25, 0.03],
                           [0.24, 0.41, 0.07, 0.00, -0.15],
                           [0.48, -0.09, -0.36, 0.12, 0.45]])
    plot_tensor(weight, 'weight')


# 細粒度剪枝方法1
@timing_decorator
def _fine_grained_prune(tensor: torch.Tensor, threshold: float) -> torch.Tensor:
    """
    遍歷矩陣中每個元素,如果元素值小於閾值,則將其設定為0。
    引數太大的話,遍歷會影響到速度,下面將介紹在剪枝中常用的一種方法,即使用mask掩碼矩陣來實現。
    :param tensor: 輸入張量,包含需要剪枝的權重。
    :param threshold: 閾值,用於判斷權重的大小。
    :return: 剪枝後的張量。
    """
    for i in range(tensor.shape[1]):
        for j in range(tensor.shape[0]):
            if tensor[i, j] < threshold:
                tensor[i][j] = 0
    return tensor


# 細粒度剪枝方法2
@timing_decorator
def fine_grained_prune(tensor: torch.Tensor, threshold: float) -> torch.Tensor:
    """
    建立一個掩碼張量,指示哪些權重不應被剪枝(應保持非零)。
    :param tensor: 輸入張量,待剪枝的權重。
    :param threshold: 閾值,用於判斷權重的大小。
    :return: 剪枝後的張量。
    """
    mask = torch.gt(tensor, threshold)
    tensor.mul_(mask)
    return tensor


if __name__ == '__main__':
    # 建立一個矩陣weight
    weight = torch.rand(8, 8)
    plot_tensor(weight, '剪枝前weight')
    pruned_weight1 = _fine_grained_prune(weight, 0.5)
    plot_tensor(weight, '細粒度剪枝後weight1')
    pruned_weight2 = fine_grained_prune(weight, 0.5)
    plot_tensor(pruned_weight2, '細粒度剪枝後weight2')

在掩碼剪枝中,一旦生成了掩碼矩陣(通常是一個與權重矩陣同形狀的二進位制矩陣),你可以直接使用掩碼與權重進行元素級別的運算,而無需再遍歷整個矩陣。

這使得剪枝的過程可以透過向量化操作來加速,尤其是在使用 GPU 時,向量化和矩陣操作比逐元素遍歷更高效。

基於模式的剪枝

import torch
import matplotlib.pyplot as plt
from itertools import permutations

plt.rcParams['font.sans-serif'] = ['SimHei']  # 解決中文亂碼


# 建立一個視覺化2維矩陣函式,將值為0的元素與其他區分開(用於顯示剪枝效果)
def plot_tensor(tensor, title):
    # 建立一個新的影像和軸
    fig, ax = plt.subplots()

    # 使用 CPU 上的資料,轉換為 numpy 陣列,並檢查相等條件,設定顏色對映
    ax.imshow(tensor.cpu().numpy() == 0, vmin=0, vmax=1, cmap='tab20c')
    ax.set_title(title)
    ax.set_yticklabels([])
    ax.set_xticklabels([])

    # 遍歷矩陣中的每個元素並新增文字標籤
    for i in range(tensor.shape[1]):
        for j in range(tensor.shape[0]):
            text = ax.text(j, i, f'{tensor[i, j].item():.2f}', ha="center", va="center", color="k")

    # 顯示影像
    plt.show()


def reshape_1d(tensor, m):
    # 轉換成列為m的格式,若不能整除m則填充0
    if tensor.shape[1] % m > 0:
        mat = torch.FloatTensor(tensor.shape[0], tensor.shape[1] + (m - tensor.shape[1] % m)).fill_(0)
        mat[:, : tensor.shape[1]] = tensor
        return mat.view(-1, m)
    else:
        return tensor.view(-1, m)


def compute_valid_1d_patterns(m, n):
    patterns = torch.zeros(m)
    patterns[:n] = 1
    valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))
    return valid_patterns


def compute_mask(tensor, m, n):
    # tensor={tensor(8,8)}
    # 計算所有可能的模式  patterns={tensor(6,4)}
    patterns = compute_valid_1d_patterns(m, n)
    # 找到m:n最好的模式
    # mask={tensor(16,4)}
    mask = torch.IntTensor(tensor.shape).fill_(1).view(-1, m)  # 使用 -1 讓 PyTorch 自動推導某一維的大小
    # mat={tensor(16,4)}
    mat = reshape_1d(tensor, m)
    # pmax={tensor(16,)} 16x4 4x6 = 16x6 -> argmax = 16
    pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)
    mask[:] = patterns[pmax[:]]  # 選取最好的模式
    mask = mask.view(tensor.shape)  # 得到8x8掩碼矩陣
    return mask


def pattern_pruning(tensor, m, n):
    mask = compute_mask(weight, m, n)
    tensor.mul_(mask)
    return tensor


if __name__ == '__main__':
    # 建立一個矩陣weight
    weight = torch.rand(8, 8)
    plot_tensor(weight, '剪枝前weight')
    pruned_weight = pattern_pruning(weight, 4, 2)
    plot_tensor(pruned_weight, '剪枝後weight')

基於模式的剪枝(Pattern-based Pruning) 是一種透過預定義的模式來決定剪枝的權重的剪枝方法。在這種方法中,剪枝不再是基於單個權重的大小或者梯度,而是基於一組預定義的剪枝模式,模式決定了哪些權重需要被剪枝,哪些需要保留。

1. 概念解釋

NVIDIA 4:2 剪枝 為例,假設我們有一個由 4 個權重組成的單元(例如,4 個過濾器、4 個神經元等),我們選擇其中 2 個權重進行剪枝,也就是說,將 2 個權重置為 0,而保留剩餘的 2 個權重。

  • 模式(Pattern):我們可以定義 6 種可能的剪枝模式,表示從 4 個權重中選擇 2 個權重為 0 的方式。例如,如果我們用 1 表示保留的權重,用 0 表示被剪枝的權重,那麼 6 種可能的模式如下:
    • 1100
    • 1010
    • 1001
    • 0110
    • 0101
    • 0011

每一種模式都表示剪枝過程中保留的權重和被剪枝的權重的組合。

2. 權重矩陣轉換與模式匹配

為了應用這些剪枝模式,我們首先需要將權重矩陣變換為一個適合進行模式匹配的格式:

  1. 將權重矩陣變換為 nx4 形狀:假設原始的權重矩陣是一個 n x 4 的矩陣,其中 n 表示樣本數量或特徵維度,而 4 表示每個樣本的 4 個權重。

  2. 應用模式:為了與預定義的 6 種模式進行匹配,我們需要計算每個樣本在這 4 個權重中符合哪一種模式。計算的結果是一個 n x 6 的矩陣,表示每個樣本與每種模式的匹配程度(例如,可以是權重的總和、或者其他一些指標,如均值、方差等)。

  3. 選擇最佳模式:對於每個樣本,我們透過 argmax 操作,在 n 維度上選擇最大值的索引,表示該樣本與某一種模式最匹配。得到的索引對應於 6 種模式之一。

  4. 構建掩碼(Mask)矩陣:最後,根據選擇的模式索引,我們將這些索引對映到對應的模式上,構建一個掩碼矩陣。該掩碼矩陣會告訴我們哪些權重應該被保留,哪些應該被剪枝。

3. 詳細步驟解釋

讓我們透過一個具體的例子來詳細理解這個過程:

假設我們有一個 n x 4 的權重矩陣 W,每行是一個 4 維的權重向量:

W = [
    [0.5, 0.2, 0.3, 0.8],  # 第一個樣本的4個權重
    [0.4, 0.1, 0.7, 0.6],  # 第二個樣本的4個權重
    [0.6, 0.5, 0.4, 0.3]   # 第三個樣本的4個權重
]

然後,我們定義了 6 種剪枝模式,如下:

Pattern 1: 1100 (保留第 1 和第 2 個權重)
Pattern 2: 1010 (保留第 1 和第 3 個權重)
Pattern 3: 1001 (保留第 1 和第 4 個權重)
Pattern 4: 0110 (保留第 2 和第 3 個權重)
Pattern 5: 0101 (保留第 2 和第 4 個權重)
Pattern 6: 0011 (保留第 3 和第 4 個權重)
  1. 計算與模式匹配:我們可以透過計算每個樣本在 4 個權重中的值與每種模式的相似性來得出一個 n x 6 的矩陣。例如,計算每個樣本的權重和每種模式的匹配度,可能採用簡單的加和或者其他複雜的指標。

    假設我們對每種模式計算權重的總和,結果如下:

    match_matrix = [
        [1.0, 0.8, 0.7, 1.0, 0.9, 0.6],  # 第一個樣本與每個模式的匹配度
        [0.9, 0.7, 1.1, 0.9, 1.2, 0.5],  # 第二個樣本與每個模式的匹配度
        [1.1, 1.0, 0.9, 1.0, 1.0, 1.1]   # 第三個樣本與每個模式的匹配度
    ]
    
  2. 選擇最佳模式:透過對 match_matrix 進行 argmax 操作,我們可以選擇每個樣本與哪一種模式最匹配:

    best_pattern_indices = [0, 4, 5]  # 對應樣本 1 最匹配模式 1,樣本 2 最匹配模式 5,樣本 3 最匹配模式 6
    
  3. 填充掩碼(Mask)矩陣:根據每個樣本選擇的模式,我們填充掩碼矩陣。例如,樣本 1 選擇了模式 1(即 1100),樣本 2 選擇了模式 5(即 0101),樣本 3 選擇了模式 6(即 0011)。

    最終得到的掩碼矩陣 mask 就是:

    mask = [
        [1, 1, 0, 0],  # 樣本 1 對應模式 1
        [0, 1, 0, 1],  # 樣本 2 對應模式 5
        [0, 0, 1, 1]   # 樣本 3 對應模式 6
    ]
    
  4. 應用掩碼到權重矩陣:將這個掩碼矩陣與權重矩陣進行逐元素相乘,就完成了剪枝操作。

4. 總結

基於模式的剪枝透過以下步驟提升了效率:

  1. 預定義模式:定義剪枝模式,而不是針對每個權重進行逐一選擇。
  2. 模式匹配:透過計算每個樣本與模式的匹配度,並選擇最佳匹配的模式。
  3. 掩碼應用:透過掩碼矩陣直接將剪枝資訊應用到權重矩陣中,避免了頻繁的元素遍歷和修改操作。

相比於逐個權重剪枝,基於模式的剪枝能夠更高效地處理剪枝任務,特別是在大規模的模型中。

向量級別剪枝

import torch
import matplotlib.pyplot as plt
from itertools import permutations

plt.rcParams['font.sans-serif'] = ['SimHei']  # 解決中文亂碼


# 建立一個視覺化2維矩陣函式,將值為0的元素與其他區分開(用於顯示剪枝效果)
def plot_tensor(tensor, title):
    # 建立一個新的影像和軸
    fig, ax = plt.subplots()

    # 使用 CPU 上的資料,轉換為 numpy 陣列,並檢查相等條件,設定顏色對映
    ax.imshow(tensor.cpu().numpy() == 0, vmin=0, vmax=1, cmap='tab20c')
    ax.set_title(title)
    ax.set_yticklabels([])
    ax.set_xticklabels([])

    # 遍歷矩陣中的每個元素並新增文字標籤
    for i in range(tensor.shape[1]):
        for j in range(tensor.shape[0]):
            text = ax.text(j, i, f'{tensor[i, j].item():.2f}', ha="center", va="center", color="k")

    # 顯示影像
    plt.show()
# 剪枝某個點所在的行與列
def vector_pruning(weight, point):
    row, col = point
    prune_weight = weight.clone()
    prune_weight[row, :] = 0
    prune_weight[:, col] = 0
    return prune_weight
if __name__ == '__main__':
    weight = torch.rand(8, 8)
    point = (1, 1)
    prune_weight = vector_pruning(weight, point)
    plot_tensor(prune_weight, '向量級剪枝後weight')

卷積核級別剪枝

tensor = torch.rand((3, 10, 4, 5))  # 3 batch size, 10 channels, 4 height, 5 width
image-20241113131650800

10個通道則1個過濾器有10個卷積核。

image-20241113132059624

紅色的部分代表從中去掉一個卷積核。

import torch
import matplotlib.pyplot as plt
from itertools import permutations

plt.rcParams['font.sans-serif'] = ['SimHei']  # 解決中文亂碼


# 定義視覺化4維張量的函式
def visualize_tensor(tensor, title, batch_spacing=3):
    fig = plt.figure()  # 建立一個新的matplotlib圖形
    ax = fig.add_subplot(111, projection='3d')  # 向圖形中新增一個3D子圖

    # 遍歷張量的批次維度
    for batch in range(tensor.shape[0]):
        # 遍歷張量的通道維度
        for channel in range(tensor.shape[1]):
            # 遍歷張量的高度維度
            for i in range(tensor.shape[2]):
                # 遍歷張量的寬度維度
                for j in range(tensor.shape[3]):
                    # 計算條形的x位置,考慮到不同批次間的間隔
                    x = j + (batch * (tensor.shape[3] + batch_spacing))
                    y = i  # 條形的y位置,即張量的高度維度
                    z = channel  # 條形的z位置,即張量的通道維度
                    # 如果張量在當前位置的值為0,則設定條形顏色為紅色,否則為綠色
                    color = 'red' if tensor[batch, channel, i, j] == 0 else 'green'
                    # 繪製單個3D條形
                    ax.bar3d(x, y, z, 1, 1, 1, shade=True, color=color, edgecolor='black', alpha=0.9)

    ax.set_title(title)  # 設定3D圖形的標題
    ax.set_xlabel('Width')  # 設定x軸標籤,對應張量的寬度維度
    ax.set_ylabel('Height')  # 設定y軸標籤,對應張量的高度維度
    ax.set_zlabel('Channel')  # 設定z軸標籤,對於張量的通道維度
    ax.set_zlim(ax.get_zlim()[::-1])  # 反轉z軸方向
    ax.zaxis.labelpad = 15  # 調整z軸標籤的填充

    plt.show()  # 顯示圖形


def prune_conv_layer(conv_layer, title, percentile=0.2, ):
    prune_layer = conv_layer.clone()

    # 計算每個kernel的L2範數
    l2_norm = torch.norm(prune_layer, p=2, dim=(-2, -1), keepdim=True)
    threshold = torch.quantile(l2_norm, percentile)
    mask = l2_norm > threshold
    prune_layer = prune_layer * mask.float()

    visualize_tensor(prune_layer, title=title)


if __name__ == '__main__':
    # 使用PyTorch建立一個張量
    tensor = torch.rand((3, 10, 4, 5))  # 3 batch size, 10 channels, 4 height, 5 width
    # 呼叫函式進行剪枝
    pruned_tensor = prune_conv_layer(tensor, 'Kernel級別剪枝')

過濾器級別剪枝

image-20241113132441778

相當於這一組卷積核的結果都不要了。

import torch
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']  # 解決中文亂碼


# 定義視覺化4維張量的函式
def visualize_tensor(tensor, title, batch_spacing=3):
    fig = plt.figure()  # 建立一個新的matplotlib圖形
    ax = fig.add_subplot(111, projection='3d')  # 向圖形中新增一個3D子圖

    # 遍歷張量的批次維度
    for batch in range(tensor.shape[0]):
        # 遍歷張量的通道維度
        for channel in range(tensor.shape[1]):
            # 遍歷張量的高度維度
            for i in range(tensor.shape[2]):
                # 遍歷張量的寬度維度
                for j in range(tensor.shape[3]):
                    # 計算條形的x位置,考慮到不同批次間的間隔
                    x = j + (batch * (tensor.shape[3] + batch_spacing))
                    y = i  # 條形的y位置,即張量的高度維度
                    z = channel  # 條形的z位置,即張量的通道維度
                    # 如果張量在當前位置的值為0,則設定條形顏色為紅色,否則為綠色
                    color = 'red' if tensor[batch, channel, i, j] == 0 else 'green'
                    # 繪製單個3D條形
                    ax.bar3d(x, y, z, 1, 1, 1, shade=True, color=color, edgecolor='black', alpha=0.9)

    ax.set_title(title)  # 設定3D圖形的標題
    ax.set_xlabel('Width')  # 設定x軸標籤,對應張量的寬度維度
    ax.set_ylabel('Height')  # 設定y軸標籤,對應張量的高度維度
    ax.set_zlabel('Channel')  # 設定z軸標籤,對於張量的通道維度
    ax.set_zlim(ax.get_zlim()[::-1])  # 反轉z軸方向
    ax.zaxis.labelpad = 15  # 調整z軸標籤的填充

    plt.show()  # 顯示圖形


def prune_conv_layer(conv_layer, prune_method, title="", percentile=0.2, vis=True):
    prune_layer = conv_layer.clone()

    l2_norm = None
    mask = None

    # 計算每個Filter的L2範數
    l2_norm = torch.norm(prune_layer, p=2, dim=(1, 2, 3), keepdim=True)
    threshold = torch.quantile(l2_norm, percentile)
    mask = l2_norm > threshold
    prune_layer = prune_layer * mask.float()

    visualize_tensor(prune_layer, title=prune_method)

if __name__ == '__main__':
    # 使用PyTorch建立一個張量
    tensor = torch.rand((3, 10, 4, 5))

    # 呼叫函式進行剪枝

    pruned_tensor = prune_conv_layer(tensor, 'Filter級別剪枝', vis=True)

通道級別剪枝

image-20241113132703072
import torch
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']  # 解決中文亂碼


# 定義視覺化4維張量的函式
def visualize_tensor(tensor, title, batch_spacing=3):
    fig = plt.figure()  # 建立一個新的matplotlib圖形
    ax = fig.add_subplot(111, projection='3d')  # 向圖形中新增一個3D子圖

    # 遍歷張量的批次維度
    for batch in range(tensor.shape[0]):
        # 遍歷張量的通道維度
        for channel in range(tensor.shape[1]):
            # 遍歷張量的高度維度
            for i in range(tensor.shape[2]):
                # 遍歷張量的寬度維度
                for j in range(tensor.shape[3]):
                    # 計算條形的x位置,考慮到不同批次間的間隔
                    x = j + (batch * (tensor.shape[3] + batch_spacing))
                    y = i  # 條形的y位置,即張量的高度維度
                    z = channel  # 條形的z位置,即張量的通道維度
                    # 如果張量在當前位置的值為0,則設定條形顏色為紅色,否則為綠色
                    color = 'red' if tensor[batch, channel, i, j] == 0 else 'green'
                    # 繪製單個3D條形
                    ax.bar3d(x, y, z, 1, 1, 1, shade=True, color=color, edgecolor='black', alpha=0.9)

    ax.set_title(title)  # 設定3D圖形的標題
    ax.set_xlabel('Width')  # 設定x軸標籤,對應張量的寬度維度
    ax.set_ylabel('Height')  # 設定y軸標籤,對應張量的高度維度
    ax.set_zlabel('Channel')  # 設定z軸標籤,對於張量的通道維度
    ax.set_zlim(ax.get_zlim()[::-1])  # 反轉z軸方向
    ax.zaxis.labelpad = 15  # 調整z軸標籤的填充

    plt.show()  # 顯示圖形


def prune_conv_layer(conv_layer, prune_method, title="", percentile=0.2, vis=True):
    prune_layer = conv_layer.clone()

    l2_norm = None
    mask = None

    # 計算每個channel的L2範數
    l2_norm = torch.norm(prune_layer, p=2, dim=(0, 2, 3), keepdim=True)
    threshold = torch.quantile(l2_norm, percentile)
    mask = l2_norm > threshold
    prune_layer = prune_layer * mask.float()

    visualize_tensor(prune_layer, title=prune_method)


# 使用PyTorch建立一個張量
tensor = torch.rand((3, 10, 4, 5))

# 呼叫函式進行剪枝

pruned_tensor = prune_conv_layer(tensor, 'Channel級別剪枝', vis=True)

所有級別剪枝對比:

import torch
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']  # 解決中文亂碼


# 建立一個視覺化2維矩陣函式,將值為0的元素與其他區分開(用於顯示剪枝效果)
def plot_tensor(tensor, title):
    # 建立一個新的影像和軸
    fig, ax = plt.subplots()

    # 使用 CPU 上的資料,轉換為 numpy 陣列,並檢查相等條件,設定顏色對映
    ax.imshow(tensor.cpu().numpy() == 0, vmin=0, vmax=1, cmap='tab20c')
    ax.set_title(title)
    ax.set_yticklabels([])
    ax.set_xticklabels([])

    # 遍歷矩陣中的每個元素並新增文字標籤
    for i in range(tensor.shape[1]):
        for j in range(tensor.shape[0]):
            text = ax.text(j, i, f'{tensor[i, j].item():.2f}', ha="center", va="center", color="k")

    # 顯示影像
    plt.show()


# 剪枝某個點所在的行與列
def vector_pruning(weight, point):
    row, col = point
    prune_weight = weight.clone()
    prune_weight[row, :] = 0
    prune_weight[:, col] = 0
    return prune_weight


if __name__ == '__main__':
    weight = torch.rand(8, 8)
    point = (1, 1)
    prune_weight = vector_pruning(weight, point)
    plot_tensor(prune_weight, '向量級剪枝後weight')

剪枝標準

模型剪枝之所以有效,主要是因為它能夠識別並移除那些對模型效能影響較小的引數,從而減少模型的複雜性和計算成本。

其背後的理論依據主要集中在以下幾個方面:

  • 彩票假說:該假說認為,在隨機初始化的大型神經網路中,存在一個子網路,如果獨立訓練,可以達到與完整網路相似的效能。這表明網路中並非所有部分都對最終效能至關重要,從而為剪枝提供了理論支援。
  • 網路稀疏性:研究發現,許多深度神經網路引數呈現出稀疏性,即大部分引數值接近於零。這種稀疏性啟發了剪枝技術,即透過移除這些非顯著的引數來簡化模型。
  • 剪枝的一個重要理論來源是正則化,特別是L1正則化,它鼓勵網路學習稀疏的引數分佈。稀疏化的模型更容易進行剪枝,因為許多權重接近於零,可以安全移除。
  • 權重的重要性:剪枝演算法通常基於權重的重要性來決定是否剪枝。權重的重要性可以透過多種方式評估,例如權重的大小權重對損失函式的梯度、或者權重對輸入的啟用情況等。

怎麼確定要減掉哪些呢?這就涉及到剪枝標準。

基於權重大小

這種剪枝方法基於一個假設,即權重的絕對值越小,該權重對模型的輸出影響越小,因此移除它們對模型效能的影響也會較小。

image-20241113133840952

這裡也就是計算每個格子中權重的絕對值,絕對值大的保留,小的移除。

L1和L2正則化是機器學習中常用的正則化技術,它們透過在損失函式中新增額外的懲罰項來防止模型過擬合。

L1和L2正則化

深入理解L1、L2正則化 - ZingpLiu - 部落格園

正則化是機器學習中對原始損失函式引入額外資訊,以便防止過擬合和提高模型泛化效能的一類方法的統稱。也就是目標函式變成了原始損失函式+額外項,常用的額外項一般有兩種,英文稱作ℓ1−normℓ1−norm和ℓ2−normℓ2−norm,中文稱作L1正則化L2正則化,或者L1範數和L2範數(實際是L2範數的平方)。

正則化技術(如L1和L2)透過限制模型的權重來控制模型的複雜度,避免模型過擬合。對於一個包含多個特徵的模型,如果所有特徵的權重都很大,說明模型可能對每個特徵都高度依賴,這樣容易在訓練集上過擬合。

我們將L1或L2正則化加入到損失函式中,目的是懲罰那些過大的權重。懲罰項的作用是增加模型訓練時的成本,從而迫使模型儘可能避免使用過大的權重值。

  • 懲罰表示當模型的權重過大時,正則化項會增加損失函式的值,使得模型更傾向於選擇較小的權重。這就像給模型設定了一種懲罰規則,避免它在訓練過程中“過度自信”地依賴某些特徵。

  • 控制複雜度:懲罰項的加入,限制了模型引數的大小,減少了模型對訓練資料的過擬合。

在沒有正則化的情況下,模型僅僅關注最小化預測誤差(即損失函式),它可能會透過對某些特徵賦予很大的權重來達到最小化損失,這會導致過擬合。加入正則化項後,損失函式不僅考慮預測誤差,還會考慮模型的複雜度,這樣就能夠找到一個平衡點,避免模型過度擬合。

L1 正則化

image-20241115103931172

L1正則化的加入項是絕對值之和,這意味著它可以產生稀疏解——有些權重會被壓縮為零,導致對應的特徵完全被剔除。這樣做的好處是,模型變得更加簡潔和可解釋,同時可以進行特徵選擇,僅保留那些最重要的特徵。

L2 正則化

image-20241115103945543
L2正則化傾向於使得權重變小,但不會將權重壓縮為零。它的作用是讓模型更穩定,減少對某些特徵的過度依賴,但不會像L1正則化那樣進行特徵選擇。

L1、L2正則化剪枝

L1和L2正則化基本思想是以行為單位,計算每行的重要性,移除權重中那些重要性較小的行。

L1行剪枝:

image-20241115104213213

L2行剪枝:

image-20241115104246680

LeNet

# 定義一個LeNet網路
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=num_classes)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))

        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x
  • 卷積層 (conv1)

    • nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
    • 輸入的影像通道數為 1(灰度影像),輸出 6 個特徵圖,每個特徵圖大小為 28x28(5x5 卷積核,影像尺寸會變小)。
  • 卷積層 (conv2)

    • nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
    • 輸入 6 個特徵圖,輸出 16 個特徵圖。每個特徵圖大小為 10x10(再次進行 5x5 卷積)。
  • 池化層 (maxpool)

    • nn.MaxPool2d(kernel_size=2, stride=2)
    • 2x2 的最大池化操作,步長為 2,這會將每個特徵圖的尺寸縮小一半。
  • 全連線層 (fc1, fc2, fc3)

    • nn.Linear(in_features=16 * 4 * 4, out_features=120)
    • 第一個全連線層,將 16 個 4x4 的特徵圖展平為 1D 向量,輸入 256 個特徵,輸出 120 個神經元。
    • nn.Linear(in_features=120, out_features=84)
    • 第二個全連線層,輸入 120 個神經元,輸出 84 個神經元。
    • nn.Linear(in_features=84, out_features=num_classes)
    • 第三個全連線層,輸出最終的分類結果,這裡 num_classes=10 對應 MNIST 資料集的 10 個數字類別。

forward 方法

  • 該方法定義了模型的前向傳播過程。

  • 第一層卷積和池化

    • x = self.maxpool(F.relu(self.conv1(x)))
    • 對輸入 x 進行卷積(conv1),然後透過 ReLU 啟用函式,再透過最大池化層(maxpool)。
  • 第二層卷積和池化

    • x = self.maxpool(F.relu(self.conv2(x)))
    • 同樣,對卷積(conv2)的輸出進行 ReLU 啟用和池化。
  • 展平

    • x = x.view(x.size()[0], -1)
    • 將經過卷積和池化後的輸出展平為 1D 向量,為進入全連線層做準備。x.size()[0] 表示批次大小,-1 表示自動計算其餘維度。
  • 全連線層

    • x = F.relu(self.fc1(x))
    • x = F.relu(self.fc2(x))
    • x = self.fc3(x)
    • 使用 ReLU 啟用函式處理全連線層的輸出,並最終得到分類結果。

基於L1權重大小的剪枝

@torch.no_grad()
def prune_l1(weight, percentile=0.5):
    # 計算權重個數 2400=16*6*5*5
    num_elements = weight.numel()

    # 計算值為0的數量 num_zeros=200
    num_zeros = round(num_elements * percentile)
    # 計算weight的重要性 tensor{(16,6,5,5)}
    importance = weight.abs()
    # 計算裁剪閾值 tensor(0.0451, device='cuda:0')
    threshold = importance.view(-1).kthvalue(num_zeros).values
    # 計算mask (小於閾值的設定為False,大於閾值的設定為True)
    mask = torch.gt(importance, threshold)

    # 計算mask後的weight
    weight.mul_(mask)
    return weight

這段程式碼是一個 L1 正則化剪枝(pruning) 函式,目的是透過 裁剪 (prune)掉網路中一些不重要的權重,以減小模型的複雜度,通常用於模型壓縮和加速推理過程。

  • @torch.no_grad()
    這個裝飾器告訴 PyTorch 在該函式執行時不計算梯度。即使在該函式內部做了修改(如 weight.mul_(mask)),也不會追蹤這些操作的梯度。這通常用於推理或一些不需要梯度計算的操作,避免額外的記憶體開銷。

引數

  • weight
    這是模型某層的權重張量(tensor),通常是一個二維張量,對應於卷積層或全連線層的權重矩陣。

  • percentile
    這是一個介於 0 到 1 之間的浮動值,表示要裁剪掉的權重的比例。例如,percentile=0.5 表示剪掉最小的一半權重。

詳細步驟

  1. 計算權重的元素數量
    num_elements = weight.numel()
    這行程式碼計算 weight 張量中元素的總數量(即權重的個數)。

  2. 計算需要剪去的權重數量
    num_zeros = round(num_elements * percentile)
    這裡計算需要剪去的權重數量。percentile 決定了要剪去的權重佔比,num_zeros 是該佔比對應的權重數量。

  3. 計算權重的“重要性”
    importance = weight.abs()
    這一步透過對權重取 絕對值 來衡量其“重要性”。一般來說,L1 範數(絕對值)越小的權重,對模型的影響越小,因此可以認為它們較不重要。

  4. 計算裁剪的閾值
    threshold = importance.view(-1).kthvalue(num_zeros).values
    importance 展平為一維向量(view(-1)),然後透過 kthvalue 函式找到第 num_zeros 小的值。這個值即為裁剪閾值,表示剪去比這個值小的權重。

  5. 計算掩碼(Mask)
    mask = torch.gt(importance, threshold)
    這行程式碼生成一個布林值的掩碼(mask),其中 True 表示該權重的重要性大於閾值,False 表示該權重的重要性小於閾值。torch.gt 是“大於”的意思。

  6. 應用掩碼進行剪枝
    weight.mul_(mask)
    使用 mask 來篩選權重,True 的位置保持原值,False 的位置會被設為零。mul_ 是對 weight 進行原地(in-place)乘法操作,即在原始權重張量上直接進行修改。

  7. 返回剪枝後的權重
    return weight
    最終返回經過剪枝後的權重。

總結

這個函式的核心思路是:

  1. 計算每個權重的“重要性”,透過其絕對值(L1 範數)衡量。
  2. 根據設定的 percentile 引數,裁剪掉最不重要的權重。
  3. 使用一個布林掩碼(mask)將不重要的權重置為零,從而實現模型的稀疏化。

剪枝後分布:

image-20241115152630415 image-20241115152614186
  • x 軸代表 權重值的大小,表示模型中每個權重引數的數值範圍。
  • y 軸表示 權重值的密度(density),即單位區間內權重的數量。

減少了一半權重引數:

image-20241115153138672

基於L2權重大小的剪枝

@torch.no_grad()
def prune_l2(weight, percentile=0.5):
    num_elements = weight.numel()

    # 計算值為0的數量
    num_zeros = round(num_elements * percentile)
    # 計算weight的重要性(使用L2範數,即各元素的平方)
    importance = weight.pow(2)  # 這裡和上面不同
    # 計算裁剪閾值
    threshold = importance.view(-1).kthvalue(num_zeros).values
    # 計算mask
    mask = torch.gt(importance, threshold)
    
    # 計算mask後的weight
    weight.mul_(mask)
    return weight

# 裁剪fc1層(全連線)
weight_pruned = prune_l2(model.fc1.weight, percentile=0.4)  # 裁剪40%
# 替換原有model層
model.fc1.weight.data = weight_pruned
# 列出weight直方圖
plot_weight_distribution(model)

裁剪後分布 :

image-20241115154244048 image-20241115154222146

減少了40%引數:

image-20241115154741661

基於梯度大小

核心思想:在模型訓練過程中,權重的梯度反映了權重對輸出損失的影響程度,較大的梯度表示權重對輸出損失的影響較大,因此較重要;較小的梯度表示權重對輸出損失的影響較小,因此較不重要。透過去除較小梯度的權重,可以減少模型的規模,同時保持模型的準確性。

對比以權值大小為重要性依據的剪枝演算法:以人臉識別為例,在人臉的諸多特徵中,眼睛的細微變化如顏色、大小、形狀,對於人臉識別的結果有很大影響。對應到深度網路中的權值,即使權值本身很小,但是它的細微變化對結果也將產生很大的影響,這類權值是不應該被剪掉的。梯度是計算損失函式對權值的偏導數,反映了損失對權值的敏感程度。基於梯度大小的剪枝演算法是一種透過分析模型中權重梯度的方法,來判斷權重的重要性,並去除較小梯度的權重的剪裁方法。

import copy
import math
import random
import time

import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import torch.nn.functional as F

# 設定 matplotlib 使用支援負號的字型
plt.rcParams['font.family'] = 'DejaVu Sans'


# 繪製權重分佈圖
def plot_weight_distribution(model, bins=256, count_nonzero_only=False):
    fig, axes = plt.subplots(2, 3, figsize=(10, 6))

    # 刪除多餘的子圖
    fig.delaxes(axes[1][2])

    axes = axes.ravel()
    plot_index = 0
    for name, param in model.named_parameters():
        if param.dim() > 1:
            ax = axes[plot_index]
            if count_nonzero_only:
                param_cpu = param.detach().view(-1).cpu()
                param_cpu = param_cpu[param_cpu != 0].view(-1)
                ax.hist(param_cpu, bins=bins, density=True,
                        color='green', alpha=0.5)
            else:
                ax.hist(param.detach().view(-1).cpu(), bins=bins, density=True,
                        color='green', alpha=0.5)
            ax.set_xlabel(name)
            ax.set_ylabel('density')
            plot_index += 1
    fig.suptitle('Histogram of Weights')
    fig.tight_layout()
    fig.subplots_adjust(top=0.925)
    plt.show()


# 為避免前面的操作影響後續結果,重新定義一個LeNet網路,和前面一致
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=num_classes)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))

        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device)

# 載入梯度資訊
gradients = torch.load('./model_gradients.pt')
# 載入引數資訊
checkpoint = torch.load('./model.pt')
# 載入狀態字典到模型
model.load_state_dict(checkpoint)


# 修剪整個模型的權重,傳入整個模型
def gradient_magnitude_pruning(model, percentile):
    for name, param in model.named_parameters():
        if 'weight' in name:
            # 當梯度的絕對值大於或等於這個閾值時,權重會被保留。
            mask = torch.abs(gradients[name]) >= percentile
            param.data *= mask.float()


# 修剪區域性模型權重,傳入某一層的權重
@torch.no_grad()
def gradient_magnitude_pruning(weight, gradient, percentile=0.5):
    num_elements = weight.numel()
    # 計算值為0的數量
    num_zeros = round(num_elements * percentile)
    # 計算weight的重要性(使用L1範數)
    importance = gradient.abs()
    # 計算裁剪閾值
    threshold = importance.view(-1).kthvalue(num_zeros).values
    # 計算mask
    mask = torch.gt(importance, threshold)
    # 確保mask和weight在同一裝置上
    mask = mask.to(weight.device)
    # 計算mask後的weight
    weight.mul_(mask)
    return weight


if __name__ == '__main__':
    # 使用示例,這裡以fc2層的權重為例
    percentile = 0.5
    gradient_magnitude_pruning(model.fc2.weight, gradients['fc2.weight'], percentile)
    # 列出weight直方圖
    plot_weight_distribution(model)
image-20241115160423227 image-20241115161534838

基於尺度

通俗理解 Batch Normalization(含程式碼) - 知乎

Network Slimming提出了一種基於尺度(Scaling-based)的剪枝方法。這種方法:剪枝整個通道
識別並剪枝那些對模型輸出影響不大的整個通道(即一組特徵對映),而不是單個權重。

在標準的CNN訓練中,批歸一化(BN)層通常用於加速訓練並提高模型的泛化能力。該方法利用BN層中的縮放因子(γ)來實現稀疏性。這些縮放因子原本用於調節BN層輸出的尺度,但在該方法中,它們被用來指示每個通道的重要性。在訓練過程中,透過在損失函式中新增一個L1正則化項來鼓勵通道的縮放因子趨向於零。這樣,不重要的通道的縮放因子將變得非常小,從而可以被識別並剪枝。

基於二階

基於二階(Second-Order-based)的剪枝方法中最具代表性的是最優腦損傷(Optimal Brain Damage,OBD)。OBD透過最小化由於剪枝突觸引入的損失函式誤差,利用二階導數資訊來評估網路中每個權重的重要性,然後根據這些評估結果來決定哪些權重可以被剪枝。

​ 首先,計算網路損失函式相對於權重的Hessian矩陣。Hessian矩陣是一個方陣,其元素是損失函式相對於網路引數的二階偏導數。它提供了關於引數空間中曲線曲率的資訊,可以用來判斷權重的敏感度。其次,透過分析Hessian矩陣的特徵值,可以確定網路引數的重要性。通常,與較大特徵值相對應的權重被認為是更重要的,因為它們對損失函式的曲率貢獻更大。

image-20241115110319249

從最後的公式可以看出,OBD方法最後只需要考慮矩陣對角線元素,詳細的公式推導過程參考OBD公式推導

剪枝頻率

迭代剪枝

迭代剪枝是一種漸進式的模型剪枝方法,它涉及多個迴圈的剪枝和微調步驟。這個過程逐步削減模型中的權重,而不是一次性剪除大量的權重。迭代剪枝的基本思想是,透過逐步移除權重,可以更細緻地評估每一次剪枝對模型效能的影響,並允許模型有機會調整其餘權重來補償被剪除的權重

迭代剪枝通常遵循以下步驟:

  • 訓練模型:首先訓練一個完整的、未剪枝的模型,使其在訓練資料上達到一個良好的效能水平。
  • 剪枝:使用一個預定的剪枝策略(例如基於權重大小)來輕微剪枝網路,移除一小部分權重。
  • 微調:對剪枝後的模型進行微調,這通常涉及使用原始訓練資料集重新訓練模型,以恢復由於剪枝引起的效能損失。
  • 評估:在驗證集上評估剪枝後模型的效能,確保模型仍然能夠維持良好的效能。
  • 重複:重複步驟2到步驟4,每次迭代剪掉更多的權重,並進行微調,直到達到一個預定的效能標準或剪枝比例。

單次剪枝

  • 定義:在訓練完成後對模型進行一次性的剪枝操作。
  • 優點:這種剪枝方法的特點是高效且直接,它不需要在剪枝和再訓練之間進行多次迭代。
  • 步驟:在One-shot剪枝中,模型首先被訓練到收斂,然後根據某種剪枝標準(如權重的絕對值大小)來確定哪些引數可以被移除。這些引數通常是那些對模型輸出影響較小的引數。
  • 對比迭代式剪枝:單次剪枝會極大地受到噪聲的影響,而迭代式剪枝方法則會好很多,因為它在每次迭代之後只會刪除掉少量的權重,然後週而復始地進行其他輪的評估和刪除,這就能夠在一定程度上減少噪聲對於整個剪枝過程的影響。但對於大模型來說,由於微調的成本太高,所以更傾向於使用單次剪枝方法。

剪枝時機

訓練後剪枝

訓練後剪枝基本思想是先訓練一個模型 ,然後對模型進行剪枝,最後對剪枝後模型進行微調。其核心思想是對模型進行一次訓練,以瞭解哪些神經連線實際上很重要,修剪那些不重要(權重較低)的神經連線,然後再次訓練以瞭解權重的最終值。以下是詳細步驟:

  • 初始訓練:首先,使用標準的反向傳播演算法訓練神經網路。在這個過程中,網路學習到權重(即連線的強度)和網路結構。
  • 識別重要連線:在訓練完成後,網路已經學習到了哪些連線對模型的輸出有顯著影響。通常,權重較大的連線被認為是重要的。
  • 設定閾值:選擇一個閾值,這個閾值用於確定哪些連線是重要的。所有權重低於這個閾值的連線將被視為不重要。
  • 剪枝:移除所有權重低於閾值的連線。這通常涉及到將全連線層轉換為稀疏層,因為大部分連線都被移除了。
  • 重新訓練:在剪枝後,網路的容量減小了,為了補償這種變化,需要重新訓練網路。在這個過程中,網路會調整剩餘連線的權重,以便在保持準確性的同時適應新的結構。
  • 迭代剪枝:剪枝和重新訓練的過程可以迭代進行。每次迭代都會移除更多的連線,直到達到一個平衡點,即在不顯著損失準確性的情況下儘可能減少連線。

訓練時剪枝

訓練時剪枝基本思想是直接在模型訓練過程中進行剪枝,最後對剪枝後模型進行微調。與訓練後剪枝相比,連線在訓練期間根據其重要性動態停用,但允許權重適應並可能重新啟用。訓練時剪枝可以產生更有效的模型,因為不必要的連線會盡早修剪,從而可能減少訓練期間的記憶體和計算需求。然而,它需要小心處理,以避免網路結構的突然變化和過度修剪的風險,這可能會損害效能。深度學習中常用到的Dropout其實就是一種訓練時剪枝方法,在訓練過程中,隨機神經元以一定的機率被“dropout”或設定為零。訓練時剪枝的訓練過程包括以下幾個詳細步驟,以CNN網路為例:

  • 初始化模型引數:首先,使用標準的初始化方法初始化神經網路的權重。
  • 訓練迴圈:在每個訓練週期(epoch)開始時,使用完整的模型引數對訓練資料進行前向傳播和反向傳播,以更新模型權重。
  • 計算重要性:在每個訓練週期結束時,計算每個卷積層中所有過濾器的重要性。
  • 選擇過濾器進行修剪:根據一個預先設定的修剪率,選擇重要性最小的過濾器進行修剪。這些過濾器被認為是不重要的,因為它們對模型輸出的貢獻較小。
  • 修剪過濾器:將選擇的過濾器的權重設定為零,從而在後續的前向傳播中不計算這些過濾器的貢獻。
  • 重建模型:在修剪過濾器之後,繼續進行一個訓練週期。在這個階段,透過反向傳播,允許之前被修剪的過濾器的權重更新,從而恢復模型的容量。
  • 迭代過程:重複上述步驟,直到達到預定的訓練週期數或者模型收斂。

訓練前剪枝

訓練前剪枝基本思想是在模型訓練前進行剪枝,然後從頭訓練剪枝後的模型。這裡就要提及到彩票假設,即任何隨機初始化的稠密的前饋網路都包含具有如下性質的子網路——在獨立進行訓練時,初始化後的子網路在至多經過與原始網路相同的迭代次數後,能夠達到跟原始網路相近的測試準確率。在彩票假設中,剪枝後的網路不是需要進行微調,而是將“中獎”的子網路重置為網路最初的權重後重新訓練,最後得到的結果可以追上甚至超過原始的稠密網路。總結成一句話:隨機初始化的密集神經網路包含一個子網路,該子網路經過初始化,以便在單獨訓練時,在訓練最多相同次數的迭代後,它可以與原始網路的測試精度相匹配。

一開始,神經網路是使用預定義的架構和隨機初始化的權重建立的。這構成了剪枝的起點。基於某些標準或啟發法,確定特定的連線或權重以進行修剪。那麼有個問題,我們還沒有開始訓練模型,那麼我們如何知道哪些連線不重要呢?

目前常用的方式一般是在初始化階段採用隨機剪枝的方法。隨機選擇的連線被修剪,並且該過程重複多次以建立各種稀疏網路架構。這背後的想法是,如果在訓練之前以多種方式進行修剪,可能就能夠跳過尋找彩票的過程。

剪枝時機總結

訓練後剪枝(靜態稀疏性): 初始訓練階段後的修剪涉及在單獨的後處理步驟中從訓練模型中刪除連線或過濾器。這使得模型能夠在訓練過程中完全收斂而不會出現任何中斷,從而確保學習到的表示得到很好的建立。剪枝後,可以進一步微調模型,以從剪枝過程引起的任何潛在效能下降中恢復過來。訓練後的剪枝一般比較穩定,不太可能造成過擬合。適用於針對特定任務微調預訓練模型的場景。

訓練時剪枝(動態稀疏): 在這種方法中,剪枝作為附加正則化技術整合到最佳化過程中。在訓練迭代期間,根據某些標準或啟發方法動態刪除或修剪不太重要的連線。這使得模型能夠探索不同級別的稀疏性並在整個訓練過程中調整其架構。動態稀疏性可以帶來更高效的模型,因為不重要的連線會被儘早修剪,從而可能減少記憶體和計算需求。然而,它需要小心處理,以避免網路結構的突然變化和過度修剪的風險,這可能會損害效能。

訓練前剪枝: 訓練前剪枝涉及在訓練過程開始之前從神經網路中剪枝某些連線或權重。優點在於可以更快地進行訓練,因為初始模型大小減小了,並且網路可以更快地收斂。然而,它需要仔細選擇修剪標準,以避免過於積極地刪除重要連線。

剪枝比例

假設一個模型有很多層,給定一個全域性的剪枝比例,那麼應該怎麼分配每層的剪枝率呢?主要可以分為兩種方法:均勻分層剪枝和非均勻分層剪枝。

  • 均勻分層剪枝(Uniform Layer-Wise Pruning)是指在神經網路的每一層中都應用相同的剪枝率。具體來說,就是對網路的所有層按照統一的標準進行剪枝,無論每一層的權重重要性或梯度如何分佈。這種方法實現簡單,剪枝率容易控制,但它忽略了每一層對模型整體效能的重要性差異。
  • 非均勻分層剪枝(Non-Uniform Layer-Wise Pruning)則根據每一層的不同特點來分配不同的剪枝率。例如,可以根據梯度資訊、權重的大小、或者其他指標(如資訊熵、Hessian矩陣等)來確定每一層的剪枝率。層越重要,保留的引數越多;不重要的層則可以被更大程度地剪枝。如下圖3-9所示,非均勻剪枝往往比均勻剪枝的效能更好。

程式碼

  • 剪枝粒度實踐
  • 剪枝標準實踐
  • 剪枝時機實踐
  • torch中的剪枝演算法實踐

相關文章