關於Tensorflow2.0 keras的子類式多輸入多輸出
1.關鍵程式碼
在定義好輸入層、輸出層後使用類 配置inputs outputs引數(陣列)
網路模型搭建
class WideDeepModel(tf.keras.models.Model):
def __init__(self):
super(WideDeepModel, self).__init__()
self.hidden1_layer = tf.keras.layers.Dense(30, activation='relu')
self.hidden2_layer = tf.keras.layers.Dense(30, activation='relu')
self.output_layer1 = tf.keras.layers.Dense(1)
self.output_layer2 = tf.keras.layers.Dense(1)
def call(self, inputs, training=None, mask=None):
"""完成模型的正向計算"""
input_wide = inputs[0] # 輸入1
input_deep = inputs[1] # 輸入2
hidden1 = self.hidden1_layer(input_deep)
hidden2 = self.hidden2_layer(hidden1)
concat = tf.keras.layers.concatenate([input_wide, hidden2])
output1 = self.output_layer1(concat) # 輸出1
output2 = self.output_layer2(hidden2) # 輸出2
return [output1, output2] # 輸出組合
# 構建網路
model = WideDeepModel()
model.build(input_shape=[(None, 5), (None, 6)])
print(model.layers)
model.summary()
完整程式碼:
import pprint
import sys
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn
import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, keras, tf:
print(module.__name__, module.__version__)
from sklearn.datasets import fetch_california_housing
# 1.載入資料集 波士頓房價預測
housing = fetch_california_housing()
print(housing.DESCR)
print(housing.data.shape)
print(housing.target.shape)
pprint.pprint(housing.data[:5])
pprint.pprint(housing.target[:5])
from sklearn.model_selection import train_test_split
# 2.拆分資料集
# 訓練集與測試集拆分
x_train_all, x_test, y_train_all, y_test = train_test_split(housing.data,
housing.target,
random_state=7,
test_size=0.20)
# 訓練集與驗證集的拆分
x_train, x_valid, y_train, y_valid = train_test_split(
x_train_all, y_train_all, random_state=11, test_size=0.20)
print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)
print(x_test.shape, y_test.shape)
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
# 3、資料預處理 資料集的歸一化
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.transform(x_valid)
x_test_scaled = scaler.transform(x_test)
# 4、網路模型的搭建
# 子類API
class WideDeepModel(tf.keras.models.Model):
def __init__(self):
super(WideDeepModel, self).__init__()
self.hidden1_layer = tf.keras.layers.Dense(30, activation='relu')
self.hidden2_layer = tf.keras.layers.Dense(30, activation='relu')
self.output_layer1 = tf.keras.layers.Dense(1)
self.output_layer2 = tf.keras.layers.Dense(1)
def call(self, inputs, training=None, mask=None):
"""完成模型的正向計算"""
input_wide = inputs[0] # 輸入1
input_deep = inputs[1] # 輸入2
hidden1 = self.hidden1_layer(input_deep)
hidden2 = self.hidden2_layer(hidden1)
concat = tf.keras.layers.concatenate([input_wide, hidden2])
output1 = self.output_layer1(concat)
output2 = self.output_layer2(hidden2)
return [output1, output2]
# 構建網路 大連專業人流醫院
model = WideDeepModel()
model.build(input_shape=[(None, 5), (None, 6)])
print(model.layers)
model.summary()
# 5、模型的編譯 設定損失函式 最佳化器
model.compile(loss='mean_squared_error',
optimizer='adam')
# 6、設定回撥函式
callbacks = [tf.keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)]
# 7、訓練網路
x_train_scaled_wide = x_train_scaled[:, :5]
x_train_scaled_deep = x_train_scaled[:, 2:]
x_valid_scaled_wide = x_valid_scaled[:, :5]
x_valid_scaled_deep = x_valid_scaled[:, 2:]
x_test_scaled_wide = x_test_scaled[:, :5]
x_test_scaled_deep = x_test_scaled[:, 2:]
history = model.fit([x_train_scaled_wide, x_train_scaled_deep],
[y_train, y_train],
validation_data=(
[x_valid_scaled_wide, x_valid_scaled_deep],
[y_valid, y_valid]),
epochs=20,
callbacks=callbacks)
# 8、繪製訓練過程資料
def plot_learning_curves(hst):
pd.DataFrame(hst.history).plot()
plt.grid(True)
plt.gca().set_ylim(0, 1)
plt.show()
plot_learning_curves(history)
# 9.驗證資料
model.evaluate([x_test_scaled_wide, x_test_scaled_deep], [y_test, y_test])
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69945560/viewspace-2762384/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- StreamingPro 支援多輸入,多輸出配置
- Python3常用輸入模式:-輸入多組,固定組,多個輸入Python模式
- MR多輸入
- 關於torch.nn.LSTM()的輸入和輸出
- converter設計模式擴充套件,多種輸入輸出與標準輸入輸出的轉化方案設計模式套件
- 常用輸入輸出函式函式
- 輸出輸入函式彙總函式
- .NET Standard中配置TargetFrameworks輸出多版本類庫Framework
- 關於運放的共模輸入範圍和輸出擺幅
- FMC293-基於FMC 16路LVDS輸入或者輸出子卡
- 多種格式資料輸出
- HTML如何輸入多個空格HTML
- 輸入輸出
- 新手學python之Python的輸入輸出函式Python函式
- linux中的輸入與輸出管理(重定向輸入,輸出,管道符)Linux
- 關於輸入框的細節
- 資料的輸入輸出
- 輸入輸出流
- Keras輸出網路結構圖Keras
- kissat的多輸出-學習與修改1
- ncurses輸出函式:字元+字串的輸出函式字元字串
- ncurses輸入函式:字元+字串的輸入函式字元字串
- ACM的Python版輸入輸出ACMPython
- 嵌入式Linux—輸入子系統Linux
- 關於輸出的小語法點
- C語言_輸入輸出函式_PAGE5C語言函式
- 1.輸入輸出
- 【C++】輸入輸出C++
- 輸入輸出系統
- shell——shell輸入輸出
- Java 輸入輸出流Java
- filebeat輸出結果到elasticsearch的多個索引Elasticsearch索引
- Python資料的輸入與輸出Python
- AUTOCAD——圖形的輸入與輸出
- Java------簡單的輸入/輸出Java
- python:檔案的輸入與輸出Python
- 關於java輸入易錯點Java
- C語言之輸入輸出C語言