Pytorch Optimizer類使用小技巧總結
一、固定部分網路層引數
1. 將需要固定,不參與訓練層引數的requires_grad屬性設為False:
# 在nn.Modele子類內固定features層引數
for p in self.features.parameters():
p.requires_grad=False
2. 將參與訓練的層引數傳入Optimizer:
param_to_optim = []
for param in self.model.parameters():
if param.requires_grad == False:
continue
param_to_optim.append(param)
optimizer = torch.optim.SGD(param_to_optim, lr=0.001, momentum=0.9, weight_decay=1e-4)
或者:
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, momentum=0.9, weight_decay=1e-4)
二、自定義學習率衰減
def adjust_learning_rate(args, optimizer, epoch, gamma=0.1):
# 每訓練args.step_size個epochs,學習率衰減到gamma倍
lr = args.lr * (gamma ** (epoch // args.step_size))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
三、一個optimizer內多組引數或多個optimizer
# 一個optimizer內多組引數,可以為不同組引數設定不同的學習率
optimizer = optim.SGD([
{'params': net.conv_block1.parameters(), 'lr': 0.002},
{'params': net.classifier1.parameters(), 'lr': 0.002},
{'params': net.conv_block2.parameters(), 'lr': 0.002},
{'params': net.classifier2.parameters(), 'lr': 0.002},
{'params': net.conv_block3.parameters(), 'lr': 0.002},
{'params': net.classifier3.parameters(), 'lr': 0.002},
{'params': net.features.parameters(), 'lr': 0.0002}
], momentum=0.9, weight_decay=5e-4)
# 定義多個optimizer,訓練網路的不同模組
raw_optimizer = torch.optim.SGD(raw_parameters, lr=LR, momentum=0.9, weight_decay=WD)
concat_optimizer = torch.optim.SGD(concat_parameters, lr=LR, momentum=0.9, weight_decay=WD)
part_optimizer = torch.optim.SGD(part_parameters, lr=LR, momentum=0.9, weight_decay=WD)
partcls_optimizer = torch.optim.SGD(partcls_parameters, lr=LR, momentum=0.9, weight_decay=WD)
待續。。。
相關文章
- 【web前端】小技巧總結Web前端
- 小程式開發技巧總結
- Chrome 開發者工具的小技巧總結Chrome
- Eclipse使用技巧總結Eclipse
- 我開發中總結的小技巧
- 成為JavaScript開發者的小技巧總結JavaScript
- Vue的使用總結和技巧Vue
- JavaScript 中 this 的使用技巧總結JavaScript
- Android程式碼優化小技巧總結Android優化
- 前端從業兩年總結的一些js使用小技巧前端JS
- Altium Designer使用技巧總結(一)
- 幾年的Git使用技巧總結Git
- XGBoost類庫使用小結
- iOS 小技巧總結,絕對有你想要的iOS
- Mybatis使用小技巧-自定義結果集MyBatis
- Git 小技巧彙總Git
- R小技巧彙總
- Objective-C開發使用技巧總結Object
- SVN使用技巧和參考文件總結
- 面試技巧總結面試
- gulp技巧總結
- css技巧總結CSS
- word技巧總結
- JavaScript 技巧總結JavaScript
- 有關連結串列的小技巧,我都給你總結好了
- URLConnection類,HttpURLConnection類的使用和總結HTTP
- Windows使用小技巧Windows
- photoshop使用小技巧
- 不定時更新-工具類小技巧
- Maven 常用技巧總結Maven
- 【個人總結】常用技巧
- CSS技巧總結2CSS
- [原始碼解析] PyTorch 分散式(14) --使用 Distributed Autograd 和 Distributed Optimizer原始碼PyTorch分散式
- 使用mpvue開發github小程式總結VueGithub
- 總結十個Python 字典用法的使用技巧Python
- 【Tips】【UE】總結自己常用的UltraEdit使用技巧
- FastReport報表控制元件使用技巧總結AST控制元件
- 填坑總結:python記憶體洩漏排查小技巧Python記憶體