深度學習2.0-25.Train-Val-Test劃分檢測過擬合(交叉驗證)
1.train_val劃分
實戰
# train_val
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
def preprocess(x, y):
"""
x is a simple image, not a batch
"""
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [28 * 28])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y
batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())
db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)
sample = next(iter(db))
print(sample[0].shape, sample[1].shape)
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary()
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.fit(db, epochs=5, validation_data=ds_val,
validation_steps=2)
network.evaluate(ds_val)
sample = next(iter(ds_val))
x = sample[0]
y = sample[1] # one-hot
pred = network.predict(x) # [b, 10]
# convert back to number
y = tf.argmax(y, axis=1)
pred = tf.argmax(pred, axis=1)
print(pred)
print(y)
2.train_val_test
# train_val_test
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
def preprocess(x, y):
"""
x is a simple image, not a batch
"""
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [28 * 28])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y
batchsz = 128
(x, y), (x_test, y_test) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())
idx = tf.range(60000)
idx = tf.random.shuffle(idx)
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000])
x_val, y_val = tf.gather(x, idx[-10000:]), tf.gather(y, idx[-10000:])
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.map(preprocess).shuffle(50000).batch(batchsz)
db_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
db_val = db_val.map(preprocess).shuffle(10000).batch(batchsz)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.map(preprocess).batch(batchsz)
sample = next(iter(db_train))
print(sample[0].shape, sample[1].shape)
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary()
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.fit(db_train, epochs=6, validation_data=db_val, validation_freq=2)
print('Test performance:')
network.evaluate(db_test)
sample = next(iter(db_test))
x = sample[0]
y = sample[1] # one-hot
pred = network.predict(x) # [b, 10]
# convert back to number
y = tf.argmax(y, axis=1)
pred = tf.argmax(pred, axis=1)
print(pred)
print(y)
3.交叉驗證
程式碼部分有問題
相關文章
- 如何理解過擬合、正則化和交叉驗證
- 深度學習中的欠擬合和過擬合簡介深度學習
- Tensorflow-交叉熵&過擬合熵
- 深度學習:乳腺x檢測深度學習
- 深度學習中“過擬合”的產生原因和解決方法深度學習
- 深度學習之目標檢測深度學習
- 使用深度學習檢測瘧疾深度學習
- 深度學習之瑕疵缺陷檢測深度學習
- 深度學習“吃雞外掛”——目標檢測 SSD 實驗深度學習
- 【深度學習篇】--神經網路中的調優二,防止過擬合深度學習神經網路
- 52 個深度學習目標檢測模型深度學習模型
- 用深度學習進行欺詐檢測深度學習
- 使用深度學習的交通標誌檢測深度學習
- 深度學習框架跑分測驗(TensorFlow/Caffe/MXNet/Keras/PyTorch)深度學習框架KerasPyTorch
- 演算法學習之路|檢驗身份證演算法
- Pytorch_第八篇_深度學習 (DeepLearning) 基礎 [4]---欠擬合、過擬合與正則化PyTorch深度學習
- 深度學習之影像目標檢測速覽深度學習
- 【深度學習】檢測CUDA、cuDNN、Pytorch是否可用深度學習DNNPyTorch
- 理解「交叉驗證」(Cross Validation)ROS
- 網路模型的交叉驗證模型
- K重交叉驗證和網格搜尋驗證
- 我通過OCP認證的學習經驗
- 應用統計學與R語言實現學習筆記(七)——擬合優度檢驗R語言筆記
- 交叉驗證(Cross validation)總結ROS
- 基於深度學習的停車場車輛檢測演算法matlab模擬深度學習演算法Matlab
- 深度學習之目標檢測與目標識別深度學習
- 全面學習MySQL中的檢視(1) 檢視安全驗證的方式MySql
- 《神經網路和深度學習》系列文章二十五:過擬合與正則化(2)神經網路深度學習
- 模型評估與改進:交叉驗證模型
- 交叉驗證(Cross Validation)原理小結ROS
- Python高效深度學習機器識別驗證碼教程分享Python深度學習
- 深度學習目標檢測(object detection)系列(六)YOLO2深度學習ObjectYOLO
- 深度學習目標檢測(object detection)系列(五) R-FCN深度學習Object
- 深度學習目標檢測(object detection)系列(一) R-CNN深度學習ObjectCNN
- faced:基於深度學習的CPU實時人臉檢測深度學習
- 深度學習與CV教程(13) | 目標檢測 (SSD,YOLO系列)深度學習YOLO
- 揭祕人工智慧(系列):深度學習是否過分誇大?人工智慧深度學習
- 機器學習–過度擬合 欠擬合機器學習