【小白學PyTorch】17 TFrec檔案的建立與讀取

忽逢桃林發表於2020-10-03

【新聞】:機器學習煉丹術的粉絲的人工智慧交流群已經建立,目前有目標檢測、醫學影像、時間序列等多個目標為技術學習的分群和水群嘮嗑的總群,歡迎大家加煉丹兄為好友,加入煉丹協會。微信:cyx645016617.

參考目錄:

本文的程式碼已經上傳公眾號後臺,回覆【PyTorch】獲取。
第一次接觸到TFrec檔案,我也是比較矇蔽的其實:

可以看到檔案是.tfrec字尾的,而且先記住這個檔案是186.72MB大小的。

1 為什麼用tfrec檔案

正常情況下我們用於訓練的資料夾內部往往會存著成千上萬的圖片或文字等檔案,這些檔案通常被雜湊存放。這種儲存方式有一些缺點:

  • 佔用磁碟空間;
  • 一個一個讀取檔案消耗時間

而tfrec格式的檔案儲存形式會很合理的幫我們儲存資料,核心就是tfrec內部使用Protocol Buffer的二進位制資料編碼方案,這個方案可以極大的壓縮儲存空間

之前我們知道一個tfrec檔案100多M,這是因為這個tfrec檔案記憶體儲了很多的圖片,類似於壓縮,對tfrec解壓縮後可以獲取到一部分的資料集,當我們把全部的rfrec檔案都解壓縮後,可以獲取到全部的資料集。

值得一提的是,rfrec檔案內除了可以儲存圖片,還可以儲存其他的資料,比方說圖片的label。字串,float型別等都可以轉換成二進位制的方法,所以什麼資料型別基本上都可以儲存到rfrec檔案內,從而簡化讀取資料的過程。

2 tfrec檔案的內部結構

tfrec檔案時tensorflow的資料集儲存格式,tensorflow可以高效的讀取和處理這些資料集,因此我見過有的資料集因為是tfrec檔案,所以用TF讀取資料集,然後用pytorch訓練模型。

之前提到了tfrec檔案裡面是有多個樣本的,所以tfrec可以為是多個tf.train.Example檔案組成的序列(每一個example是一個樣本),然後每一個tf.train.Example又是由若干個tf.train.Features字典組成。這個Features可以理解為這個樣本的一些資訊,如果是圖片樣本,那麼肯定有一個Features是圖片畫素值資料,一個Features是圖片的標籤值;如果是預測任務,那麼這個Feature可能就是一些字串型別的特徵

3 製作tfrec檔案

import tensorflow as tf
import glob
# 先記錄一下要儲存的tfrec檔案的名字
tfrecord_file = './train.tfrec'
# 獲取指定目錄的所有以jpeg結尾的檔案list
images = glob.glob('./*.jpeg')
with tf.io.TFRecordWriter(tfrecord_file) as writer:
    for filename in images:
        image = open(filename, 'rb').read()  # 讀取資料集圖片到記憶體,image 為一個 Byte 型別的字串
        feature = {  # 建立 tf.train.Feature 字典
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),  # 圖片是一個 Bytes 物件
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
            'float':tf.train.Feature(float_list=tf.train.FloatList(value=[1.0,2.0])),
            'name':tf.train.Feature(bytes_list=tf.train.BytesList(value=[str.encode(filename)]))
        }
        # tf.train.Example 在 tf.train.Features 外面又多了一層封裝
        example = tf.train.Example(features=tf.train.Features(feature=feature))  # 通過字典建立 Example
        writer.write(example.SerializeToString())  # 將 Example 序列化並寫入 TFRecord 檔案

程式碼中我們需要注意的地方是:

  • 先讀取圖片,然後構建一個字典來作為這個example的格式;
  • 上面程式碼中,字典中有四個屬性,首先是image圖片本身的畫素值,然後有一個標籤,標籤是int型別,然後有一個float浮點型別,name是一個字串型別,這個string型別的需要轉換成byte位元組型別的才能進行儲存,所以這裡使用str.encode來把字串轉換成位元組;
  • 然後這個features再經過Example的封裝,再然後把這個example寫進這個tfrec檔案中。

這一段程式碼建議儲存下來,方便以後的直接參考和複製。構建tfrec檔案對於tensorflow處理圖片來說,應該是繞不過的一個步驟。

4 讀取tfrec檔案

現在,我們執行完上面的程式碼,應該生成了一個./train.tfrec檔案,下面我們再對這個檔案進行讀取。

import tensorflow as tf

dataset = tf.data.TFRecordDataset('./train.tfrec')

def decode(example):
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'float': tf.io.FixedLenFeature([1, 2], tf.float32),
        'name': tf.io.FixedLenFeature([], tf.string)
    }
    feature_dict = tf.io.parse_single_example(example, feature_description)
    feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'])  # 解碼 JEPG 圖片
    return feature_dict

dataset = dataset.map(decode).batch(4)
for i in dataset.take(1):
    print(i['image'].shape)
    print(i['label'].shape)
    print(i['float'].shape)
    print(bytes.decode(i['name'][0].numpy()))
  • 首先使用專門用來讀取tfrec檔案的方法tf.data.TFRecordDataset,進行讀取,建立了一個dataset,但是這個dataset並不能直接使用,需要對tfrec中的example進行一些解碼;
  • 自己寫一個解碼函式decode,首先寫一個特徵描述,我們知道在儲存tfrec的時候每一個example有四個特徵,這裡需要對每一個特徵確定他的型別,是string還是int還是float這樣的。
  • 然後通過這個特徵描述和tf.io.parse_single_example方法,從example中提取到對應的特徵;
  • 因為image是一個圖片張量,而我們讀取的時候是讀取的tf.string的型別,所以使用tf.io.decode_jpeg()來把字串解碼成一個tensor張量。
  • 最後使用上節課講過的.batch(4)把資料集每一個batch包含四個樣本。

上面程式碼輸出的結果為:

需要注意的是這個如何把name轉換成string型別的,如果已經在本地跑完了上面的程式碼,可以自己看看i['name']是一個什麼型別的,然後自己試試如何轉換成string型別的。上面的程式碼是能成功轉換的。

下一次的內容就是如何構建模型,然後怎麼把資料集餵給模型。

相關文章