深入淺出介紹TensorFlow資料集和估算器

谷歌開發者_發表於2017-09-26

TensorFlow 1.3 引入了兩個重要功能,您應當嘗試一下:

  • 資料集:一種建立輸入管道(即,將資料讀入您的程式)的全新方式。

  • 估算器:一種建立 TensorFlow 模型的高階方式。估算器包括適用於常見機器學習任務的預製模型,不過,您也可以使用它們建立自己的自定義模型。


下面是它們在 TensorFlow 架構內的裝配方式。結合使用這些估算器,可以輕鬆地建立 TensorFlow 模型和向模型提供資料:


640?wx_fmt=png&wxfrom=5&wx_lazy=1



我們的示例模型

為了探索這些功能,我們將構建一個模型並向您顯示相關的程式碼段。完整程式碼在這裡,其中包括獲取訓練和測試檔案的說明。請注意,編寫的程式碼旨在演示資料集和估算器的工作方式,並沒有為了實現最大效能而進行優化。

經過訓練的模型可以根據四個植物學特徵(萼片長度、萼片寬度、花瓣長度和花瓣寬度)對鳶尾花進行分類。因此,在推理期間,您可以為這四個特徵提供值,模型將預測花朵屬於以下三個美麗變種之中的哪一個:

0?wx_fmt=jpeg

▲ 從左到右依次為:山鳶尾(Radomil 攝影,CC BY-SA 3.0)、變色鳶尾(Dlanglois 攝影,CC BY-SA 3.0)和維吉尼亞鳶尾(Frank Mayfield 攝影,CC BY-SA 2.0)。


我們將使用下面的結構訓練深度神經網路分類器。所有輸入和輸出值都是 float32,輸出值的總和將等於 1(因為我們在預測屬於每種鳶尾花的可能性):

0?wx_fmt=jpeg


例如,輸出結果對山鳶尾來說可能是 0.05,對變色鳶尾是 0.9,對維吉尼亞鳶尾是 0.05,表示這種花有 90% 的可能性是變色鳶尾。

好了!我們現在已經定義模型,接下來看一看如何使用資料集和估算器訓練模型和進行預測。



資料集介紹

資料集是一種為 TensorFlow 模型建立輸入管道的新方式。使用此 API 的效能要比使用 feed_dict 或佇列式管道的效能高得多,而且此 API 更簡潔,使用起來更容易。儘管資料集在 1.3 版本中仍位於 tf.contrib.data 中,但是我們預計會在 1.4 版本中將此 API 移動到核心中,所以,是時候嘗試一下了。

從高層次而言,資料集由以下類組成:

0?wx_fmt=jpeg


其中:

  • 資料集:基類,包含用於建立和轉換資料集的函式。允許您從記憶體中的資料或從 Python 生成器初始化資料集。

  • TextLineDataset:從文字檔案中讀取各行內容。

  • TFRecordDataset:從 TFRecord 檔案中讀取記錄。

  • FixedLengthRecordDataset:從二進位制檔案中讀取固定大小的記錄。

  • 迭代器:提供了一種一次獲取一個資料集元素的方法。



我們的資料集

首先,我們來看一下要用來為模型提供資料的資料集。我們將從一個 CSV 檔案讀取資料,這個檔案的每一行都包含五個值 - 四個輸入值,加上標籤:

0?wx_fmt=jpeg


標籤的值如下所述:

  • 山鳶尾為 0

  • 變色鳶尾為 1

  • 維吉尼亞鳶尾為 2。



表示我們的資料集

為了說明我們的資料集,我們先來建立一個特徵列表:

feature_names = [
    'SepalLength',
    'SepalWidth',
    'PetalLength',
    'PetalWidth']


在訓練模型時,我們需要一個可以讀取輸入檔案並返回特徵和標籤資料的函式。估算器要求您建立一個具有以下格式的函式:

def input_fn():
    ...<code>...
    return ({ 'SepalLength':[values], ..<etc>.., 'PetalWidth':[values] },
            [IrisFlowerType])


返回值必須是一個按照如下方式組織的兩元素元組:

  • 第一個元素必須是一個字典(其中的每個輸入特徵都是一個鍵),然後是一個用於訓練批次的值列表。

  • 第二個元素是一個用於訓練批次的標籤列表。


由於我們要返回一批輸入特徵和訓練標籤,返回語句中的所有列表都將具有相同的長度。從技術角度而言,我們在這裡說的“列表”實際上是指 1-d TensorFlow 張量。

為了方便重複使用 input_fn,我們將向其中新增一些引數。這樣,我們就可以使用不同設定構建輸入函式。引數非常直觀:

  • file_path:要讀取的資料檔案。

  • perform_shuffle:是否應將記錄順序隨機化。

  • repeat_count:在資料集中迭代記錄的次數。例如,如果我們指定 1,那麼每個記錄都將讀取一次。如果我們不指定,迭代將永遠持續下去。


下面是我們使用 Dataset API 實現此函式的方式。我們會將它包裝到一個“輸入函式”中,這個輸入函式稍後將用於為我們的估算器模型提供資料:

def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
   def decode_csv(line):
       parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
       label = parsed_line[-1:] # Last element is the label
       del parsed_line[-1] # Delete last element
       features = parsed_line # Everything (but last element) are the features
       d = dict(zip(feature_names, features)), label
       return d

   dataset = (tf.contrib.data.TextLineDataset(file_path) # Read text file
       .skip(1) # Skip header row
       .map(decode_csv)) # Transform each elem by applying decode_csv fn
   if perform_shuffle:
       # Randomizes input using a window of 256 elements (read into memory)
       dataset = dataset.shuffle(buffer_size=256)
   dataset = dataset.repeat(repeat_count) # Repeats dataset this # times
   dataset = dataset.batch(32)  # Batch size to use
   iterator = dataset.make_one_shot_iterator()
   batch_features, batch_labels = iterator.get_next()
   return batch_features, batch_labels


注意以下內容:

  • TextLineDataset:在您使用 Dataset API 的檔案式資料集時,它將為您執行大量的記憶體管理工作。例如,您可以讀入比記憶體大得多的資料集檔案,或者以引數形式指定列表,讀入多個檔案。

  • shuffle:讀取 buffer_size 記錄,然後打亂(隨機化)它們的順序。

  • map:呼叫 decode_csv 函式,並將資料集中的每個元素作為一個引數(由於我們使用的是 TextLineDataset,每個元素都將是一行 CSV 文字)。然後,我們將向每一行應用 decode_csv 。

  • decode_csv:將每一行拆分成各個欄位,根據需要提供預設值。然後,返回一個包含欄位鍵和欄位值的字典。map 函式將使用字典更新資料集中的每個元素(行)。


以上是資料集的簡單介紹!為了娛樂一下,我們現在可以使用下面的函式列印第一個批次:

next_batch = my_input_fn(FILE, True) # Will return 32 random elements

# Now let's try it out, retrieving and printing one batch of data.
# Although this code looks strange, you don't need to understand
# the details.
with tf.Session() as sess:
    first_batch = sess.run(next_batch)
print(first_batch)

# Output
({'SepalLength': array([ 5.4000001, ...<repeat to 32 elems>], dtype=float32),
  'PetalWidth': array([ 0.40000001, ...<repeat to 32 elems>], dtype=float32),
  ...
 },
 [array([[2], ...<repeat to 32 elems>], dtype=int32) # Labels
)


這就是我們需要 Dataset API 在實現模型時所做的全部工作。不過,資料集還有很多功能;請參閱我們在“閱讀原文”末尾列出的更多資源。



估算器介紹

估算器是一種高階 API,使用這種 API,您在訓練 TensorFlow 模型時就不再像之前那樣需要編寫大量的樣板檔案程式碼。估算器也非常靈活,如果您對模型有具體的要求,它允許您替換預設行為。

使用估算器,您可以通過兩種可能的方式構建模型:

  • 預製估算器 - 這些是預先定義的估算器,旨在生成特定型別的模型。在這篇博文中,我們將使用 DNNClassifier 預製估算器。

  • 估算器(基類)- 允許您使用 model_fn 函式完全掌控模型的建立方式。我們將在單獨的博文中介紹如何操作。


下面是估算器的類圖:

0?wx_fmt=jpeg


我們希望在未來版本中新增更多的預製估算器。

正如您所看到的,所有估算器都使用 input_fn,它為估算器提供輸入資料。在我們的示例中,我們將重用 my_input_fn,這個函式是我們專門為演示定義的。

下面的程式碼可以將預測鳶尾花型別的估算器例項化:

# Create the feature_columns, which specifies the input to our model.
# All our input features are numeric, so use numeric_column for each one.
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names]

# Create a deep neural network regression classifier.
# Use the DNNClassifier pre-made estimator
classifier = tf.estimator.DNNClassifier(
   feature_columns=feature_columns, # The input features to our model
   hidden_units=[10, 10], # Two layers, each with 10 neurons
   n_classes=3,
   model_dir=PATH) # Path to where checkpoints etc are stored


我們現在有了一個可以開始訓練的估算器。


檢視全文(“訓練模型”“擴充”“總結”“資源列表”等)及文中連結,請點選文末“閱讀原文”。


0?wx_fmt=gif

相關文章