關於tensorflow 的資料讀取執行緒管理QueueRunner

__Sunny__發表於2017-05-02

轉自 http://blog.csdn.net/sunquan_ok/article/details/51832442

TensorFlow的Session物件是可以支援多執行緒的,因此多個執行緒可以很方便地使用同一個會話(Session)並且並行地執行操作。然而,在Python程式實現這樣的並行運算卻並不容易。所有執行緒都必須能被同步終止,異常必須能被正確捕獲並報告,回話終止的時候, 佇列必須能被正確地關閉。

所幸TensorFlow提供了兩個類來幫助多執行緒的實現:tf.Coordinator和 tf.QueueRunner。從設計上這兩個類必須被一起使用。Coordinator類可以用來同時停止多個工作執行緒並且向那個在等待所有工作執行緒終止的程式報告異常。QueueRunner類用來協調多個工作執行緒同時將多個張量推入同一個佇列中


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

以上為極客中國上tensorflow的官方文件翻譯中,對於執行緒和佇列的介紹。但是可能說的不太清楚。


Coordinator還比較好理解,可以理解為訊號量之類的東西。QueueRunner比較難理解,通篇看介紹文件,都沒有找到QueueRunner這個程式碼,後來終於發現一段文字:

建立執行緒並使用QueueRunner物件來預取

簡單來說:使用上面列出的許多tf.train函式新增QueueRunner到你的資料流圖中。在你執行任何訓練步驟之前,需要呼叫tf.train.start_queue_runners函式,否則資料流圖將一直掛起。tf.train.start_queue_runners 這個函式將會啟動輸入管道的執行緒,填充樣本到佇列中,以便出隊操作可以從佇列中拿到樣本。


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

也就是說,QueueRunner是一個不存在於程式碼中的東西,而是後臺運作的一個概念。由tf.train函式新增。


首先,我們先建立資料流圖,這個資料流圖由一些流水線的階段組成,階段間用佇列連線在一起。第一階段將生成檔名,我們讀取這些檔名並且把他們排到檔名佇列中。第二階段從檔案中讀取資料(使用Reader),產生樣本,而且把樣本放在一個樣本佇列中。根據你的設定,實際上也可以拷貝第二階段的樣本,使得他們相互獨立,這樣就可以從多個檔案中並行讀取。在第二階段的最後是一個排隊操作,就是入隊到佇列中去,在下一階段出隊。因為我們是要開始執行這些入隊操作的執行緒,所以我們的訓練迴圈會使得樣本佇列中的樣本不斷地出隊。




tf.train中要建立這些佇列和執行入隊操作,就要新增tf.train.QueueRunner到一個使用tf.train.add_queue_runner函式的資料流圖中。每個QueueRunner負責一個階段,處理那些需要線上程中執行的入隊操作的列表。一旦資料流圖構造成功,tf.train.start_queue_runners函式就會要求資料流圖中每個QueueRunner去開始它的執行緒執行入隊操作。


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

程式碼中根本沒有QueueRunner的啊


# Create the graph, etc.
init_op = tf.initialize_all_variables()

# Create a session for running operations in the Graph.
sess = tf.Session()

# Initialize the variables (like the epoch counter).
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    while not coord.should_stop():
        # Run training steps or whatever
        sess.run(train_op)

except tf.errors.OutOfRangeError:
    print 'Done training -- epoch limit reached'
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()


相關文章