ptorch常用程式碼梯度篇(梯度裁剪、梯度累積、凍結預訓練層等)

MapleTx發表於2022-05-07

梯度裁剪(Gradient Clipping)

在訓練比較深或者迴圈神經網路模型的過程中,我們有可能發生梯度爆炸的情況,這樣會導致我們模型訓練無法收斂。 我們可以採取一個簡單的策略來避免梯度的爆炸,那就是梯度截斷 Clip, 將梯度約束在某一個區間之內,在訓練的過程中,在優化器更新之前進行梯度截斷操作!!!!! 注意這個方法只在訓練的時候使用,在測試的時候驗證和測試的時候不用。

整個流程簡單總結如下:

  1. 載入訓練資料和標籤
  2. 模型輸入輸出
  3. 計算 loss 函式值
  4. loss 反向傳播
  5. 梯度截斷
  6. 優化器更新梯度引數
import torch.nn as nn
outputs = model(data)
loss= loss_fn(outputs, target)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()
optimizer.zero_grad()

nn.utils.clip_grad_norm_ 輸入是(NN 引數,最大梯度範數,範數型別 = 2) 一般預設為 L2 範數。

梯度累積

常規網路如下:

# 正常網路
optimizer.zero_grad()
for idx, (x, y) in enumerate(train_loader):
    pred = model(x)
    loss = criterion(pred, y)
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    if (idx+1) % eval_steps == 0:
        eval()

需要梯度累計時,每個 mini-batch 仍然正常前向傳播以及反向傳播,但是反向傳播之後並不進行梯度清零,因為 PyTorch 中的 loss.backward() 執行的是梯度累加的操作,所以當我們呼叫 4 次 loss.backward() 後,這 4 個 mini-batch 的梯度都會累加起來。但是,我們需要的是一個平均的梯度,或者說平均的損失,所以我們應該將每次計算得到的 loss除以 accum_steps

# 梯度累積

accum_steps = 4
optimizer.zero_grad()
for idx, (x, y) in enumerate(train_loader):
    pred = model(x)
    loss = criterion(pred, y)
    
    # normlize loss to account for batch accumulation
    loss = loss / accum_steps
    
    loss.backward()
    
    if (idx+1) % accum_steps == 0 or (idx+1) == len(train_loader):
        optimizer.step()
        optimizer.zero_grad()
        if (idx+1) % eval_steps:
            eval()

總的來說,梯度累加就是計算完每個 mini-batch 的梯度後不清零,而是做梯度的累加,當累加到一定的次數之後再更新網路引數,然後將梯度清零。通過這種延遲更新的手段,可以實現與採用大 batch_size 相近的效果

凍結某些層

在載入預訓練模型的時候,我們有時想凍結前面幾層,使其引數在訓練過程中不發生變化。

我們需要先知道每一層的名字,通過如下程式碼列印:

net = Network()  # 獲取自定義網路結構
for name, value in net.named_parameters():
    print('name: {0},\t grad: {1}'.format(name, value.requires_grad))

假設前幾層資訊如下:

name: cnn.VGG_16.convolution1_1.weight,   grad: True
name: cnn.VGG_16.convolution1_1.bias,   grad: True
name: cnn.VGG_16.convolution1_2.weight,   grad: True
name: cnn.VGG_16.convolution1_2.bias,   grad: True
name: cnn.VGG_16.convolution2_1.weight,   grad: True
name: cnn.VGG_16.convolution2_1.bias,   grad: True
name: cnn.VGG_16.convolution2_2.weight,   grad: True
name: cnn.VGG_16.convolution2_2.bias,   grad: True

後面的 True 表示該層的引數可訓練,然後我們定義一個要凍結的層的列表:

no_grad = [
    'cnn.VGG_16.convolution1_1.weight',
    'cnn.VGG_16.convolution1_1.bias',
    'cnn.VGG_16.convolution1_2.weight',
    'cnn.VGG_16.convolution1_2.bias'
]

凍結方法如下:

# net = Net.CTPN()  # 獲取網路結構
net = Network() 
for name, value in net.named_parameters():
    if name in no_grad:
        value.requires_grad = False
    else:
        value.requires_grad = True

凍結後我們再列印每層的資訊:

name: cnn.VGG_16.convolution1_1.weight,   grad: False
name: cnn.VGG_16.convolution1_1.bias,   grad: False
name: cnn.VGG_16.convolution1_2.weight,   grad: False
name: cnn.VGG_16.convolution1_2.bias,   grad: False
name: cnn.VGG_16.convolution2_1.weight,   grad: True
name: cnn.VGG_16.convolution2_1.bias,   grad: True
name: cnn.VGG_16.convolution2_2.weight,   grad: True
name: cnn.VGG_16.convolution2_2.bias,   grad: True

可以看到前兩層的 weight 和 bias 的 requires_grad 都為 False,表示它們不可訓練。

最後在定義優化器時,只對 requires_grad 為 True 的層的引數進行更新。(這裡用filter篩選只傳入了requires_grad為True的引數,但如果直接傳入全部引數應該也可以達到只訓練未凍結層引數的效果)

optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.01)

其他注意事項

  1. with torch.no_grad()或者@torch.no_grad()中的資料不需要計算梯度,也不會進行反向傳播。不需要計算梯度的程式碼塊(如驗證測試)用 with torch.no_grad() 包含起來,節省視訊記憶體
model.eval()
with torch.no_grad():
   pass
@torch.no_grad()
def eval():
	...
  1. model.eval() 和 torch.no_grad() 的區別在於,model.eval() 是將網路切換為測試狀態,例如 BN 和dropout在訓練和測試階段使用不同的計算方法。torch.no_grad() 是關閉 PyTorch 張量的自動求導機制,以減少儲存使用和加速計算,得到的結果無法進行 loss.backward()。

  2. model.zero_grad()會把整個模型的引數的梯度都歸零, 而optimizer.zero_grad()只會把傳入其中的引數的梯度歸零.

  3. loss.backward() 前用 optimizer.zero_grad() 清除累積梯度。如果在迴圈裡需要把optimizer.zero_grad()寫在後面,那應該在迴圈外需要先呼叫一次optimizer.zero_grad()

  4. 檢視網路中的梯度

params = list(model.named_parameters())
(name, param) = params[28]
print(name)
print(param.grad)
print('-------------------------------------------------')

相關文章