使用tf.estimator.Estimator訓練神經網路

大雄沒有叮噹貓發表於2018-09-16

參考: https://github.com/aymericdamien/TensorFlow-Examples/

from __future__ import print_function

from tensorflow.examples.tutorials.mnist import input_data

 

mnist=input_data.read_data_sets("/data/machine_learning/mnist/",one_hot=False)

print(mnist.train.labels.shape)

 

import tensorflow as tf

 

#引數設定

learning_rate=0.1 #學習率

num_steps=1000 #迭代次數

batch_size=128 #批處理大小

display_step=100 #輸出間隔

 

#網路引數

n_hidden_1=256 #第一個隱藏層神經元

n_hidden_2=256 #第二個隱藏層神經元

num_input=784 #28*28

num_classes=10 #標籤類別

 

#定義網路結構

def neural_net(x_dict):

    x=x_dict['images']

    layer_1=tf.layers.dense(x,n_hidden_1) #全連線層

    layer_2=tf.layers.dense(layer_1,n_hidden_2) #全連線層

    out_layer=tf.layers.dense(layer_2,num_classes) #全連線層,輸出層

    return out_layer

 

def model_fn(features,labels,mode):

    logits=neural_net(features) #輸出

   

    #預測

    pred_classes=tf.argmax(logits,axis=1)

    pred_probas=tf.nn.softmax(logits)

   

    if mode==tf.estimator.ModeKeys.PREDICT:

        return tf.estimator.EstimatorSpec(mode,predictions=pred_classes)

   

    #定義損失和優化函式

    loss_op=tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=tf.cast(labels,dtype=tf.int32)))

    optimizer=tf.train.GradientDescentOptimizer(learning_rate=learning_rate) #梯度下降優化器

    train_op=optimizer.minimize(loss_op,global_step=tf.train.get_global_step()) #最小化損失

   

    acc_op=tf.metrics.accuracy(labels=labels,predictions=pred_classes) #精度

   

    estim_specs=tf.estimator.EstimatorSpec(mode=mode,predictions=pred_classes,loss=loss_op,train_op=train_op,eval_metric_ops={'accuracy':acc_op})

    return estim_specs

 

 

model=tf.estimator.Estimator(model_fn)

 

input_fn=tf.estimator.inputs.numpy_input_fn(x={'images':mnist.train.images},y=mnist.train.labels,batch_size=batch_size,num_epochs=None,shuffle=True)

 

model.train(input_fn,steps=num_steps)

 

input_fn=tf.estimator.inputs.numpy_input_fn(x={'images':mnist.test.images},y=mnist.test.labels,batch_size=batch_size,shuffle=False)

 

e=model.evaluate(input_fn)

print("測試精度:",e['accuracy'])

 

 

相關文章