argparse是深度學習專案調參時常用的python標準庫,使用argparse後,我們在命令列輸入的引數就可以以這種形式python filename.py --lr 1e-4 --batch_size 32
來完成對常見超引數的設定。,一般使用時可以歸納為以下三個步驟
使用步驟:
- 建立
ArgumentParser()
物件 - 呼叫
add_argument()
方法新增引數 - 使用
parse_args()
解析引數 在接下來的內容中,我們將以實際操作來學習argparse的使用方法
import argparse
parser = argparse.ArgumentParser() # 建立一個解析物件
parser.add_argument() # 向該物件中新增你要關注的命令列引數和選項
args = parser.parse_args() # 呼叫parse_args()方法進行解析
常見規則
- 在命令列中輸入
python demo.py -h
或者python demo.py --help
可以檢視該python檔案引數說明 - arg字典類似python字典,比如arg字典
Namespace(integers='5')
可使用arg.引數名
來提取這個引數 parser.add_argument('integers', type=str, nargs='+',help='傳入的數字')
nargs是用來說明傳入的引數個數,'+' 表示傳入至少一個引數,'*' 表示引數可設定零個或多個,'?' 表示引數可設定零個或一個parser.add_argument('-n', '--name', type=str, required=True, default='', help='名')
required=True
表示必須引數, -n表示可以使用短選項使用該引數parser.add_argument("--test_action", default='False', action='store_true')
store_true 觸發時為真,不觸發則為假(test.py
,輸出為False
,test.py --test_action
,輸出為True
)
使用config檔案傳入超引數
為了使程式碼更加簡潔和模組化,可以將有關超引數的操作寫在config.py
,然後在train.py
或者其他檔案匯入就可以。具體的config.py
可以參考如下內容。
import argparse
def get_options(parser=argparse.ArgumentParser()):
parser.add_argument('--workers', type=int, default=0,
help='number of data loading workers, you had better put it '
'4 times of your gpu')
parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
parser.add_argument('--seed', type=int, default=118, help="random seed")
parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')
parser.add_argument('--checkpoint_path',type=str,default='',
help='Path to load a previous trained model if not empty (default empty)')
parser.add_argument('--output',action='store_true',default=True,help="shows output")
opt = parser.parse_args()
if opt.output:
print(f'num_workers: {opt.workers}')
print(f'batch_size: {opt.batch_size}')
print(f'epochs (niters) : {opt.niter}')
print(f'learning rate : {opt.lr}')
print(f'manual_seed: {opt.seed}')
print(f'cuda enable: {opt.cuda}')
print(f'checkpoint_path: {opt.checkpoint_path}')
return opt
if __name__ == '__main__':
opt = get_options()
$ python config.py
num_workers: 0
batch_size: 4
epochs (niters) : 10
learning rate : 3e-05
manual_seed: 118
cuda enable: True
checkpoint_path:
隨後在train.py
等其他檔案,我們就可以使用下面的這樣的結構來呼叫引數。
# 匯入必要庫
...
import config
opt = config.get_options()
manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path
# 隨機數的設定,保證復現結果
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
...
if __name__ == '__main__':
set_seed(manual_seed)
for epoch in range(niters):
train(model,lr,batch_size,num_workers,checkpoint_path)
val(model,lr,batch_size,num_workers,checkpoint_path)
參考:
https://zhuanlan.zhihu.com/p/56922793
(14條訊息) python argparse中action的可選引數store_true的作用_元氣少女wuqh的部落格-CSDN部落格
[6.6 使用argparse進行調參 — 深入淺出PyTorch (datawhalechina.github.io)](https://datawhalechina.github.io/thorough-pytorch/第六章/6.6 使用argparse進行調參.html)