tensorflow2.0在訓練資料集的時候,fit和fit_generator的使用

狙擊 妳吢發表於2020-11-24

model.fit函式

fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,
    validation_split=0.0, validation_data=None, shuffle=True, class_weight=None,
    sample_weight=None, initial_epoch=0, steps_per_epoch=None,
    validation_steps=None, validation_batch_size=None, validation_freq=1,
    max_queue_size=10, workers=1, use_multiprocessing=False
)

為模型訓練固定的批次(資料集上的迭代)。

引數:

引數作用
x輸入資料.它可能是:Numpy陣列(或類似陣列的陣列)或陣列列表(如果模型具有多個輸入)。TensorFlow張量或張量列表(如果模型具有多個輸入)。如果模型已命名輸入,則dict將輸入名稱對映到相應的陣列/張量。一個tf.data資料集。應該返回(inputs, targets)或 的元組(inputs, targets, sample_weights)。生成器或keras.utils.Sequence返回(inputs, targets) 或(inputs, targets, sample_weights)。下面給出了迭代器型別(資料集,生成器,序列)的拆包行為的更詳細描述。
y目標資料。像輸入資料一樣x,它可以是Numpy陣列或TensorFlow張量。它應該與x(您不能有Numpy輸入和張量目標,或者相反)保持一致。如果x是keras.utils.Sequence,y則不應該指定,生成器或例項(因為將從中獲取目標x)。
batch_size整數或None。每個梯度更新的樣本數。如果未指定,則預設為32。如果資料是以資料集,生成器或例項的形式(因為它們生成批次),則不要指定。 batch_sizebatch_sizekeras.utils.Sequence
epochs整數。訓練模型的時期數。時期是整個x和所y 提供資料的迭代。請注意,在與結合, 應理解為“最後時期”。不會針對給出的多次迭代訓練模型,而只是對到達索引的時期進行訓練。 initial_epochepochsepochsepochs
verbose0、1或2。詳細模式。0 =靜音,1 =進度條,2 =每個時期一行。請注意,進度條在登入到檔案時不是特別有用,因此,如果不以互動方式執行(例如,在生產環境中),建議使用verbose = 2。
callbackskeras.callbacks.Callback例項 列表。訓練期間要應用的回撥列表。請參閱tf.keras.callbacks。
validation_split在0到1之間浮動。將訓練資料的分數用作驗證資料。模型將分開訓練資料的這一部分,不對其進行訓練,並且將在每個時期結束時評估此資料的損失和任何模型度量。在改組之前,從x和中y提供的最後一個樣本中選擇驗證資料。當x是資料集,生成器或keras.utils.Sequence例項時, 不支援此引數。
validation_data在每個時期結束時用於評估損失的資料和任何模型指標。該模型將不會根據此資料進行訓練。因此,請注意以下事實:使用 或不受正則化層(如噪聲和壓降)影響的資料驗證損失。 將覆蓋。 可能: validation_splitvalidation_datavalidation_datavalidation_splitvalidation_data 1.(x_val, y_val)Numpy陣列或張量的元組. 2(x_val, y_val, val_sample_weights)Numpy陣列的元組 .資料集對於前兩種情況,batch_size必須提供。對於最後一種情況,validation_steps可以提供。請注意,validation_data它並不支援xdict,generator或中支援的所有資料型別keras.utils.Sequence。
shuffle布林值(是否在每個紀元之前改組訓練資料)或str(用於“批處理”)。當x是生成器時,將忽略此引數。“批處理”是處理HDF5資料限制的特殊選項;它以批量大小的塊洗牌。當沒有任何效果是不是。 steps_per_epochNone
class_weight可選的字典對映類索引(整數)到權重(浮動)值,用於加權損失函式(僅在訓練期間)。這可能有助於告訴模型“更多關注”來自代表性不足的類的樣本。
sample_weight訓練樣本的可選Numpy權重陣列,用於加權損失函式(僅在訓練過程中)。您可以傳遞長度與輸入樣本相同的平坦(1D)Numpy陣列(權重和樣本之間的1:1對映),或者對於時間資料,可以傳遞帶有shape的2D陣列 以應用每個樣品的每個時間步均具有不同的權重。如果是資料集,生成器或 例項,而將sample_weights作為的第三個元素,則不支援此引數。 (samples, sequence_length)xkeras.utils.Sequencex
initial_epoch整數。開始訓練的時期(用於恢復以前的訓練執行)。
steps_per_epoch整數或None。宣告一個紀元完成並開始下一個紀元之前的總步數(一批樣品)。使用輸入張量(例如TensorFlow資料張量)進行訓練時,預設None值等於資料集中的樣本數除以批次大小;如果無法確定,則預設為1。如果x是 tf.data資料集,並且’steps_per_epoch’為None,則該紀元將執行直到輸入資料集用盡。傳遞無限重複的資料集時,必須指定 引數。陣列輸入不支援此引數。 steps_per_epoch
validation_steps僅在提供時才相關,並且是資料集。在每個時期結束時執行驗證時,在停止之前要繪製的步驟總數(樣本批次)。如果“ validation_steps”為“無”,則驗證將一直進行到資料集用盡。如果是無限重複的資料集,它將陷入無限迴圈。如果指定了“ validation_steps”,並且僅消耗了一部分資料集,則評估將在每個時期從資料集的開頭開始。這樣可以確保每次都使用相同的驗證樣本。 validation_datatf.datavalidation_data
validation_batch_size整數或None。每個驗證批次的樣品數量。如果未指定,則預設為。不要指定資料是資料集,生成器還是例項的形式(因為它們生成批處理)。 batch_sizevalidation_batch_sizekeras.utils.Sequence
validation_freq僅在提供驗證資料時才相關。整數或例項(例如列表,元組等)。如果為整數,則指定在執行新的驗證執行之前要執行多少個訓練時期,例如,每2個時期執行一次驗證。如果是容器,則指定要執行驗證的時期,例如,在第一個,第二個和第十個時期的末尾執行驗證。 collections_abc.Containervalidation_freq=2validation_freq=[1, 2, 10]
max_queue_size整數。keras.utils.Sequence 僅用於生成器或輸入。生成器佇列的最大大小。如果未指定,則預設為10。 max_queue_size
workers整數。keras.utils.Sequence僅用於生成器或輸入。使用基於程式的執行緒時,要啟動的最大程式數。如果未指定,workers 則預設為1。如果為0,將在主執行緒上執行生成器。
use_multiprocessing布林值。keras.utils.Sequence僅用於生成器或 輸入。如果為True,請使用基於程式的執行緒。如果未指定,則預設為 。請注意,由於此實現依賴於多處理,因此不應將不可拾取的引數傳遞給生成器,因為它們無法輕易傳遞給子程式。 use_multiprocessingFalse

類似於迭代器的輸入的拆包行為:一種常見的模式是將tf.data.Dataset,generator或tf.keras.utils.Sequence傳遞給fitx引數,這實際上不僅會產生特徵(x),而且會產生可選結果目標(y)和樣本權重。Keras要求此類類似迭代器的輸出必須明確。迭代器應返回長度為1、2或3的元組,其中可選的第二和第三元素將分別用於y和sample_weight。提供的任何其他型別將被包裹在一個元組的長度中,從而將所有內容有效地視為“ x”。發出命令時,它們仍應遵循頂級元組結構。例如({“x0”: x0, “x1”: x1}, y)。Keras不會嘗試從單個字典的鍵中分離特徵,目標和權重。值得注意的不受支援的資料型別是namedtuple。原因是它的行為類似於有序資料型別(元組)和對映資料型別(dict)。因此,給定形式的namedtuple: namedtuple(“example_tuple”, [“y”, “x”]) 在解釋值時是否反轉元素的順序是不明確的。更糟糕的是以下形式的元組: namedtuple(“other_tuple”, [“x”, “y”, “z”]) 尚不清楚該元組是否打算解包為x,y和sample_weight或作為單個元素傳遞給x。結果,如果資料處理程式碼遇到一個命名元組,它將僅引發ValueError。(以及糾正該問題的說明。)

函式返回:
歷史記錄物件。它的History.history屬性記錄了連續時期的訓練損失值和度量值,以及驗證損失值和驗證度量值(如果適用)。

fit_generator函式

fit_generator(
    generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None,
    validation_data=None, validation_steps=None, validation_freq=1,
    class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False,
    shuffle=True, initial_epoch=0
)

使模型適合Python生成器逐批生成的資料。

在使用tensorflow2.0.0版本的時候的,去執行了tensorflow2.2.0版本的yoloV4程式碼,導致報錯出現:
在這裡插入圖片描述
在這裡插入圖片描述

TypeError: int() argument must be a string, a bytes-like object or a number, not ‘tuple’

原來是因為在2.0.0版本的時候Model.fit不支援生成器建立的資料集,因此會出現錯誤。
把Model.fit函式換成Model.fit_generator函式,則成功解決這個問題。

如果幫助到您,點贊論評走一波!!!!!
謝謝各位大佬!!!!

相關文章