面向初學者的快速入門tensorflow

山河執手發表於2021-01-03
# 將tensorflow載入程式
import tensorflow as tf

# 載入並準備好資料集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 將樣本從整數轉換為浮點數
x_train, x_test = x_train / 255.0, x_test / 255.0
#搭建tf.keras.models.Sequential模型
model=tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28)),#Flatten:變平  將多維陣列變成一位陣列(784)
    tf.keras.layers.Dense(128,activation='relu'),#Dense:降維 設定輸出節點數為128,啟用函式型別為Relu
    tf.keras.layers.Dropout(0.2),#Dropout:隨機拋棄,防止過擬合
    tf.keras.layers.Dense(10,activation='softmax')
])
#為訓練選擇優化器,損失函式,度量
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
#訓練並驗證模型
model.fit(x_train,y_train,epochs=5)
model.evaluate(x_test,y_test,verbose=2)#verbose = 2 為每個epoch輸出一行記錄

model.save_weights('D:weight', save_format='tf')  # 儲存模型

相關文章