關於Tensorflow2.0 keras的子類式多輸入多輸出

ckxllf發表於2021-03-11

  1.關鍵程式碼

  在定義好輸入層、輸出層後使用類 配置inputs outputs引數(陣列)

  網路模型搭建

  class WideDeepModel(tf.keras.models.Model):

  def __init__(self):

  super(WideDeepModel, self).__init__()

  self.hidden1_layer = tf.keras.layers.Dense(30, activation='relu')

  self.hidden2_layer = tf.keras.layers.Dense(30, activation='relu')

  self.output_layer1 = tf.keras.layers.Dense(1)

  self.output_layer2 = tf.keras.layers.Dense(1)

  def call(self, inputs, training=None, mask=None):

  """完成模型的正向計算"""

  input_wide = inputs[0] # 輸入1

  input_deep = inputs[1] # 輸入2

  hidden1 = self.hidden1_layer(input_deep)

  hidden2 = self.hidden2_layer(hidden1)

  concat = tf.keras.layers.concatenate([input_wide, hidden2])

  output1 = self.output_layer1(concat) # 輸出1

  output2 = self.output_layer2(hidden2) # 輸出2

  return [output1, output2] # 輸出組合

  # 構建網路

  model = WideDeepModel()

  model.build(input_shape=[(None, 5), (None, 6)])

  print(model.layers)

  model.summary()

  完整程式碼:

  import pprint

  import sys

  import matplotlib as mpl

  import matplotlib.pyplot as plt

  import numpy as np

  import pandas as pd

  import sklearn

  import tensorflow as tf

  from tensorflow import keras

  print(tf.__version__)

  print(sys.version_info)

  for module in mpl, np, pd, sklearn, keras, tf:

  print(module.__name__, module.__version__)

  from sklearn.datasets import fetch_california_housing

  # 1.載入資料集 波士頓房價預測

  housing = fetch_california_housing()

  print(housing.DESCR)

  print(housing.data.shape)

  print(housing.target.shape)

  pprint.pprint(housing.data[:5])

  pprint.pprint(housing.target[:5])

  from sklearn.model_selection import train_test_split

  # 2.拆分資料集

  # 訓練集與測試集拆分

  x_train_all, x_test, y_train_all, y_test = train_test_split(housing.data,

  housing.target,

  random_state=7,

  test_size=0.20)

  # 訓練集與驗證集的拆分

  x_train, x_valid, y_train, y_valid = train_test_split(

  x_train_all, y_train_all, random_state=11, test_size=0.20)

  print(x_train.shape, y_train.shape)

  print(x_valid.shape, y_valid.shape)

  print(x_test.shape, y_test.shape)

  from sklearn.preprocessing import StandardScaler

  scaler = StandardScaler()

  # 3、資料預處理 資料集的歸一化

  x_train_scaled = scaler.fit_transform(x_train)

  x_valid_scaled = scaler.transform(x_valid)

  x_test_scaled = scaler.transform(x_test)

  # 4、網路模型的搭建

  # 子類API

  class WideDeepModel(tf.keras.models.Model):

  def __init__(self):

  super(WideDeepModel, self).__init__()

  self.hidden1_layer = tf.keras.layers.Dense(30, activation='relu')

  self.hidden2_layer = tf.keras.layers.Dense(30, activation='relu')

  self.output_layer1 = tf.keras.layers.Dense(1)

  self.output_layer2 = tf.keras.layers.Dense(1)

  def call(self, inputs, training=None, mask=None):

  """完成模型的正向計算"""

  input_wide = inputs[0] # 輸入1

  input_deep = inputs[1] # 輸入2

  hidden1 = self.hidden1_layer(input_deep)

  hidden2 = self.hidden2_layer(hidden1)

  concat = tf.keras.layers.concatenate([input_wide, hidden2])

  output1 = self.output_layer1(concat)

  output2 = self.output_layer2(hidden2)

  return [output1, output2]

  # 構建網路 大連專業人流醫院

  model = WideDeepModel()

  model.build(input_shape=[(None, 5), (None, 6)])

  print(model.layers)

  model.summary()

  # 5、模型的編譯 設定損失函式 最佳化器

  model.compile(loss='mean_squared_error',

  optimizer='adam')

  # 6、設定回撥函式

  callbacks = [tf.keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)]

  # 7、訓練網路

  x_train_scaled_wide = x_train_scaled[:, :5]

  x_train_scaled_deep = x_train_scaled[:, 2:]

  x_valid_scaled_wide = x_valid_scaled[:, :5]

  x_valid_scaled_deep = x_valid_scaled[:, 2:]

  x_test_scaled_wide = x_test_scaled[:, :5]

  x_test_scaled_deep = x_test_scaled[:, 2:]

  history = model.fit([x_train_scaled_wide, x_train_scaled_deep],

  [y_train, y_train],

  validation_data=(

  [x_valid_scaled_wide, x_valid_scaled_deep],

  [y_valid, y_valid]),

  epochs=20,

  callbacks=callbacks)

  # 8、繪製訓練過程資料

  def plot_learning_curves(hst):

  pd.DataFrame(hst.history).plot()

  plt.grid(True)

  plt.gca().set_ylim(0, 1)

  plt.show()

  plot_learning_curves(history)

  # 9.驗證資料

  model.evaluate([x_test_scaled_wide, x_test_scaled_deep], [y_test, y_test])


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69945560/viewspace-2762384/,如需轉載,請註明出處,否則將追究法律責任。

相關文章