TFrecord是一個Google提供的用於深度學習的資料格式,個人覺得很方便規範,值得學習。本文主要講的是怎麼儲存array,別的資料儲存較為簡單,舉一反三就行。
在TFrecord中的資料都需要進行一個轉化的過程,這個轉化分成三種
- int64
- float
- bytes
一般來講我們的圖片讀進來以後是兩種形式,
- tf.image.decode_jpeg 解碼圖片讀取成 (width,height,channels)的矩陣,這個讀取的方式和cv2.imread以及ndimage.imread一樣
- tf.image.convert_image_dtype會將讀進來的上面的矩陣歸一化,一般來講我們都要進行這個歸一化的過程。歸一化的好處可以去查。
但是儲存在TFrecord裡面的不能是array的形式,所以我們需要利用tostring()將上面的矩陣轉化成字串再通過tf.train.BytesList轉化成可以儲存的形式。
下面給個例項程式碼,大家看看就懂了
adjust_pic.py : 作用就是轉化Image大小
# -*- coding: utf-8 -*-
import tensorflow as tf
def resize(img_data, width, high, method=0):
return tf.image.resize_images(img_data,[width, high], method)
pic2tfrecords.py :將圖片存成TFrecord
# -*- coding: utf-8 -*-
# 將圖片儲存成 TFRecord
import os.path
import matplotlib.image as mpimg
import tensorflow as tf
import adjust_pic as ap
from PIL import Image
SAVE_PATH = `data/dataset.tfrecords`
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def load_data(datafile, width, high, method=0, save=False):
train_list = open(datafile,`r`)
# 準備一個 writer 用來寫 TFRecord 檔案
writer = tf.python_io.TFRecordWriter(SAVE_PATH)
with tf.Session() as sess:
for line in train_list:
# 獲得圖片的路徑和型別
tmp = line.strip().split(` `)
img_path = tmp[0]
label = int(tmp[1])
# 讀取圖片
image = tf.gfile.FastGFile(img_path, `r`).read()
# 解碼圖片(如果是 png 格式就使用 decode_png)
image = tf.image.decode_jpeg(image)
# 轉換資料型別
# 因為為了將圖片資料能夠儲存到 TFRecord 結構體中,所以需要將其圖片矩陣轉換成 string,所以為了在使用時能夠轉換回來,這裡確定下資料格式為 tf.float32
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# 既然都將圖片儲存成 TFRecord 了,那就先把圖片轉換成希望的大小吧
image = ap.resize(image, width, high)
# 執行 op: image
image = sess.run(image)
# 將其圖片矩陣轉換成 string
image_raw = image.tostring()
# 將資料整理成 TFRecord 需要的資料結構
example = tf.train.Example(features=tf.train.Features(feature={
`image_raw`: _bytes_feature(image_raw),
`label`: _int64_feature(label),
}))
# 寫 TFRecord
writer.write(example.SerializeToString())
writer.close()
load_data(`train_list.txt_bak`, 224, 224)
tfrecords2data.py :讀取Tfrecord裡的內容
# -*- coding: utf-8 -*-
# 從 TFRecord 中讀取並儲存圖片
import tensorflow as tf
import numpy as np
SAVE_PATH = `data/dataset.tfrecords`
def load_data(width, high):
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer([SAVE_PATH])
# 從 TFRecord 讀取內容並儲存到 serialized_example 中
_, serialized_example = reader.read(filename_queue)
# 讀取 serialized_example 的格式
features = tf.parse_single_example(
serialized_example,
features={
`image_raw`: tf.FixedLenFeature([], tf.string),
`label`: tf.FixedLenFeature([], tf.int64),
})
# 解析從 serialized_example 讀取到的內容
images = tf.decode_raw(features[`image_raw`], tf.uint8)
labels = tf.cast(features[`label`], tf.int64)
with tf.Session() as sess:
# 啟動多執行緒
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 因為我這裡只有 2 張圖片,所以下面迴圈 2 次
for i in range(2):
# 獲取一張圖片和其對應的型別
label, image = sess.run([labels, images])
# 這裡特別說明下:
# 因為要想把圖片儲存成 TFRecord,那就必須先將圖片矩陣轉換成 string,即:
# pic2tfrecords.py 中 image_raw = image.tostring() 這行
# 所以這裡需要執行下面這行將 string 轉換回來,否則會無法 reshape 成圖片矩陣,請看下面的小例子:
# a = np.array([[1, 2], [3, 4]], dtype=np.int64) # 2*2 的矩陣
# b = a.tostring()
# # 下面這行的輸出是 32,即: 2*2 之後還要再乘 8
# # 如果 tostring 之後的長度是 2*2=4 的話,那可以將 b 直接 reshape([2, 2]),但現在的長度是 2*2*8 = 32,所以無法直接 reshape
# # 同理如果你的圖片是 500*500*3 的話,那 tostring() 之後的長度是 500*500*3 後再乘上一個數
# print len(b)
#
# 但在網上有很多提供的程式碼裡都沒有下面這一行,你們那真的能 reshape ?
image = np.fromstring(image, dtype=np.float32)
# reshape 成圖片矩陣
image = tf.reshape(image, [224, 224, 3])
# 因為要儲存圖片,所以將其轉換成 uint8
image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
# 按照 jpeg 格式編碼
image = tf.image.encode_jpeg(image)
# 儲存圖片
with tf.gfile.GFile(`pic_%d.jpg` % label, `wb`) as f:
f.write(sess.run(image))
load_data(224, 224)
以上程式碼摘自TFRecord 的使用,覺得挺好的,沒改原樣照搬,我自己做實驗時改了很多,因為我是在im2txt的基礎上寫的。