雲端TensorFlow讀取資料IO的高效方式

李博Garvin發表於2017-07-17

低效的IO方式

最近通過觀察PAI平臺上TensoFlow使用者的執行情況,發現大家在資料IO這方面還是有比較大的困惑,主要是因為很多同學沒有很好的理解本地執行TensorFlow程式碼和分散式雲端執行TensorFlow的區別。本地讀取資料是server端直接從client端獲得graph進行計算,而云端服務server在獲得graph之後還需要將計算下發到各個worker處理(具體原理可以參考視訊教程-Tensorflow高階篇:https://tianchi.aliyun.com/competition/new_articleDetail.html)。

本文通過讀取一個簡單的CSV檔案為例,幫助大家快速瞭解如何使用TensorFlow高效的讀取資料。CSV檔案如下:

1,1,1,1,1
2,2,2,2,2
3,3,3,3,3  

首先我們來看下大家容易產生問題的幾個地方。

1.不建議用python本地讀取檔案的方式

PAI支援python的自帶IO方式,但是需要將資料來源和程式碼打包上傳的方式使用,這種讀取方式是將資料寫入記憶體之後再計算,效率比較低,不建議使用。範例程式碼如下:

import csv
csv_reader=csv.reader(open('csvtest.csv'))
for row in csv_reader:
    print(row)  

2.儘量不要用第三方庫的讀取檔案方法

很多同學使用第三方庫的一些資料IO的方式進行資料讀取,比如TFLearn、Panda的資料IO方式,這些方法很多都是通過封裝PYTHON的讀取方式實現的,所以在PAI平臺使用的時候也會造成效率低下問題。

3.儘量不要用preload的方式讀取檔案

很多人在用PAI的服務的時候表示GPU並沒有比本地的CPU速度快的明顯,主要問題可能就出在資料IO這塊。preload的方式是先把資料全部都讀到記憶體中,然後再通過session計算,比如feed的讀取方式。這樣要先進行資料讀取,再計算,不同步造成效能浪費,同時因為記憶體限制也無法支援大資料量的計算。舉個例子:假設我們的硬碟中有一個圖片資料集0001.jpg,0002.jpg,0003.jpg……我們只需要把它們讀取到記憶體中,然後提供給GPU或是CPU進行計算就可以了。這聽起來很容易,但事實遠沒有那麼簡單。事實上,我們必須要把資料先讀入後才能進行計算,假設讀入用時0.1s,計算用時0.9s,那麼就意味著每過1s,GPU都會有0.1s無事可做,這就大大降低了運算的效率。

下面我們看下高效的讀取方式。

高效的IO方式

高效的TensorFlow讀取方式是將資料讀取轉換成OP,通過session run的方式拉去資料。另外,讀取執行緒源源不斷地將檔案系統中的圖片讀入到一個記憶體的佇列中,而負責計算的是另一個執行緒,計算需要資料時,直接從記憶體佇列中取就可以了。這樣就可以解決GPU因為IO而空閒的問題!

下面我們看下程式碼,如何在PAI平臺通過OP的方式讀取資料:

import argparse
import tensorflow as tf
import os
FLAGS=None
def main(_):
    dirname = os.path.join(FLAGS.buckets, "csvtest.csv")
    reader=tf.TextLineReader()
    filename_queue=tf.train.string_input_producer([dirname])
    key,value=reader.read(filename_queue)
    record_defaults=[[''],[''],[''],[''],['']]
    d1, d2, d3, d4, d5= tf.decode_csv(value, record_defaults, ',')

    init=tf.initialize_all_variables()

    with tf.Session() as sess:
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)
        for i in range(4):
            print(sess.run(d2))
        coord.request_stop()
        coord.join(threads)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--buckets', type=str, default='',
                        help='input data path')
    parser.add_argument('--checkpointDir', type=str, default='',
                        help='output model path')
    FLAGS, _ = parser.parse_known_args()
    tf.app.run(main=main)
  • dirname:OSS檔案路徑,可以是陣列,方便下一階段shuffle
  • reader:TF內建各種reader API,可以根據需求選用
  • tf.train.string_input_producer:將檔案生成佇列
  • tf.decode_csv:是一個splite功能的OP,可以拿到每一行的特定引數
  • 通過OP獲取資料,在session中需要tf.train.Coordinator()和tf.train.start_queue_runners(sess=sess,coord=coord)

在程式碼中,我們的輸入是3行5個欄位:

1,1,1,1,1
2,2,2,2,2
3,3,3,3,3  

我們迴圈輸出4次,列印出第2個欄位。結果如圖:

輸出結果也證明了資料結構是成佇列。

其它

  • 我的微信公眾號(長期分享機器學習乾貨):凡人機器學習
    這裡寫圖片描述

  • PAI notebook功能上線,支援線上修改程式碼並且內建各種深度學習框架,歡迎使用:https://data.aliyun.com/product/learn

  • 強烈推薦視訊教程:https://tianchi.aliyun.com/competition/new_articleDetail.html
  • 本文參考了網際網路上《十圖詳解TensorFlow資料讀取機制(附程式碼)》一文,關於圖片的讀取方式也可以參考這篇文章,感謝原作者。

相關文章