TensorFlow讀寫資料

Java3y發表於2019-03-16

前言

只有光頭才能變強。

文字已收錄至我的GitHub倉庫,歡迎Star:github.com/ZhongFuChen…

回顧前面:

眾所周知,要訓練出一個模型,首先我們得有資料。我們第一個例子中,直接使用dataset的api去載入mnist的資料。(minst的資料要麼我們是提前下載好,放在對應的目錄上,要麼就根據他給的url直接從網上下載)。

一般來說,我們使用TensorFlow是從TFRecord檔案中讀取資料的。

TFRecord 檔案格式是一種面向記錄的簡單二進位制格式,很多 TensorFlow 應用採用此格式來訓練資料

所以,這篇文章來聊聊怎麼讀取TFRecord檔案的資料。

一、入門對資料集的資料進行讀和寫

首先,我們來體驗一下怎麼造一個TFRecord檔案,怎麼從TFRecord檔案中讀取資料,遍歷(消費)這些資料。

1.1 造一個TFRecord檔案

現在,我們還沒有TFRecord檔案,我們可以自己簡單寫一個:

def write_sample_to_tfrecord():
    gmv_values = np.arange(10)
    click_values = np.arange(10)
    label_values = np.arange(10)

    with tf.python_io.TFRecordWriter("/Users/zhongfucheng/data/fashin/demo.tfrecord", options=None) as writer:
        for _ in range(10):
            feature_internal = {
                "gmv": tf.train.Feature(float_list=tf.train.FloatList(value=[gmv_values[_]])),
                "click": tf.train.Feature(int64_list=tf.train.Int64List(value=[click_values[_]])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label_values[_]]))
            }
            features_extern = tf.train.Features(feature=feature_internal)

            # 使用tf.train.Example將features編碼資料封裝成特定的PB協議格式
            # example = tf.train.Example(features=tf.train.Features(feature=features_extern))
            example = tf.train.Example(features=features_extern)

            # 將example資料系列化為字串
            example_str = example.SerializeToString()

            # 將系列化為字串的example資料寫入協議緩衝區
            writer.write(example_str)


if __name__ == '__main__':
    write_sample_to_tfrecord()
複製程式碼

我相信大家程式碼應該是能夠看得懂的,其實就是分了幾步:

  • 生成TFRecord Writer
  • tf.train.Feature生成協議資訊
  • 使用tf.train.Example將features編碼資料封裝成特定的PB協議格式
  • 將example資料系列化為字串
  • 將系列化為字串的example資料寫入協議緩衝區

參考資料:

ok,現在我們就有了一個TFRecord檔案啦。

1.2 讀取TFRecord檔案

  • 其實就是通過tf.data.TFRecordDataset這個api來讀取到TFRecord檔案,生成處dataset物件

  • 對dataset進行處理(shape處理,格式處理...等等)

  • 使用迭代器對dataset進行消費(遍歷)

demo程式碼如下:

import tensorflow as tf


def read_tensorflow_tfrecord_files():
    # 定義消費緩衝區協議的parser,作為dataset.map()方法中傳入的lambda:
    def _parse_function(single_sample):
        features = {
            "gmv": tf.FixedLenFeature([1], tf.float32),
            "click": tf.FixedLenFeature([1], tf.int64),  # ()或者[]沒啥影響
            "label": tf.FixedLenFeature([1], tf.int64)
        }
        parsed_features = tf.parse_single_example(single_sample, features=features)

        # 對parsed 之後的值進行cast.
        gmv = tf.cast(parsed_features["gmv"], tf.float64)
        click = tf.cast(parsed_features["click"], tf.float64)
        label = tf.cast(parsed_features["label"], tf.float64)

        return gmv, click, label

    # 開始定義dataset以及解析tfrecord格式
    filenames = tf.placeholder(tf.string, shape=[None])

    # 定義dataset 和 一些列trasformation method
    dataset = tf.data.TFRecordDataset(filenames)
    parsed_dataset = dataset.map(_parse_function)  # 消費緩衝區需要定義在dataset 的map 函式中
    batchd_dataset = parsed_dataset.batch(3)

    # 建立Iterator
    sample_iter = batchd_dataset.make_initializable_iterator()
    # 獲取next_sample
    gmv, click, label = sample_iter.get_next()
    training_filenames = [
        "/Users/zhongfucheng/data/fashin/demo.tfrecord"]
    with tf.Session() as session:
        # 初始化帶引數的Iterator
        session.run(sample_iter.initializer, feed_dict={filenames: training_filenames})
        # 讀取檔案
        print(session.run(gmv))


if __name__ == '__main__':
    read_tensorflow_tfrecord_files()

複製程式碼

無意外的話,我們可以輸出這樣的結果:

[[0.]
 [1.]
 [2.]]
複製程式碼

ok,現在我們已經大概知道怎麼寫一個TFRecord檔案,以及怎麼讀取TFRecord檔案的資料,並且消費這些資料了。

二、epoch和batchSize術語解釋

我在學習TensorFlow翻閱資料時,經常看到一些機器學習的術語,由於自己沒啥機器學習的基礎,所以很多時候看到一些專業名詞就開始懵逼了。

2.1epoch

當一個完整的資料集通過了神經網路一次並且返回了一次,這個過程稱為一個epoch

這可能使我們跟dataset.repeat()方法聯絡起來,這個方法可以使當前資料集重複一遍。比如說,原有的資料集是[1,2,3,4,5],如果我呼叫dataset.repeat(2)的話,那麼我們的資料集就變成了[1,2,3,4,5],[1,2,3,4,5]

  • 所以會有個說法:假設原先的資料是一個epoch,使用repeat(5)就可以將之變成5個epoch

2.2batchSize

一般來說我們的資料集都是比較大的,無法一次性將整個資料集的資料喂進神經網路中,所以我們會將資料集分成好幾個部分。每次喂多少條樣本進神經網路,這個叫做batchSize。

在TensorFlow也提供了方法給我們設定:dataset.batch(),在API中是這樣介紹batchSize的:

representing the number of consecutive elements of this dataset to combine in a single batch
複製程式碼

我們一般在每次訓練之前,會將整個資料集的順序打亂,提高我們模型訓練的效果。這裡我們用到的api是:dataset.shffle();

三、再來聊聊dataset

我從官網的介紹中截了一個dataset的方法圖(部分):

dataset的方法圖

dataset的功能主要有以下三種:

  • 建立dataset例項
    • 通過檔案建立(比如TFRecord)
    • 通過記憶體建立
  • 對資料集的資料進行變換
    • 比如上面的batch(),常見的map(),flat_map(),zip(),repeat()等等
    • 文件中一般都有給出例子,跑一下一般就知道對應的意思了。
  • 建立迭代器,遍歷資料集的資料

3.1 聊聊迭代器

迭代器可以分為四種:

  • 單次。對資料集進行一次迭代,不支援引數化
  • 可初始化迭代
    • 使用前需要進行初始化,支援傳入引數。面向的是同一個DataSet
  • 可重新初始化:同一個Iterator從不同的DataSet中讀取資料
    • DataSet的物件具有相同的結構,可以使用tf.data.Iterator.from_structure來進行初始化
    • 問題:每次 Iterator 切換時,資料都從頭開始列印了
  • 可饋送(也是通過物件相同的結果來建立的迭代器)
    • 可讓您在兩個資料集之間切換的可饋送迭代器
    • 通過一個string handler來實現。
    • 可饋送的 Iterator 在不同的 Iterator 切換的時候,可以做到不從頭開始

簡單總結:

  • 1、 單次 Iterator ,它最簡單,但無法重用,無法處理資料集引數化的要求。
  • 2、 可以初始化的 Iterator ,它可以滿足 Dataset 重複載入資料,滿足了引數化要求。
  • 3、可重新初始化的 Iterator,它可以對接不同的 Dataset,也就是可以從不同的 Dataset 中讀取資料。
  • 4、可饋送的 Iterator,它可以通過 feeding 的方式,讓程式在執行時候選擇正確的 Iterator,它和可重新初始化的 Iterator 不同的地方就是它的資料在不同的 Iterator 切換時,可以做到不重頭開始讀取資料

string handler(可饋送的 Iterator)這種方式是最常使用的,我當時也寫了一個Demo來使用了一下,程式碼如下:

def read_tensorflow_tfrecord_files():
    # 開始定義dataset以及解析tfrecord格式.
    train_filenames = tf.placeholder(tf.string, shape=[None])
    vali_filenames = tf.placeholder(tf.string, shape=[None])

    # 載入train_dataset   batch_inputs這個方法每個人都不一樣的,這個方法我就不給了。
    train_dataset = batch_inputs([
        train_filenames], batch_size=5, type=False,
        num_epochs=2, num_preprocess_threads=3)
    # 載入validation_dataset  batch_inputs這個方法每個人都不一樣的,這個方法我就不給了。
    validation_dataset = batch_inputs([vali_filenames
                                       ], batch_size=5, type=False,
                                      num_epochs=2, num_preprocess_threads=3)

    # 建立出string_handler()的迭代器(通過相同資料結構的dataset來構建)
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_dataset.output_types, train_dataset.output_shapes)

    # 有了迭代器就可以呼叫next方法了。
    itemid = iterator.get_next()

    # 指定哪種具體的迭代器,有單次迭代的,有初始化的。
    training_iterator = train_dataset.make_initializable_iterator()
    validation_iterator = validation_dataset.make_initializable_iterator()

    # 定義出placeholder的值
    training_filenames = [
        "/Users/zhongfucheng/tfrecord_test/data01aa"]
    validation_filenames = ["/Users/zhongfucheng/tfrecord_validation/part-r-00766"]

    with tf.Session() as sess:
        # 初始化迭代器
        training_handle = sess.run(training_iterator.string_handle())
        validation_handle = sess.run(validation_iterator.string_handle())

        for _ in range(2):
            sess.run(training_iterator.initializer, feed_dict={train_filenames: training_filenames})
            print("this is training iterator ----")

            for _ in range(5):
                print(sess.run(itemid, feed_dict={handle: training_handle}))

            sess.run(validation_iterator.initializer,
                     feed_dict={vali_filenames: validation_filenames})

            print("this is validation iterator ")
            for _ in range(5):
                print(sess.run(itemid, feed_dict={vali_filenames: validation_filenames, handle: validation_handle}))


if __name__ == '__main__':
    read_tensorflow_tfrecord_files()

複製程式碼

參考資料:

3.2 dataset參考資料

在翻閱資料時,發現寫得不錯的一些部落格:

最後

樂於輸出乾貨的Java技術公眾號:Java3y。公眾號內有200多篇原創技術文章、海量視訊資源、精美腦圖,不妨來關注一下!

下一篇文章打算講講如何理解axis~

帥的人都關注了

覺得我的文章寫得不錯,不妨點一下

相關文章