簡單介紹TensorFlow中關於tf.app.flags命令列引數解析模組

大雄45發表於2023-04-13
導讀 這篇文章主要介紹了TensorFlow中關於tf.app.flags 行引數解析模組,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
tf.app.flags 行引數解析模組

說道命令列引數解析,就不得不提到 python 的 argparse 模組,詳情可參考我之前的一篇文章:python argparse 模組命令列引數用法及說明。

在閱讀相關工程的原始碼時,很容易發現 tf.app.flags 模組的身影。其作用與 python 的 argparse 類似。

直接上程式碼例項,新建一個名為 test_flags.py 的檔案,內容如下:

#coding:utf-8
import tensorflow as tf
 
FLAGS = tf.app.flags.FLAGS
# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
tf.app.flags.DEFINE_string("train_data_path", "/home/feige", "training data dir")
tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir")
tf.app.flags.DEFINE_integer("train_batch_size", 128, "batch size of train data")
tf.app.flags.DEFINE_integer("test_batch_size", 64, "batch size of test data")
tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate")
 
def main(unused_argv):
    train_data_path = FLAGS.train_data_path
    print("train_data_path", train_data_path)
    train_batch_size = FLAGS.train_batch_size
    print("train_batch_size", train_batch_size)
    test_batch_size = FLAGS.test_batch_size
    print("test_batch_size", test_batch_size)
    size_sum = tf.add(train_batch_size, test_batch_size)
    with tf.Session() as sess:
        sum_result = sess.run(size_sum)
        print("sum_result", sum_result)
 
# 使用這種方式保證了,如果此檔案被其他檔案 import的時候,不會執行main 函式
if __name__ == '__main__':
    tf.app.run()   # 解析命令列引數,呼叫main 函式 main(sys.argv)

上述程式碼已給出較為詳細的註釋,在此不再贅述。

該檔案的呼叫示例以及執行結果如下所示

簡單介紹TensorFlow中關於tf.app.flags命令列引數解析模組簡單介紹TensorFlow中關於tf.app.flags命令列引數解析模組

如果需要修改預設引數的值,則在命令列傳入自定義引數值即可,若全部使用預設引數值,則可直接在命令列執行該 python 檔案。

讀者可能會對 tf.app.run() 有些疑問,在上述註釋中也有所解釋,但要真正弄清楚其執行原理

還需查閱其原始碼
def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS
 
  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None
 
  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access
 
  main = main or sys.modules['__main__'].main
 
  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

flags_passthrough=f._parse_flags(args=args)這裡的_parse_flags就是我們tf.app.flags原始碼中用來解析命令列引數的函式。

所以這一行就是解析引數的功能;

下面兩行程式碼也就是 tf.app.run 的核心意思:執行程式中 main 函式,並解析命令列引數!

原文來自:


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69955379/viewspace-2945101/,如需轉載,請註明出處,否則將追究法律責任。

相關文章