tensorflow載入資料的三種方式

馬衛飛發表於2018-06-18

Tensorflow資料讀取有三種方式:

  • Preloaded data: 預載入資料
  • Feeding: Python產生資料,再把資料餵給後端。
  • Reading from file: 從檔案中直接讀取

這三種有讀取方式有什麼區別呢? 我們首先要知道TensorFlow(TF)是怎麼樣工作的。

TF的核心是用C++寫的,這樣的好處是執行快,缺點是呼叫不靈活。而Python恰好相反,所以結合兩種語言的優勢。涉及計算的核心運算元和執行框架是用C++寫的,並提供API給Python。Python呼叫這些API,設計訓練模型(Graph),再將設計好的Graph給後端去執行。簡而言之,Python的角色是Design,C++是Run。

一、預載入資料:

[python] view plain copy
  1. import tensorflow as tf  
  2. # 設計Graph  
  3. x1 = tf.constant([234])  
  4. x2 = tf.constant([401])  
  5. y = tf.add(x1, x2)  
  6. # 開啟一個session --> 計算y  
  7. with tf.Session() as sess:  
  8.     print sess.run(y)  

二、python產生資料,再將資料餵給後端

[python] view plain copy
  1. import tensorflow as tf  
  2. # 設計Graph  
  3. x1 = tf.placeholder(tf.int16)  
  4. x2 = tf.placeholder(tf.int16)  
  5. y = tf.add(x1, x2)  
  6. # 用Python產生資料  
  7. li1 = [234]  
  8. li2 = [401]  
  9. # 開啟一個session --> 喂資料 --> 計算y  
  10. with tf.Session() as sess:  
  11.     print sess.run(y, feed_dict={x1: li1, x2: li2})  
說明:在這裡x1, x2只是佔位符,沒有具體的值,那麼執行的時候去哪取值呢?這時候就要用到sess.run()中的feed_dict引數,將Python產生的資料餵給後端,並計算y。
這兩種方案的缺點:

1、預載入:將資料直接內嵌到Graph中,再把Graph傳入Session中執行。當資料量比較大時,Graph的傳輸會遇到效率問題

2、用佔位符替代資料,待執行的時候填充資料。

前兩種方法很方便,但是遇到大型資料的時候就會很吃力,即使是Feeding,中間環節的增加也是不小的開銷,比如資料型別轉換等等。最優的方案就是在Graph定義好檔案讀取的方法,讓TF自己去從檔案中讀取資料,並解碼成可使用的樣本集。

三、從檔案中讀取,簡單來說就是將資料讀取模組的圖搭好


1、準備資料,構造三個檔案,A.csv,B.csv,C.csv

[python] view plain copy
  1. $ echo -e "Alpha1,A1\nAlpha2,A2\nAlpha3,A3" > A.csv  
  2. $ echo -e "Bee1,B1\nBee2,B2\nBee3,B3" > B.csv  
  3. $ echo -e "Sea1,C1\nSea2,C2\nSea3,C3" > C.csv  

2、單個Reader,單個樣本

[python] view plain copy
  1. #-*- coding:utf-8 -*-  
  2. import tensorflow as tf  
  3. # 生成一個先入先出佇列和一個QueueRunner,生成檔名佇列  
  4. filenames = ['A.csv''B.csv''C.csv']  
  5. filename_queue = tf.train.string_input_producer(filenames, shuffle=False)  
  6. # 定義Reader  
  7. reader = tf.TextLineReader()  
  8. key, value = reader.read(filename_queue)  
  9. # 定義Decoder  
  10. example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])  
  11. #example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=1, capacity=200, min_after_dequeue=100, num_threads=2)  
  12. # 執行Graph  
  13. with tf.Session() as sess:  
  14.     coord = tf.train.Coordinator()  #建立一個協調器,管理執行緒  
  15.     threads = tf.train.start_queue_runners(coord=coord)  #啟動QueueRunner, 此時檔名佇列已經進隊。  
  16.     for i in range(10):  
  17.         print example.eval(),label.eval()  
  18.     coord.request_stop()  
  19.     coord.join(threads)  
說明:這裡沒有使用tf.train.shuffle_batch,會導致生成的樣本和label之間對應不上,亂序了。生成結果如下:

Alpha1 A2
Alpha3 B1
Bee2 B3
Sea1 C2
Sea3 A1
Alpha2 A3
Bee1 B2
Bee3 C1
Sea2 C3
Alpha1 A2

解決方案:用tf.train.shuffle_batch,那麼生成的結果就能夠對應上。

[python] view plain copy
  1. #-*- coding:utf-8 -*-  
  2. import tensorflow as tf  
  3. # 生成一個先入先出佇列和一個QueueRunner,生成檔名佇列  
  4. filenames = ['A.csv''B.csv''C.csv']  
  5. filename_queue = tf.train.string_input_producer(filenames, shuffle=False)  
  6. # 定義Reader  
  7. reader = tf.TextLineReader()  
  8. key, value = reader.read(filename_queue)  
  9. # 定義Decoder  
  10. example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])  
  11. example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=1, capacity=200, min_after_dequeue=100, num_threads=2)  
  12. # 執行Graph  
  13. with tf.Session() as sess:  
  14.     coord = tf.train.Coordinator()  #建立一個協調器,管理執行緒  
  15.     threads = tf.train.start_queue_runners(coord=coord)  #啟動QueueRunner, 此時檔名佇列已經進隊。  
  16.     for i in range(10):  
  17.         e_val,l_val = sess.run([example_batch, label_batch])  
  18.         print e_val,l_val  
  19.     coord.request_stop()  
  20.     coord.join(threads)  

3、單個Reader,多個樣本,主要也是通過tf.train.shuffle_batch來實現

[python] view plain copy
  1. #-*- coding:utf-8 -*-  
  2. import tensorflow as tf  
  3. filenames = ['A.csv''B.csv''C.csv']  
  4. filename_queue = tf.train.string_input_producer(filenames, shuffle=False)  
  5. reader = tf.TextLineReader()  
  6. key, value = reader.read(filename_queue)  
  7. example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])  
  8. # 使用tf.train.batch()會多加了一個樣本佇列和一個QueueRunner。  
  9. #Decoder解後資料會進入這個佇列,再批量出隊。  
  10. # 雖然這裡只有一個Reader,但可以設定多執行緒,相應增加執行緒數會提高讀取速度,但並不是執行緒越多越好。  
  11. example_batch, label_batch = tf.train.batch(  
  12.       [example, label], batch_size=5)  
  13. with tf.Session() as sess:  
  14.     coord = tf.train.Coordinator()  
  15.     threads = tf.train.start_queue_runners(coord=coord)  
  16.     for i in range(10):  
  17.         e_val,l_val = sess.run([example_batch,label_batch])  
  18.         print e_val,l_val  
  19.     coord.request_stop()  
  20.     coord.join(threads)  

說明:下面這種寫法,提取出來的batch_size個樣本,特徵和label之間也是不同步的

[python] view plain copy
  1. #-*- coding:utf-8 -*-  
  2. import tensorflow as tf  
  3. filenames = ['A.csv''B.csv''C.csv']  
  4. filename_queue = tf.train.string_input_producer(filenames, shuffle=False)  
  5. reader = tf.TextLineReader()  
  6. key, value = reader.read(filename_queue)  
  7. example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])  
  8. # 使用tf.train.batch()會多加了一個樣本佇列和一個QueueRunner。  
  9. #Decoder解後資料會進入這個佇列,再批量出隊。  
  10. # 雖然這裡只有一個Reader,但可以設定多執行緒,相應增加執行緒數會提高讀取速度,但並不是執行緒越多越好。  
  11. example_batch, label_batch = tf.train.batch(  
  12.       [example, label], batch_size=5)  
  13. with tf.Session() as sess:  
  14.     coord = tf.train.Coordinator()  
  15.     threads = tf.train.start_queue_runners(coord=coord)  
  16.     for i in range(10):  
  17.         print example_batch.eval(), label_batch.eval()  
  18.     coord.request_stop()  
  19.     coord.join(threads)  
說明:輸出結果如下:可以看出feature和label之間是不對應的

['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2'] ['B3' 'C1' 'C2' 'C3' 'A1']
['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3'] ['C1' 'C2' 'C3' 'A1' 'A2']
['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1'] ['C2' 'C3' 'A1' 'A2' 'A3']

4、多個reader,多個樣本

[python] view plain copy
  1. #-*- coding:utf-8 -*-  
  2. import tensorflow as tf  
  3. filenames = ['A.csv''B.csv''C.csv']  
  4. filename_queue = tf.train.string_input_producer(filenames, shuffle=False)  
  5. reader = tf.TextLineReader()  
  6. key, value = reader.read(filename_queue)  
  7. record_defaults = [['null'], ['null']]  
  8. #定義了多種解碼器,每個解碼器跟一個reader相連  
  9. example_list = [tf.decode_csv(value, record_defaults=record_defaults)  
  10.                   for _ in range(2)]  # Reader設定為2  
  11. # 使用tf.train.batch_join(),可以使用多個reader,並行讀取資料。每個Reader使用一個執行緒。  
  12. example_batch, label_batch = tf.train.batch_join(  
  13.       example_list, batch_size=5)  
  14. with tf.Session() as sess:  
  15.     coord = tf.train.Coordinator()  
  16.     threads = tf.train.start_queue_runners(coord=coord)  
  17.     for i in range(10):  
  18.         e_val,l_val = sess.run([example_batch,label_batch])  
  19.         print e_val,l_val  
  20.     coord.request_stop()  
  21.     coord.join(threads)  

tf.train.batchtf.train.shuffle_batch函式是單個Reader讀取,但是可以多執行緒。tf.train.batch_jointf.train.shuffle_batch_join可設定多Reader讀取,每個Reader使用一個執行緒。至於兩種方法的效率,單Reader時,2個執行緒就達到了速度的極限。多Reader時,2個Reader就達到了極限。所以並不是執行緒越多越快,甚至更多的執行緒反而會使效率下降。

5、迭代控制,設定epoch引數,指定我們的樣本在訓練的時候只能被用多少輪

[python] view plain copy
  1. #-*- coding:utf-8 -*-  
  2. import tensorflow as tf  
  3. filenames = ['A.csv''B.csv''C.csv']  
  4. #num_epoch: 設定迭代數  
  5. filename_queue = tf.train.string_input_producer(filenames, shuffle=False,num_epochs=3)  
  6. reader = tf.TextLineReader()  
  7. key, value = reader.read(filename_queue)  
  8. record_defaults = [['null'], ['null']]  
  9. #定義了多種解碼器,每個解碼器跟一個reader相連  
  10. example_list = [tf.decode_csv(value, record_defaults=record_defaults)  
  11.                   for _ in range(2)]  # Reader設定為2  
  12. # 使用tf.train.batch_join(),可以使用多個reader,並行讀取資料。每個Reader使用一個執行緒。  
  13. example_batch, label_batch = tf.train.batch_join(  
  14.       example_list, batch_size=1)  
  15. #初始化本地變數  
  16. init_local_op = tf.initialize_local_variables()  
  17. with tf.Session() as sess:  
  18.     sess.run(init_local_op)  
  19.     coord = tf.train.Coordinator()  
  20.     threads = tf.train.start_queue_runners(coord=coord)  
  21.     try:  
  22.         while not coord.should_stop():  
  23.             e_val,l_val = sess.run([example_batch,label_batch])  
  24.             print e_val,l_val  
  25.     except tf.errors.OutOfRangeError:  
  26.             print('Epochs Complete!')  
  27.     finally:  
  28.             coord.request_stop()  
  29.     coord.join(threads)  
  30.     coord.request_stop()  
  31.     coord.join(threads)  

在迭代控制中,記得新增tf.initialize_local_variables(),官網教程沒有說明,但是如果不初始化,執行就會報錯。

=========================================================================================對於傳統的機器學習而言,比方說分類問題,[x1 x2 x3]是feature。對於二分類問題,label經過one-hot編碼之後就會是[0,1]或者[1,0]。一般情況下,我們會考慮將資料組織在csv檔案中,一行代表一個sample。然後使用佇列的方式去讀取資料


說明:對於該資料,前三列代表的是feature,因為是分類問題,後兩列就是經過one-hot編碼之後得到的label

使用佇列讀取該csv檔案的程式碼如下:

[python] view plain copy
  1. #-*- coding:utf-8 -*-  
  2. import tensorflow as tf  
  3. # 生成一個先入先出佇列和一個QueueRunner,生成檔名佇列  
  4. filenames = ['A.csv']  
  5. filename_queue = tf.train.string_input_producer(filenames, shuffle=False)  
  6. # 定義Reader  
  7. reader = tf.TextLineReader()  
  8. key, value = reader.read(filename_queue)  
  9. # 定義Decoder  
  10. record_defaults = [[1], [1], [1], [1], [1]]  
  11. col1, col2, col3, col4, col5 = tf.decode_csv(value,record_defaults=record_defaults)  
  12. features = tf.pack([col1, col2, col3])  
  13. label = tf.pack([col4,col5])  
  14. example_batch, label_batch = tf.train.shuffle_batch([features,label], batch_size=2, capacity=200, min_after_dequeue=100, num_threads=2)  
  15. # 執行Graph  
  16. with tf.Session() as sess:  
  17.     coord = tf.train.Coordinator()  #建立一個協調器,管理執行緒  
  18.     threads = tf.train.start_queue_runners(coord=coord)  #啟動QueueRunner, 此時檔名佇列已經進隊。  
  19.     for i in range(10):  
  20.         e_val,l_val = sess.run([example_batch, label_batch])  
  21.         print e_val,l_val  
  22.     coord.request_stop()  
  23.     coord.join(threads)  

輸出結果如下:


說明:

record_defaults = [[1], [1], [1], [1], [1]]
代表解析的模板,每個樣本有5列,在資料中是預設用‘,’隔開的,然後解析的標準是[1],也即每一列的數值都解析為整型。[1.0]就是解析為浮點,['null']解析為string型別

相關文章