TensorFlow高效讀取資料的方法——TFRecord的學習

戰爭熱誠發表於2019-07-20

關於TensorFlow讀取資料,官網給出了三種方法:

  • 供給資料(Feeding):在TensorFlow程式執行的每一步,讓python程式碼來供給資料。
  • 從檔案讀取資料:在TensorFlow圖的起始,讓一個輸入管線從檔案中讀取資料。
  • 預載入資料:在TensorFlow圖中定義常量或變數來儲存所有資料(僅適用於資料量比較小的情況)。

  對於資料量較小而言,可能一般選擇直接將資料載入進記憶體,然後再分batch輸入網路進行訓練(tip:使用這種方法時,結合yeild 使用更為簡潔)。但是如果資料量較大,這樣的方法就不適用了。因為太耗記憶體,所以這時最好使用TensorFlow提供的佇列queue,也就是第二種方法:從檔案讀取資料。對於一些特定的讀取,比如csv檔案格式,官網有相關的描述,在這裡我們學習一種比較通用的,高效的讀取方法,即使用TensorFlow內定標準格式——TFRecords。

1,什麼是TFRecords?

  一種儲存記錄的方法可以允許你講任意的資料轉換為TensorFlow所支援的格式,這種方法可以使TensorFlow的資料集更容易與網路應用架構相匹配。這種建議的方法就是使用TFRecords檔案。

  TFRecord是谷歌推薦的一種二進位制檔案格式,理論上它可以儲存任何格式的資訊。下面是Tensorflow的官網給出的文件結構,整個檔案由檔案長度資訊,長度校驗碼,資料,資料校驗碼組成。

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

  但是對於我們普通開發者而言,我們並不需要關心這些,TensorFlow提供了豐富的API可以幫助我們輕鬆地讀寫TFRecord檔案。

  TFRecord支援寫入三種格式的資料:string,int64,float32,以列表的形式分別通過tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 寫入 tf.train.Feature,如下所示:

#feature一般是多維陣列,要先轉為list
tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) 

#tostring函式後feature的形狀資訊會丟失,把shape也寫入
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape)))  

tf.train.Feature(float_list=tf.train.FloatList(value=[label]))

  通過上述操作,我們以dict的形式把要寫入的資料彙總,並構建 tf.train.Features,然後構建 tf.train.Example。如下:

def get_tfrecords_example(feature, label):
	tfrecords_features = {}
	feat_shape = feature.shape
	tfrecords_features['feature'] = tf.train.Feature(bytes_list=
                                              tf.train.BytesList(value=[feature.tostring()]))
	tfrecords_features['shape'] = tf.train.Feature(int64_list=
                                              tf.train.Int64List(value=list(feat_shape)))
	tfrecords_features['label'] = tf.train.Feature(float_list=
                                              tf.train.FloatList(value=label))

	return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

  把建立的tf.train.Example序列化下,便可以通過 tf.python_io.TFRecordWriter 寫入 tfrecord檔案中,如下:

#建立tfrecord的writer,檔名為xxx
tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord')  
#把資料寫入Example
exmp = get_tfrecords_example(feats[inx], labels[inx])  
#Example序列化
exmp_serial = exmp.SerializeToString()   
#寫入tfrecord檔案 
tfrecord_wrt.write(exmp_serial)   
#寫完後關閉tfrecord的writer
tfrecord_wrt.close()    

  TFRecord 的核心內容在於內部有一系列的Example,Example 是protocolbuf 協議(protocolbuf 是通用的協議格式,對主流的程式語言都適用。所以這些 List對應到Python語言當中是列表。而對於Java 或者 C/C++來說他們就是陣列)下的訊息體。

  一個Example訊息體包含了一系列的feature屬性。每一個feature是一個map,也就是 key-value 的鍵值對。key 取值是String型別。而value是Feature型別的訊息體。

  舉個例子,一個BytesList 可以儲存Byte 陣列,因此像字串,圖片,視訊等等都可以容納進去。所以TFRecord 可以儲存幾乎任何格式的資訊。

2,為什麼要用TFRecord?

  TFRerecord也不是非用不可,但確實是谷歌官網推薦的檔案格式。

  • 1,它特別適合於TensorFlow,或者說就是為TensorFlow量身打造的。
  • 2,因為TensorFlow開發者眾多,統一訓練的資料檔案格式是一件很有意義的事情,也有助於降低學習成本和遷移成本。

  TFRecords 其實是一種二進位制檔案,雖然它不如其他格式好理解,但是它能更好的利用記憶體,更方便賦值和移動,並且不需要單獨的標籤檔案,理論上,它能儲存所有的資訊。總而言之,這樣的檔案格式好處多多,所以讓我們利用起來。

3,為什麼要生成自己的圖片資料集TFrecords?

  使用TensorFlow進行網格訓練時,為了提高讀取資料的效率,一般建議將訓練資料轉化為TFrecords格式。

  使用tensorflow官網例子練習,我們會發現基本都是MNIST,CIFAR_10這種做好的資料集說事。所以對於我們這些初學者,完全不知道圖片該如何輸入。這時候學習自己製作資料集就非常有必要了。

4,如何將一張圖片和一個TFRecord 檔案相互轉化

  我們可以使用TFWriter輕鬆的完成這個任務。但是製作之前,我們要明確自己的目的。我們必須要想清楚,需要把什麼資訊儲存到TFRecord 檔案當中,這其實是最重要的。

  下面我們將一張圖片轉化為TFRecord,然後讀取一張TFRecord檔案,並展示為圖片。

4.1  將一張圖片轉化成TFRecord 檔案

  下面舉例說明嘗試把圖片轉化成TFRecord 檔案。  

  首先定義Example 訊息體。

Example Message {
    Features{
        feature{
            key:"name"
            value:{
                bytes_list:{
                    value:"cat"
                }
            }
        }
        feature{
            key:"shape"
            value:{
                int64_list:{
                    value:689
                    value:720
                    value:3
                }
            }
        }
        feature{
            key:"data"
            value:{
                bytes_list:{
                    value:0xbe
                    value:0xb2
                    ...
                    value:0x3
                }
            }
        }
    }
}

  上面的Example表示,要將一張 cat 圖片資訊寫進了 TFRecord 當中。而圖片資訊包含了圖片的名字,圖片的維度資訊還有圖片的資料,分別對應了 name,shape,content 3個feature。

  下面我們嘗試使用程式碼實現:

# _*_coding:utf-8_*_
import tensorflow as tf

def write_test(input, output):
    # 藉助於TFRecordWriter 才能將資訊寫入TFRecord 檔案
    writer = tf.python_io.TFRecordWriter(output)

    # 讀取圖片並進行解碼
    image = tf.read_file(input)
    image = tf.image.decode_jpeg(image)

    with tf.Session() as sess:
        image = sess.run(image)
        shape = image.shape
        # 將圖片轉換成string
        image_data = image.tostring()
        print(type(image))
        print(len(image_data))
        name = bytes('cat', encoding='utf-8')
        print(type(name))
        # 建立Example物件,並將Feature一一對應填充進去
        example = tf.train.Example(features=tf.train.Features(feature={
             'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
             'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
             'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
        }
        ))
        # 將example序列化成string 型別,然後寫入。
        writer.write(example.SerializeToString())
    writer.close()


if __name__ == '__main__':
    input_photo = 'cat.jpg'
    output_file = 'cat.tfrecord'
    write_test(input_photo, output_file)

  上述程式碼註釋比較詳細,所以我們就重點說一下下面三點:

  • 1,將圖片解碼,然後轉化成string資料,然後填充進去。
  • 2,Feature 的value 是列表,所以記得加上 []
  • 3,example需要呼叫 SerializetoString() 進行序列化後才行

4.2  TFRecord 檔案讀取為圖片

  我們將圖片的資訊寫入到一個tfrecord檔案當中。現在我們需要檢驗它是否正確。這就需要用到如何讀取TFRecord 檔案的知識點了。

  程式碼如下:

# _*_coding:utf-8_*_
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def _parse_record(example_photo):
    features = {
        'name': tf.FixedLenFeature((), tf.string),
        'shape': tf.FixedLenFeature([3], tf.int64),
        'data': tf.FixedLenFeature((), tf.string)
    }
    parsed_features = tf.parse_single_example(example_photo,features=features)
    return parsed_features

def read_test(input_file):
    # 用dataset讀取TFRecords檔案
    dataset = tf.data.TFRecordDataset(input_file)
    dataset = dataset.map(_parse_record)
    iterator = dataset.make_one_shot_iterator()

    with tf.Session() as sess:
        features = sess.run(iterator.get_next())
        name = features['name']
        name = name.decode()
        img_data = features['data']
        shape = features['shape']
        print("==============")
        print(type(shape))
        print(len(img_data))

        # 從bytes陣列中載入圖片原始資料,並重新reshape,它的結果是 ndarray 陣列
        img_data = np.fromstring(img_data, dtype=np.uint8)
        image_data = np.reshape(img_data, shape)

        plt.figure()
        # 顯示圖片
        plt.imshow(image_data)
        plt.show()

        # 將資料重新編碼成jpg圖片並儲存
        img = tf.image.encode_jpeg(image_data)
        tf.gfile.GFile('cat_encode.jpg', 'wb').write(img.eval())

if __name__ == '__main__':
    read_test("cat.tfrecord")

  下面解釋一下程式碼:

1,首先使用dataset去讀取tfrecord檔案

2,在解析example 的時候,用現成的API:tf.parse_single_example

3,用 np.fromstring() 方法就可以獲取解析後的string資料,記得把資料還原成 np.uint8

4,用 tf.image.encode_jepg() 方法可以將圖片資料編碼成 jpeg 格式

5,用 tf.gfile.GFile 物件可以把圖片資料儲存到本地

6,因為將圖片 shape 寫入了example 中,所以解析的時候必須指定維度,在這裡 [3],不然程式會報錯。

  執行程式後,可以看到圖片顯示如下:

 

5,如何將一個資料夾下多張圖片和一個TFRecord 檔案相互轉化

  下面我們將一個資料夾的圖片轉化為TFRecord,然後再將TFRecord讀取為圖片。

5.1 將一個資料夾下多張圖片轉化為一個TFRecord檔案

   下面舉例說明嘗試把圖片轉化成TFRecord 檔案。

# _*_coding:utf-8_*_
# 將圖片儲存成TFRecords
import os
import tensorflow as tf
from PIL import Image
import random
import cv2
import numpy as np


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# 生成字串型的屬性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


# 生成實數型的屬性
def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def read_image(filename, resize_height, resize_width, normalization=False):
    '''
        讀取圖片資料,預設返回的是uint8, [0, 255]
        :param filename:
        :param resize_height:
        :param resize_width:
        :param normalization:  是否歸一化到 [0.0, 1.0]
        :return:  返回的圖片資料
        '''
    bgr_image = cv2.imread(filename)
    # print(type(bgr_image))
    # 若是灰度圖則轉化為三通道
    if len(bgr_image.shape) == 2:
        print("Warning:gray image", filename)
        bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)
    # 將BGR轉化為RGB
    rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
    # show_image(filename, rgb_image)
    # rgb_image=Image.open(filename)
    if resize_width > 0 and resize_height > 0:
        rgb_image = cv2.resize(rgb_image, (resize_width, resize_height))
    rgb_image = np.asanyarray(rgb_image)
    if normalization:
        rgb_image = rgb_image / 255.0
    return rgb_image


def load_labels_file(filename, labels_num=1, shuffle=False):
    '''
        載圖txt檔案,檔案中每行為一個圖片資訊,且以空格隔開,影象路徑 標籤1  標籤2
        如  test_image/1.jpg 0 2
        :param filename:
        :param labels_num:  labels個數
        :param shuffle: 是否打亂順序
        :return:  images type-> list
        :return:labels type->lis\t
        '''
    images = []
    labels = []
    with open(filename) as f:
        lines_list = f.readlines()
        # print(lines_list)  # ['plane\\0499.jpg 4\n', 'plane\\0500.jpg 4\n']
        if shuffle:
            random.shuffle(lines_list)
        for lines in lines_list:
            line = lines.rstrip().split(" ")  # rstrip 刪除 string 字串末尾的空格.  ['plane\\0006.jpg', '4']
            label = []
            for i in range(labels_num):  # labels_num 1      0 1所以i只能取1
                label.append(int(line[i + 1]))  # 確保讀取的是列表的第二個元素
            # print(label)
            images.append(line[0])
            # labels.append(line[1])  # ['0', '4']
            labels.append(label)
    # print(images)
    # print(labels)
    return images, labels


def create_records(image_dir, file, output_record_dir, resize_height, resize_width, shuffle, log=5):
    '''
    實現將影象原始資料,label,長,寬等資訊儲存為record檔案
    注意:讀取的影象資料預設是uint8,再轉為tf的字串型BytesList儲存,解析請需要根據需要轉換型別
    :param image_dir:原始影象的目錄
    :param file:輸入儲存圖片資訊的txt檔案(image_dir+file構成圖片的路徑)
    :param output_record_dir:儲存record檔案的路徑
    :param resize_height:
    :param resize_width:
    PS:當resize_height或者resize_width=0是,不執行resize
    :param shuffle:是否打亂順序
    :param log:log資訊列印間隔
    '''
    # 載入檔案,僅獲取一個label
    images_list, labels_list = load_labels_file(file, 1, shuffle)

    writer = tf.python_io.TFRecordWriter(output_record_dir)
    for i, [image_name, labels] in enumerate(zip(images_list, labels_list)):
        image_path = os.path.join(image_dir, images_list[i])
        if not os.path.exists(image_path):
            print("Error:no image", image_path)
            continue
        image = read_image(image_path, resize_height, resize_width)
        image_raw = image.tostring()
        if i % log == 0 or i == len(images_list) - 1:
            print("-----------processing:%d--th------------" % (i))
            print('current image_path=%s' % (image_path), 'shape:{}'.format(image.shape),
                  'labels:{}'.format(labels))
        # 這裡僅儲存一個label,多label適當增加"'label': _int64_feature(label)"項
        label = labels[0]
        example = tf.train.Example(features=tf.train.Features(feature={
            'image_raw': _bytes_feature(image_raw),
            'height': _int64_feature(image.shape[0]),
            'width': _int64_feature(image.shape[1]),
            'depth': _int64_feature(image.shape[2]),
            'label': _int64_feature(label)
        }))
        writer.write(example.SerializeToString())
    writer.close()

def get_example_nums(tf_records_filenames):
    '''
    統計tf_records影象的個數(example)個數
    :param tf_records_filenames: tf_records檔案路徑
    :return:
    '''
    nums = 0
    for record in tf.python_io.tf_record_iterator(tf_records_filenames):
        nums += 1
    return nums

if __name__ == '__main__':
    resize_height = 224  # 指定儲存圖片高度
    resize_width = 224  # 指定儲存圖片寬度
    shuffle = True
    log = 5

    image_dir = 'dataset/train'
    train_labels = 'dataset/train.txt'
    train_record_output = 'train.tfrecord'
    create_records(image_dir, train_labels, train_record_output, resize_height, resize_width, shuffle, log)
    train_nums = get_example_nums(train_record_output)
    print("save train example nums={}".format(train_nums))

  

 5.2  將一個TFRecord檔案轉化為圖片顯示

  因為圖片太多,所以我們這裡只展示每個資料夾中第一張圖片即可。

  程式碼如下:

# _*_coding:utf-8_*_
# 將圖片儲存成TFRecords
import os
import tensorflow as tf
from PIL import Image
import random
import cv2
import numpy as np
import matplotlib.pyplot as plt

def read_records(filename,resize_height, resize_width,type=None):
    '''
    解析record檔案:原始檔的影象資料是RGB,uint8,[0,255],一般作為訓練資料時,需要歸一化到[0,1]
    :param filename:
    :param resize_height:
    :param resize_width:
    :param type:選擇影象資料的返回型別
         None:預設將uint8-[0,255]轉為float32-[0,255]
         normalization:歸一化float32-[0,1]
         centralization:歸一化float32-[0,1],再減均值中心化
    :return:
    '''
    # 建立檔案佇列,不限讀取的數量
    filename_queue = tf.train.string_input_producer([filename])
    # 為檔案佇列建立一個閱讀區
    reader = tf.TFRecordReader()
    # reader從檔案佇列中讀入一個序列化的樣本
    _, serialized_example = reader.read(filename_queue)

    # 解析符號化的樣本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'depth': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64)
        }
    )
    # 獲得影象原始的資料
    tf_image = tf.decode_raw(features["image_raw"], tf.uint8)

    tf_height = features['height']
    tf_width = features['width']
    tf_depth = features['depth']
    tf_label = tf.cast(features['label'], tf.int32)

    #PS 回覆原始影象 reshpe的大小必須與儲存之前的影象shape一致,否則報錯
    # 設定影象的維度
    tf_image = tf.reshape(tf_image, [resize_height, resize_width, 3])

    # 恢復資料後,才可以對影象進行resize_images:輸入 uint 輸出 float32
    # tf_image = tf.image.resize_images(tf_image, [224, 224])

    # 儲存的影象型別為 uint8 tensorflow訓練資料必須是tf.float32
    if type is None:
        tf_image = tf.cast(tf_image, tf.float32)
    # 【1】 若需要歸一化的話請使用
    elif type == 'normalization':
        # 僅當輸入資料是 uint8,才會歸一化 [0 , 255]
        tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0)
    elif type=='centralization':
        # 若需要歸一化,且中心化,假設均值為0.5 請使用
        tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0) - 0.5

    # 這裡僅僅返回影象和標籤
    return tf_image, tf_label


def show_image(title, image):
    '''
    顯示圖片
    :param title:  影象標題
    :param image:  影象的資料
    :return:
    '''
    plt.imshow(image)
    plt.axis('on')   # 關掉座標軸 為  off
    plt.title(title)  # 影象題目
    plt.show()


def disp_records(record_file,resize_height, resize_width,show_nums=4):
    '''
    解析record檔案,並顯示show_nums張圖片,主要用於驗證生成record檔案是否成功
    :param tfrecord_file: record檔案路徑
    :return:
    '''
    # 讀取record 函式
    tf_image, tf_label = read_records(record_file, resize_height, resize_width, type='normalization')
    # 顯示前4個圖片
    init_op = tf.global_variables_initializer()
    # init_op = tf.initialize_all_variables()
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(show_nums):  # 在會話中取出image和label
            image, label = sess.run([tf_image, tf_label])
            # image = tf_image.eval()
            # 直接從record解析的image是一個向量,需要reshape顯示
            # image = image.reshape([height,width,depth])
            print('shape:{},tpye:{},labels:{}'.format(image.shape, image.dtype, label))
            # pilimg = Image.fromarray(np.asarray(image_eval_reshape))
            # pilimg.show()
            show_image("image:%d"%(label), image)
        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    resize_height = 224  # 指定儲存圖片高度
    resize_width = 224  # 指定儲存圖片寬度
    shuffle = True
    log = 5

    image_dir = 'dataset/train'
    train_labels = 'dataset/train.txt'
    train_record_output = 'train.tfrecord'


    # 測試顯示函式
    disp_records(train_record_output, resize_height, resize_width)

  部分程式碼解析:

5.3,加入佇列

with tf.Session() as sess:
    sess.run(init_op)
    coord = tf.train.Coordinator()
# 啟動佇列 threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(show_nums): # 在會話中取出image和label image, label = sess.run([tf_image, tf_label])

  注意,啟動佇列那條code不能忘記,不然會卡死,這樣加入後,就可以做到和tensorflow官網一樣的二進位制資料集了。

6,生成分割多個record檔案

  當圖片資料很多時候,會導致單個record檔案超級巨大的情況,解決方法就是,將資料分成多個record檔案儲存,讀取時,只需要將多個record檔案的路徑列表交給“tf.train.string_input_producer”,

完整程式碼如下:(此處來自 此部落格

# -*-coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import os
import cv2
import math
import matplotlib.pyplot as plt
import random
from PIL import Image
 
 
##########################################################################
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成字串型的屬性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成實數型的屬性
def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))
 
def show_image(title,image):
    '''
    顯示圖片
    :param title: 影象標題
    :param image: 影象的資料
    :return:
    '''
    # plt.figure("show_image")
    # print(image.dtype)
    plt.imshow(image)
    plt.axis('on')    # 關掉座標軸為 off
    plt.title(title)  # 影象題目
    plt.show()
 
def load_labels_file(filename,labels_num=1):
    '''
    載圖txt檔案,檔案中每行為一個圖片資訊,且以空格隔開:影象路徑 標籤1 標籤2,如:test_image/1.jpg 0 2
    :param filename:
    :param labels_num :labels個數
    :return:images type->list
    :return:labels type->list
    '''
    images=[]
    labels=[]
    with open(filename) as f:
        for lines in f.readlines():
            line=lines.rstrip().split(' ')
            label=[]
            for i in range(labels_num):
                label.append(int(line[i+1]))
            images.append(line[0])
            labels.append(label)
    return images,labels
 
def read_image(filename, resize_height, resize_width):
    '''
    讀取圖片資料,預設返回的是uint8,[0,255]
    :param filename:
    :param resize_height:
    :param resize_width:
    :return: 返回的圖片資料是uint8,[0,255]
    '''
 
    bgr_image = cv2.imread(filename)
    if len(bgr_image.shape)==2:#若是灰度圖則轉為三通道
        print("Warning:gray image",filename)
        bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)
 
    rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#將BGR轉為RGB
    # show_image(filename,rgb_image)
    # rgb_image=Image.open(filename)
    if resize_height>0 and resize_width>0:
        rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))
    rgb_image=np.asanyarray(rgb_image)
    # show_image("src resize image",image)
 
    return rgb_image
 
 
def create_records(image_dir,file, record_txt_path, batchSize,resize_height, resize_width):
    '''
    實現將影象原始資料,label,長,寬等資訊儲存為record檔案
    注意:讀取的影象資料預設是uint8,再轉為tf的字串型BytesList儲存,解析請需要根據需要轉換型別
    :param image_dir:原始影象的目錄
    :param file:輸入儲存圖片資訊的txt檔案(image_dir+file構成圖片的路徑)
    :param output_record_txt_dir:儲存record檔案的路徑
    :param batchSize: 每batchSize個圖片儲存一個*.tfrecords,避免單個檔案過大
    :param resize_height:
    :param resize_width:
    PS:當resize_height或者resize_width=0是,不執行resize
    '''
    if os.path.exists(record_txt_path):
        os.remove(record_txt_path)
 
    setname, ext = record_txt_path.split('.')
 
    # 載入檔案,僅獲取一個label
    images_list, labels_list=load_labels_file(file,1)
    sample_num = len(images_list)
    # 打亂樣本的資料
    # random.shuffle(labels_list)
    batchNum = int(math.ceil(1.0 * sample_num / batchSize))
 
    for i in range(batchNum):
        start = i * batchSize
        end = min((i + 1) * batchSize, sample_num)
        batch_images = images_list[start:end]
        batch_labels = labels_list[start:end]
        # 逐個儲存*.tfrecords檔案
        filename = setname + '{0}.tfrecords'.format(i)
        print('save:%s' % (filename))
 
        writer = tf.python_io.TFRecordWriter(filename)
        for i, [image_name, labels] in enumerate(zip(batch_images, batch_labels)):
            image_path=os.path.join(image_dir,batch_images[i])
            if not os.path.exists(image_path):
                print('Err:no image',image_path)
                continue
            image = read_image(image_path, resize_height, resize_width)
            image_raw = image.tostring()
            print('image_path=%s,shape:( %d, %d, %d)' % (image_path,image.shape[0], image.shape[1], image.shape[2]),'labels:',labels)
            # 這裡僅儲存一個label,多label適當增加"'label': _int64_feature(label)"項
            label=labels[0]
            example = tf.train.Example(features=tf.train.Features(feature={
                'image_raw': _bytes_feature(image_raw),
                'height': _int64_feature(image.shape[0]),
                'width': _int64_feature(image.shape[1]),
                'depth': _int64_feature(image.shape[2]),
                'label': _int64_feature(label)
            }))
            writer.write(example.SerializeToString())
        writer.close()
 
        # 用txt儲存*.tfrecords檔案列表
        # record_list='{}.txt'.format(setname)
        with open(record_txt_path, 'a') as f:
            f.write(filename + '\n')
 
def read_records(filename,resize_height, resize_width):
    '''
    解析record檔案
    :param filename:儲存*.tfrecords檔案的txt檔案路徑
    :return:
    '''
    # 讀取txt中所有*.tfrecords檔案
    with open(filename, 'r') as f:
        lines = f.readlines()
        files_list=[]
        for line in lines:
            files_list.append(line.rstrip())
 
    # 建立檔案佇列,不限讀取的數量
    filename_queue = tf.train.string_input_producer(files_list,shuffle=False)
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader從檔案佇列中讀入一個序列化的樣本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符號化的樣本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'depth': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64)
        }
    )
    tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#獲得影象原始的資料
 
    tf_height = features['height']
    tf_width = features['width']
    tf_depth = features['depth']
    tf_label = tf.cast(features['label'], tf.int32)
    # tf_image=tf.reshape(tf_image, [-1])    # 轉換為行向量
    tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 設定影象的維度
    # 儲存的影象型別為uint8,這裡需要將型別轉為tf.float32
    # tf_image = tf.cast(tf_image, tf.float32)
    # [1]若需要歸一化請使用:
    tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)# 歸一化
    # tf_image = tf.cast(tf_image, tf.float32) * (1. / 255)  # 歸一化
    # [2]若需要歸一化,且中心化,假設均值為0.5,請使用:
    # tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 #中心化
    return tf_image, tf_height,tf_width,tf_depth,tf_label
 
def disp_records(record_file,resize_height, resize_width,show_nums=4):
    '''
    解析record檔案,並顯示show_nums張圖片,主要用於驗證生成record檔案是否成功
    :param tfrecord_file: record檔案路徑
    :param resize_height:
    :param resize_width:
    :param show_nums: 預設顯示前四張照片
    :return:
    '''
    tf_image, tf_height, tf_width, tf_depth, tf_label = read_records(record_file,resize_height, resize_width)  # 讀取函式
    # 顯示前show_nums個圖片
    init_op = tf.initialize_all_variables()
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(show_nums):
            image,height,width,depth,label = sess.run([tf_image,tf_height,tf_width,tf_depth,tf_label])  # 在會話中取出image和label
            # image = tf_image.eval()
            # 直接從record解析的image是一個向量,需要reshape顯示
            # image = image.reshape([height,width,depth])
            print('shape:',image.shape,'label:',label)
            # pilimg = Image.fromarray(np.asarray(image_eval_reshape))
            # pilimg.show()
            show_image("image:%d"%(label),image)
        coord.request_stop()
        coord.join(threads)
 
 
def batch_test(record_file,resize_height, resize_width):
    '''
    :param record_file: record檔案路徑
    :param resize_height:
    :param resize_width:
    :return:
    :PS:image_batch, label_batch一般作為網路的輸入
    '''
 
    tf_image,tf_height,tf_width,tf_depth,tf_label = read_records(record_file,resize_height, resize_width) # 讀取函式
 
    # 使用shuffle_batch可以隨機打亂輸入:
    # shuffle_batch用法:https://blog.csdn.net/u013555719/article/details/77679964
    min_after_dequeue = 100#該值越大,資料越亂,必須小於capacity
    batch_size = 4
    # capacity = (min_after_dequeue + (num_threads + a small safety margin∗batchsize)
    capacity = min_after_dequeue + 3 * batch_size#容量:一個整數,佇列中的最大的元素數
 
    image_batch, label_batch = tf.train.shuffle_batch([tf_image, tf_label],
                                                      batch_size=batch_size,
                                                      capacity=capacity,
                                                      min_after_dequeue=min_after_dequeue)
 
    init = tf.global_variables_initializer()
    with tf.Session() as sess:  # 開始一個會話
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(4):
            # 在會話中取出images和labels
            images, labels = sess.run([image_batch, label_batch])
            # 這裡僅顯示每個batch裡第一張圖片
            show_image("image", images[0, :, :, :])
            print(images.shape, labels)
        # 停止所有執行緒
        coord.request_stop()
        coord.join(threads)
 
 
if __name__ == '__main__':
    # 引數設定
    image_dir='dataset/train'
    train_file = 'dataset/train.txt'  # 圖片路徑
    output_record_txt = 'dataset/record/record.txt'#指定儲存record的檔案列表
    resize_height = 224  # 指定儲存圖片高度
    resize_width = 224  # 指定儲存圖片寬度
    batchSize=8000     #batchSize一般設定為8000,即每batchSize張照片儲存為一個record檔案
    # 產生record檔案
    create_records(image_dir=image_dir,
                   file=train_file,
                   record_txt_path=output_record_txt,
                   batchSize=batchSize,
                   resize_height=resize_height,
                   resize_width=resize_width)
 
    # 測試顯示函式
    disp_records(output_record_txt,resize_height, resize_width)
 
    # batch_test(output_record_txt,resize_height, resize_width)

  

7,直接讀取檔案的方式

  之前,我們都是將資料轉存為tfrecord檔案,訓練時候再去讀取,如果不想轉為record檔案,想直接讀取影象檔案進行訓練,可以使用下面的方法:

  filename.txt

0.jpg 0
1.jpg 0
2.jpg 0
3.jpg 0
4.jpg 0
5.jpg 1
6.jpg 1
7.jpg 1
8.jpg 1
9.jpg 1

  程式碼如下:

# -*-coding: utf-8 -*-

import tensorflow as tf
import glob
import numpy as np
import os
import matplotlib.pyplot as plt
 
import cv2
def show_image(title, image):
    '''
    顯示圖片
    :param title: 影象標題
    :param image: 影象的資料
    :return:
    '''
    # plt.imshow(image, cmap='gray')
    plt.imshow(image)
    plt.axis('on')  # 關掉座標軸為 off
    plt.title(title)  # 影象題目
    plt.show()
 
 
def tf_read_image(filename, resize_height, resize_width):
    '''
    讀取圖片
    :param filename:
    :param resize_height:
    :param resize_width:
    :return:
    '''
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    # tf_image = tf.cast(image_decoded, tf.float32)
    tf_image = tf.cast(image_decoded, tf.float32) * (1. / 255.0)  # 歸一化
    if resize_width>0 and resize_height>0:
        tf_image = tf.image.resize_images(tf_image, [resize_height, resize_width])
    # tf_image = tf.image.per_image_standardization(tf_image)  # 標準化[0,1](減均值除方差)
    return tf_image
 
 
def get_batch_images(image_list, label_list, batch_size, labels_nums, resize_height, resize_width, one_hot=False, shuffle=False):
    '''
    :param image_list:影象
    :param label_list:標籤
    :param batch_size:
    :param labels_nums:標籤個數
    :param one_hot:是否將labels轉為one_hot的形式
    :param shuffle:是否打亂順序,一般train時shuffle=True,驗證時shuffle=False
    :return:返回batch的images和labels
    '''
    # 生成佇列
    image_que, tf_label = tf.train.slice_input_producer([image_list, label_list], shuffle=shuffle)
    tf_image = tf_read_image(image_que, resize_height, resize_width)
    min_after_dequeue = 200
    capacity = min_after_dequeue + 3 * batch_size  # 保證capacity必須大於min_after_dequeue引數值
    if shuffle:
        images_batch, labels_batch = tf.train.shuffle_batch([tf_image, tf_label],
                                                            batch_size=batch_size,
                                                            capacity=capacity,
                                                            min_after_dequeue=min_after_dequeue)
    else:
        images_batch, labels_batch = tf.train.batch([tf_image, tf_label],
                                                    batch_size=batch_size,
                                                    capacity=capacity)
    if one_hot:
        labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)
    return images_batch, labels_batch
 
 
def load_image_labels(filename):
    '''
    載圖txt檔案,檔案中每行為一個圖片資訊,且以空格隔開:影象路徑 標籤1,如:test_image/1.jpg 0
    :param filename:
    :return:
    '''
    images_list = []
    labels_list = []
    with open(filename) as f:
        lines = f.readlines()
        for line in lines:
            # rstrip:用來去除結尾字元、空白符(包括\n、\r、\t、' ',即:換行、回車、製表符、空格)
            content = line.rstrip().split(' ')
            name = content[0]
            labels = []
            for value in content[1:]:
                labels.append(int(value))
            images_list.append(name)
            labels_list.append(labels)
    return images_list, labels_list
 
 
def batch_test(filename, image_dir):
    labels_nums = 2
    batch_size = 4
    resize_height = 200
    resize_width = 200
    image_list, label_list = load_image_labels(filename)
    image_list=[os.path.join(image_dir,image_name) for image_name in image_list]
 
    image_batch, labels_batch = get_batch_images(image_list=image_list,
                                                 label_list=label_list,
                                                 batch_size=batch_size,
                                                 labels_nums=labels_nums,
                                                 resize_height=resize_height, resize_width=resize_width,
                                                 one_hot=False, shuffle=True)
    with tf.Session() as sess:  # 開始一個會話
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(4):
            # 在會話中取出images和labels
            images, labels = sess.run([image_batch, labels_batch])
            # 這裡僅顯示每個batch裡第一張圖片
            show_image("image", images[0, :, :, :])
            print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))
 
        # 停止所有執行緒
        coord.request_stop()
        coord.join(threads)
 
 
if __name__ == "__main__":
    image_dir = "./dataset/train"
    filename = "./dataset/train.txt"
    batch_test(filename, image_dir)
 
 

  

8,資料輸入管道:pipeline機制解釋如下:

  TensorFlow引入了tf.data.Dataset模組,使其資料讀入的操作變得更為方便,而支援多執行緒(程式)的操作,也在效率上獲得了一定程度的提高。使用tf.data.Dataset模組的pipline機制,可實現CPU多執行緒處理輸入的資料,如讀取圖片和圖片的一些的預處理,這樣GPU可以專注於訓練過程,而CPU去準備資料。
  參考資料:

https://blog.csdn.net/u014061630/article/details/80776975

(五星推薦)TensorFlow全新的資料讀取方式:Dataset API入門教程:http://baijiahao.baidu.com/s?id=1583657817436843385&wfr=spider&for=pc

  從tfrecord檔案建立TFRecordDataset方式如下:

# 用dataset讀取TFRecords檔案
dataset = tf.contrib.data.TFRecordDataset(input_file)

  解析tfrecord 檔案的每條記錄,即序列化後的 tf.train.Example;使用 tf.parse_single_example 來解析:

feats = tf.parse_single_example(serial_exmp, features=data_dict)

  其中,data_dict 是一個dict,包含的key 是寫入tfrecord檔案時用的key ,相應的value是對應不同的資料型別,我們直接使用程式碼看,如下:

def _parse_record(example_photo):
    features = {
        'name': tf.FixedLenFeature((), tf.string),
        'shape': tf.FixedLenFeature([3], tf.int64),
        'data': tf.FixedLenFeature((), tf.string)
    }
    parsed_features = tf.parse_single_example(example_photo,features=features)
    return parsed_features

  解析tfrecord檔案中的所有記錄,我們需要使用dataset 的map 方法,如下:

dataset = dataset.map(_parse_record)

  Dataset支援一類特殊的操作:Transformation。一個Dataset通過Transformation變成一個新的Dataset。通常我們可以通過Transformation完成資料變換,打亂,組成batch,生成epoch等一系列操作。常用的Transformation有:map、batch、shuffle和repeat。

  map方法可以接受任意函式對dataset中的資料進行處理;另外可以使用repeat,shuffle,batch方法對dataset進行重複,混洗,分批;用repeat賦值dataset以進行多個epoch;如下:

dataset = dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)

  解析完資料後,便可以取出資料進行使用,通過建立iterator來進行,如下:

iterator = dataset.make_one_shot_iterator()

features = sess.run(iterator.get_next())

  下面分別介紹

8.1,map

    使用 tf.data.Dataset.map,我們可以很方便地對資料集中的各個元素進行預處理。因為輸入元素之間時獨立的,所以可以在多個 CPU 核心上並行地進行預處理。map 變換提供了一個 num_parallel_calls引數去指定並行的級別。

dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_calls)

8.2,prefetch

  tf.data.Dataset.prefetch 提供了 software pipelining 機制。該函式解耦了 資料產生的時間 和 資料消耗的時間。具體來說,該函式有一個後臺執行緒和一個內部快取區,在資料被請求前,就從 dataset 中預載入一些資料(進一步提高效能)。prefech(n) 一般作為最後一個 transformation,其中 n 為 batch_size。 prefetch 的使用方法如下:

dataset = dataset.batch(batch_size=FLAGS.batch_size)
dataset = dataset.prefetch(buffer_size=FLAGS.prefetch_buffer_size) # last transformation
return dataset

8.3,repeat

  repeat的功能就是將整個序列重複多次,主要用來處理機器學習中的epoch,假設原先的資料是一個epoch,使用repeat(5)就可以將之變成5個epoch:

    如果直接呼叫repeat()的話,生成的序列就會無限重複下去,沒有結束,因此也不會丟擲tf.errors.OutOfRangeError異常

8.4,完整程式碼如下:

# -*-coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import glob
import matplotlib.pyplot as plt
 
width=0
height=0
def show_image(title, image):
    '''
    顯示圖片
    :param title: 影象標題
    :param image: 影象的資料
    :return:
    '''
    # plt.figure("show_image")
    # print(image.dtype)
    plt.imshow(image)
    plt.axis('on')  # 關掉座標軸為 off
    plt.title(title)  # 影象題目
    plt.show()
 
 
def tf_read_image(filename, label):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.cast(image_decoded, tf.float32)
    if width>0 and height>0:
        image = tf.image.resize_images(image, [height, width])
    image = tf.cast(image, tf.float32) * (1. / 255.0)  # 歸一化
    return image, label
 
 
def input_fun(files_list, labels_list, batch_size, shuffle=True):
    '''
    :param files_list:
    :param labels_list:
    :param batch_size:
    :param shuffle:
    :return:
    '''
    # 構建資料集
    dataset = tf.data.Dataset.from_tensor_slices((files_list, labels_list))
    if shuffle:
        dataset = dataset.shuffle(100)
    dataset = dataset.repeat()  # 空為無限迴圈
    dataset = dataset.map(tf_read_image, num_parallel_calls=4)  # num_parallel_calls一般設定為cpu核心數量
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(2)  # software pipelining 機制
    return dataset
 
 
if __name__ == '__main__':
    data_dir = 'dataset/image/*.jpg'
    # labels_list = tf.constant([0,1,2,3,4])
    # labels_list = [1, 2, 3, 4, 5]
    files_list = glob.glob(data_dir)
    labels_list = np.arange(len(files_list))
    num_sample = len(files_list)
    batch_size = 1
    dataset = input_fun(files_list, labels_list, batch_size=batch_size, shuffle=False)
 
    # 需滿足:max_iterate*batch_size <=num_sample*num_epoch,否則越界
    max_iterate = 3
    with tf.Session() as sess:
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.make_initializer(dataset)
        sess.run(init_op)
        iterator = iterator.get_next()
        for i in range(max_iterate):
            images, labels = sess.run(iterator)
            show_image("image", images[0, :, :, :])
            print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))

  

9,AttributeError: module 'tensorflow' has no attribute 'data' 解決方法

  當我們使用tf 中的 dataset時,可能會出現如下錯誤:

  原因是tf 版本不同導致的錯誤。

  在編寫程式碼的時候,使用的tf版本不同,可能導致其Dataset API 放置的位置不同。當使用TensorFlow1.3的時候,Dataset API是放在 contrib 包裡面,而當使用TensorFlow1.4以後的版本,Dataset API已經從contrib 包中移除了,而變成了核心API的一員。故會產生報錯。

  解決方法:

  將下面程式碼:

# 用dataset讀取TFRecords檔案
dataset = tf.data.TFRecordDataset(input_file)

   改為此程式碼:

# 用dataset讀取TFRecords檔案
dataset = tf.contrib.data.TFRecordDataset(input_file)

  問題解決。

10,tf.gfile.FastGfile()函式學習

  函式如下:

tf.gfile.FastGFile(path,decodestyle) 

  函式功能:實現對圖片的讀取

  函式引數:path:圖片所在路徑

       decodestyle:圖片的解碼方式(‘r’:UTF-8編碼; ‘rb’:非UTF-8編碼)

例子如下:

img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()

  

11,Python zip()函式學習

  zip() 函式用於將可迭代的物件作為引數,將物件中對應的元素打包成一個個元組,然後返回由這些元組組成的列表。如果各個迭代器的元素個數不一致,則返回列表長度與最短的物件相同,利用*號操作符,可以將元組解壓為列表。

  在 Python 3.x 中為了減少記憶體,zip() 返回的是一個物件。如需展示列表,需手動 list() 轉換。

zip([iterable, ...])

引數說明: iterabl——一個或多個迭代器

返回值:返回元組列表

  例項:

>>>a = [1,2,3]
>>> b = [4,5,6]
>>> c = [4,5,6,7,8]

>>> zipped = zip(a,b)     # 打包為元組的列表
[(1, 4), (2, 5), (3, 6)]

>>> zip(a,c)              # 元素個數與最短的列表一致
[(1, 4), (2, 5), (3, 6)]

>>> zip(*zipped)          # 與 zip 相反,*zipped 可理解為解壓,返回二維矩陣式
[(1, 2, 3), (4, 5, 6)]

  

12,下一步計劃

1,為什麼前面使用Dataset,而用大多數博文中的 QueueRunner 呢?

  A:這是因為 Dataset 比 QueueRunner 新,而且是官方推薦的,Dataset 比較簡單。

2,學習了 TFRecord 相關知識,下一步學習什麼?

  A:可以嘗試將常見的資料集如 MNIST 和 CIFAR-10 轉換成 TFRecord 格式。

 

 

 參考文獻:https://blog.csdn.net/u012759136/article/details/52232266

https://blog.csdn.net/tengxing007/article/details/56847828/

https://blog.csdn.net/briblue/article/details/80789608 (五星推薦)

https://blog.csdn.net/happyhorizion/article/details/77894055  (五星推薦)

相關文章