Pytorch Optimizer類使用小技巧總結
一、固定部分網路層引數
1. 將需要固定,不參與訓練層引數的requires_grad屬性設為False:
# 在nn.Modele子類內固定features層引數
for p in self.features.parameters():
p.requires_grad=False
2. 將參與訓練的層引數傳入Optimizer:
param_to_optim = []
for param in self.model.parameters():
if param.requires_grad == False:
continue
param_to_optim.append(param)
optimizer = torch.optim.SGD(param_to_optim, lr=0.001, momentum=0.9, weight_decay=1e-4)
或者:
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, momentum=0.9, weight_decay=1e-4)
二、自定義學習率衰減
def adjust_learning_rate(args, optimizer, epoch, gamma=0.1):
# 每訓練args.step_size個epochs,學習率衰減到gamma倍
lr = args.lr * (gamma ** (epoch // args.step_size))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
三、一個optimizer內多組引數或多個optimizer
# 一個optimizer內多組引數,可以為不同組引數設定不同的學習率
optimizer = optim.SGD([
{'params': net.conv_block1.parameters(), 'lr': 0.002},
{'params': net.classifier1.parameters(), 'lr': 0.002},
{'params': net.conv_block2.parameters(), 'lr': 0.002},
{'params': net.classifier2.parameters(), 'lr': 0.002},
{'params': net.conv_block3.parameters(), 'lr': 0.002},
{'params': net.classifier3.parameters(), 'lr': 0.002},
{'params': net.features.parameters(), 'lr': 0.0002}
], momentum=0.9, weight_decay=5e-4)
# 定義多個optimizer,訓練網路的不同模組
raw_optimizer = torch.optim.SGD(raw_parameters, lr=LR, momentum=0.9, weight_decay=WD)
concat_optimizer = torch.optim.SGD(concat_parameters, lr=LR, momentum=0.9, weight_decay=WD)
part_optimizer = torch.optim.SGD(part_parameters, lr=LR, momentum=0.9, weight_decay=WD)
partcls_optimizer = torch.optim.SGD(partcls_parameters, lr=LR, momentum=0.9, weight_decay=WD)
待續。。。
相關文章
- 小程式開發技巧總結
- [原始碼解析] PyTorch 分散式(14) --使用 Distributed Autograd 和 Distributed Optimizer原始碼PyTorch分散式
- Vue的使用總結和技巧Vue
- XGBoost類庫使用小結
- PyTorch使用總覽PyTorch
- 前端從業兩年總結的一些js使用小技巧前端JS
- URLConnection類,HttpURLConnection類的使用和總結HTTP
- Mybatis使用小技巧-自定義結果集MyBatis
- gulp技巧總結
- Git 小技巧彙總Git
- R小技巧彙總
- 總結十個Python 字典用法的使用技巧Python
- 有關連結串列的小技巧,我都給你總結好了
- CSS技巧總結2CSS
- 面試技巧總結面試
- Python pyinstaller類庫使用學習總結Python
- Python pycryptodome類庫使用學習總結Python
- Python pymodbus類庫使用學習總結Python
- 填坑總結:python記憶體洩漏排查小技巧Python記憶體
- photoshop使用小技巧
- Windows使用小技巧Windows
- 不定時更新-工具類小技巧
- 使用mpvue開發github小程式總結VueGithub
- [技巧] 做題及考試技巧總結
- Pytorch中stack()方法的總結及理解PyTorch
- phpRedis函式使用總結【分類詳細】PHPRedis函式
- Deep Learning模型中常見的optimizer優化器演算法總結模型優化演算法
- JS傳參技巧總結JS
- Vue 開發技巧總結Vue
- 【個人總結】常用技巧
- VS Code 使用小技巧
- Android studio使用小技巧Android
- Postman 使用小技巧/指南Postman
- SVN小總結
- 小總結(1)
- 小總結吧
- Spring 小總結Spring
- 使用setInterval與clearInterval踩的小坑總結