網路模型的交叉驗證
網路模型的交叉驗證
資料集的內容一般分為:trainset、valset和testset,將獲取的mnist資料集的(x,y)部分拆分為[x_train, y_train](訓練資料)和[x_val, y_val](評測資料)
#把[x, y]拆分成[x_train, y_train]和[x_val, y_val]
idx = tf.range(60000)
idx = tf.random.shuffle(idx)
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000]) #[0, 50000]
x_val, y_val = tf.gather(x, idx[-10000:]) , tf.gather(y, idx[-10000:]) #[50000,60000]
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
用這兩部分資料集進行邊訓練,邊測評,能夠檢視當前訓練的指標效果。可以檢視每次訓練後評價指標,找出最好的引數。
#一共training6次,每2次評測一次
# db_train:training資料集, epochs:資料集的training的次數, validation_data:用於評測的資料集, validation_freq:每n個epoch做一次評測
network.fit(db_train, epochs=6, validation_data=db_val, validation_freq=2)
用(x_test, y_test)資料集對模型進行進行測試。
network.evaluate(db_test)
完整程式碼:
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
def preprocess(x, y):
"""
x is a simple image, not a batch
"""
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [28*28])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x,y
batchsz = 128
(x, y), (x_test, y_test) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())
#把[x, y]拆分成[x_train, y_train]和[x_val, y_val]
idx = tf.range(60000)
idx = tf.random.shuffle(idx)
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000]) #[0, 50000]
x_val, y_val = tf.gather(x, idx[-10000:]) , tf.gather(y, idx[-10000:]) #[50000,60000]
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
db_train = tf.data.Dataset.from_tensor_slices((x_train,y_train))
db_train = db_train.map(preprocess).shuffle(50000).batch(batchsz)
db_val = tf.data.Dataset.from_tensor_slices((x_val,y_val))
db_val = db_val.map(preprocess).shuffle(10000).batch(batchsz)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.map(preprocess).batch(batchsz)
sample = next(iter(db_train))
print(sample[0].shape, sample[1].shape)
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
#一共training6次,每2次評測一次
# db_train:training資料集, epochs:資料集的training的次數, validation_data:用於評測的資料集, validation_freq:每n個epoch做一次評測
network.fit(db_train, epochs=6, validation_data=db_val, validation_freq=2)
#測試的資料集
print('Test performance:')
network.evaluate(db_test)
sample = next(iter(db_test))
x = sample[0]
y = sample[1] # one-hot
pred = network.predict(x) # [b, 10]
# convert back to number
y = tf.argmax(y, axis=1)
pred = tf.argmax(pred, axis=1)
print(pred)
print(y)
相關文章
- 模型評估與改進:交叉驗證模型
- K重交叉驗證和網格搜尋驗證
- TALK的網路驗證
- 理解「交叉驗證」(Cross Validation)ROS
- 交叉驗證(Cross validation)總結ROS
- 交叉驗證(Cross Validation)原理小結ROS
- 【模型評估與選擇】交叉驗證Cross-validation: evaluating estimator performance模型ROSORM
- webapi - 模型驗證WebAPI模型
- 多折交叉驗證有什麼用處
- 計算機網路驗證性實驗計算機網路
- 網路驗證碼的進化:從簡單圖文到無感驗證
- 夢亞網路驗證開源程式
- 網路驗證之授權碼使用
- 如何理解過擬合、正則化和交叉驗證
- 深度殘差收縮網路:(五)實驗驗證
- Gin 模型繫結驗證模型
- 【教程】無法驗證app需要網際網路連線以驗證是否信任開發者APP
- 如何快速全面驗證網路埠連通性
- Token的驗證原理是什麼?網路安全網路協議知識點協議
- 透過WLAN測試驗證網路的連線性
- 雙重保險——前端bootstrapValidator驗證+後臺MVC模型驗證前端bootMVC模型
- 實驗二 網路嗅探與身份認證
- 網路模型模型
- 網路滲透測試實驗二——網路嗅探與身份認證
- 基於gin的golang web開發:模型驗證GolangWeb模型
- 網路流量模型模型
- R語言邏輯迴歸、ROC曲線和十折交叉驗證R語言邏輯迴歸
- Docker 網路模型之 macvlan 詳解,圖解,實驗完整Docker模型Mac圖解
- 《神經網路的梯度推導與程式碼驗證》之LSTM前向和反向傳播的程式碼驗證神經網路梯度反向傳播
- 網路 Server 模型的演進Server模型
- win10如何取消聯網驗證_win10系統網路身份驗證提示框怎麼關閉Win10
- 計算機網路(一) --網路模型計算機網路模型
- 《神經網路的梯度推導與程式碼驗證》之vanilla RNN前向和反向傳播的程式碼驗證神經網路梯度RNN反向傳播
- 利用IPsec實現網路安全之四(CA證書實現身份驗證)
- 七層網路模型模型
- 網路I/O模型模型
- TCP/IP網路模型TCP模型
- Java 網路 IO 模型Java模型