深入淺出介紹TensorFlow資料集和估算器
TensorFlow 1.3 引入了兩個重要功能,您應當嘗試一下:
資料集:一種建立輸入管道(即,將資料讀入您的程式)的全新方式。
估算器:一種建立 TensorFlow 模型的高階方式。估算器包括適用於常見機器學習任務的預製模型,不過,您也可以使用它們建立自己的自定義模型。
下面是它們在 TensorFlow 架構內的裝配方式。結合使用這些估算器,可以輕鬆地建立 TensorFlow 模型和向模型提供資料:
我們的示例模型
為了探索這些功能,我們將構建一個模型並向您顯示相關的程式碼段。完整程式碼在這裡,其中包括獲取訓練和測試檔案的說明。請注意,編寫的程式碼旨在演示資料集和估算器的工作方式,並沒有為了實現最大效能而進行優化。
經過訓練的模型可以根據四個植物學特徵(萼片長度、萼片寬度、花瓣長度和花瓣寬度)對鳶尾花進行分類。因此,在推理期間,您可以為這四個特徵提供值,模型將預測花朵屬於以下三個美麗變種之中的哪一個:
▲ 從左到右依次為:山鳶尾(Radomil 攝影,CC BY-SA 3.0)、變色鳶尾(Dlanglois 攝影,CC BY-SA 3.0)和維吉尼亞鳶尾(Frank Mayfield 攝影,CC BY-SA 2.0)。
我們將使用下面的結構訓練深度神經網路分類器。所有輸入和輸出值都是 float32,輸出值的總和將等於 1(因為我們在預測屬於每種鳶尾花的可能性):
例如,輸出結果對山鳶尾來說可能是 0.05,對變色鳶尾是 0.9,對維吉尼亞鳶尾是 0.05,表示這種花有 90% 的可能性是變色鳶尾。
好了!我們現在已經定義模型,接下來看一看如何使用資料集和估算器訓練模型和進行預測。
資料集介紹
資料集是一種為 TensorFlow 模型建立輸入管道的新方式。使用此 API 的效能要比使用 feed_dict 或佇列式管道的效能高得多,而且此 API 更簡潔,使用起來更容易。儘管資料集在 1.3 版本中仍位於 tf.contrib.data 中,但是我們預計會在 1.4 版本中將此 API 移動到核心中,所以,是時候嘗試一下了。
從高層次而言,資料集由以下類組成:
其中:
資料集:基類,包含用於建立和轉換資料集的函式。允許您從記憶體中的資料或從 Python 生成器初始化資料集。
TextLineDataset:從文字檔案中讀取各行內容。
TFRecordDataset:從 TFRecord 檔案中讀取記錄。
FixedLengthRecordDataset:從二進位制檔案中讀取固定大小的記錄。
迭代器:提供了一種一次獲取一個資料集元素的方法。
我們的資料集
首先,我們來看一下要用來為模型提供資料的資料集。我們將從一個 CSV 檔案讀取資料,這個檔案的每一行都包含五個值 - 四個輸入值,加上標籤:
標籤的值如下所述:
山鳶尾為 0
變色鳶尾為 1
維吉尼亞鳶尾為 2。
表示我們的資料集
為了說明我們的資料集,我們先來建立一個特徵列表:
feature_names = [ 'SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth']
在訓練模型時,我們需要一個可以讀取輸入檔案並返回特徵和標籤資料的函式。估算器要求您建立一個具有以下格式的函式:
def input_fn(): ...<code>... return ({ 'SepalLength':[values], ..<etc>.., 'PetalWidth':[values] }, [IrisFlowerType])
返回值必須是一個按照如下方式組織的兩元素元組:
第一個元素必須是一個字典(其中的每個輸入特徵都是一個鍵),然後是一個用於訓練批次的值列表。
第二個元素是一個用於訓練批次的標籤列表。
由於我們要返回一批輸入特徵和訓練標籤,返回語句中的所有列表都將具有相同的長度。從技術角度而言,我們在這裡說的“列表”實際上是指 1-d TensorFlow 張量。
為了方便重複使用 input_fn,我們將向其中新增一些引數。這樣,我們就可以使用不同設定構建輸入函式。引數非常直觀:
file_path:要讀取的資料檔案。
perform_shuffle:是否應將記錄順序隨機化。
repeat_count:在資料集中迭代記錄的次數。例如,如果我們指定 1,那麼每個記錄都將讀取一次。如果我們不指定,迭代將永遠持續下去。
下面是我們使用 Dataset API 實現此函式的方式。我們會將它包裝到一個“輸入函式”中,這個輸入函式稍後將用於為我們的估算器模型提供資料:
def my_input_fn(file_path, perform_shuffle=False, repeat_count=1): def decode_csv(line): parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]]) label = parsed_line[-1:] # Last element is the label del parsed_line[-1] # Delete last element features = parsed_line # Everything (but last element) are the features d = dict(zip(feature_names, features)), label return d dataset = (tf.contrib.data.TextLineDataset(file_path) # Read text file .skip(1) # Skip header row .map(decode_csv)) # Transform each elem by applying decode_csv fn if perform_shuffle: # Randomizes input using a window of 256 elements (read into memory) dataset = dataset.shuffle(buffer_size=256) dataset = dataset.repeat(repeat_count) # Repeats dataset this # times dataset = dataset.batch(32) # Batch size to use iterator = dataset.make_one_shot_iterator() batch_features, batch_labels = iterator.get_next() return batch_features, batch_labels
注意以下內容:
TextLineDataset:在您使用 Dataset API 的檔案式資料集時,它將為您執行大量的記憶體管理工作。例如,您可以讀入比記憶體大得多的資料集檔案,或者以引數形式指定列表,讀入多個檔案。
shuffle:讀取 buffer_size 記錄,然後打亂(隨機化)它們的順序。
map:呼叫 decode_csv 函式,並將資料集中的每個元素作為一個引數(由於我們使用的是 TextLineDataset,每個元素都將是一行 CSV 文字)。然後,我們將向每一行應用 decode_csv 。
decode_csv:將每一行拆分成各個欄位,根據需要提供預設值。然後,返回一個包含欄位鍵和欄位值的字典。map 函式將使用字典更新資料集中的每個元素(行)。
以上是資料集的簡單介紹!為了娛樂一下,我們現在可以使用下面的函式列印第一個批次:
next_batch = my_input_fn(FILE, True) # Will return 32 random elements # Now let's try it out, retrieving and printing one batch of data. # Although this code looks strange, you don't need to understand # the details. with tf.Session() as sess: first_batch = sess.run(next_batch) print(first_batch) # Output ({'SepalLength': array([ 5.4000001, ...<repeat to 32 elems>], dtype=float32), 'PetalWidth': array([ 0.40000001, ...<repeat to 32 elems>], dtype=float32), ... }, [array([[2], ...<repeat to 32 elems>], dtype=int32) # Labels )
這就是我們需要 Dataset API 在實現模型時所做的全部工作。不過,資料集還有很多功能;請參閱我們在“閱讀原文”末尾列出的更多資源。
估算器介紹
估算器是一種高階 API,使用這種 API,您在訓練 TensorFlow 模型時就不再像之前那樣需要編寫大量的樣板檔案程式碼。估算器也非常靈活,如果您對模型有具體的要求,它允許您替換預設行為。
使用估算器,您可以通過兩種可能的方式構建模型:
預製估算器 - 這些是預先定義的估算器,旨在生成特定型別的模型。在這篇博文中,我們將使用 DNNClassifier 預製估算器。
估算器(基類)- 允許您使用 model_fn 函式完全掌控模型的建立方式。我們將在單獨的博文中介紹如何操作。
下面是估算器的類圖:
我們希望在未來版本中新增更多的預製估算器。
正如您所看到的,所有估算器都使用 input_fn,它為估算器提供輸入資料。在我們的示例中,我們將重用 my_input_fn,這個函式是我們專門為演示定義的。
下面的程式碼可以將預測鳶尾花型別的估算器例項化:
# Create the feature_columns, which specifies the input to our model. # All our input features are numeric, so use numeric_column for each one. feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names] # Create a deep neural network regression classifier. # Use the DNNClassifier pre-made estimator classifier = tf.estimator.DNNClassifier( feature_columns=feature_columns, # The input features to our model hidden_units=[10, 10], # Two layers, each with 10 neurons n_classes=3, model_dir=PATH) # Path to where checkpoints etc are stored
我們現在有了一個可以開始訓練的估算器。
檢視全文(“訓練模型”“擴充”“總結”“資源列表”等)及文中連結,請點選文末“閱讀原文”。
相關文章
- 深入淺出MyBatis:JDBC和MyBatis介紹MyBatisJDBC
- 深入淺出MongoDB複製【經典介紹】MongoDB
- 深入淺出JMS(一)——JMS簡單介紹
- MNIST資料集介紹
- Cora 資料集介紹
- nuPlan資料集介紹
- 深入淺出資料字典摘要
- Tensorflow介紹和安裝
- 使用 TensorFlow Hub 和估算器構建文字分類模型文字分類模型
- 無監控,不運維!深入淺出介紹ChengYing監控設計和使用運維
- 介紹Ext JS 4.2的新特性的《深入淺出Ext JS》上市JS
- Oracle資料庫字符集介紹Oracle資料庫
- 深入淺出React和ReduxReactRedux
- 深入淺出Redis-redis哨兵叢集Redis
- 深入淺出 Java 同步器Java
- Redis介紹、使用、資料結構和叢集模式總結Redis資料結構模式
- 深入淺出 Runtime(二):資料結構資料結構
- 深入淺出FE(十四)深入淺出websocketWeb
- 深入淺出大端和小端
- 深入淺出Websocket(二)分散式Websocket叢集Web分散式
- 深入淺出搜尋系列文章彙集
- 評侯捷的《深入淺出MFC》和李久進的《MFC深入淺出》
- 深入淺出瀏覽器渲染原理瀏覽器
- 深入淺出Lua虛擬機器虛擬機
- 深入淺出 Vue 系列 -- 資料劫持實現原理Vue
- 深入淺出:5G和HTTPHTTP
- TensorFlow 入門(MNIST資料集)
- RxRetrofit – 終極封裝 – 深入淺出 & 資料快取封裝快取
- 深入淺出:瞭解時序資料庫 InfluxDB資料庫UX
- 深入淺出FaaS應用場景:資料編排
- RxRetrofit - 終極封裝 - 深入淺出 & 資料快取封裝快取
- 資料庫設計正規化深入淺出(轉)資料庫
- 深入iOS系統底層之指令集介紹iOS
- TensorFlow除錯程式介紹除錯
- 深入淺出——MVCMVC
- 深入淺出mongooseGo
- HTTP深入淺出HTTP
- 深入淺出IO