使用tensorboard的簡單方法
在Tensorflow中,有時想要使用tensorboard來監視一些指標的變化。下面給出一個小小的例子。我們用到的函式有tf.summary.scalar(),tf.summary.FileWriter(), file_writer.add_summary()
程式碼
tf.summary.scalar()
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,
logits=logits)
loss = tf.reduce_mean(xentropy, name="loss")
loss_summary = tf.summary.scalar('log_loss', loss)
def log_dir(prefix=""):
now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
root_logdir = "tf_logs"
if prefix:
prefix += "-"
name = prefix + "run-" + now
return "{}/{}/".format(root_logdir, name)
logdir = log_dir("logreg")
file_writer = tf.summary.FileWriter(logdir, tf.get_default_graph())
with tf.Session() as sess:
accuracy_val, loss_val, accuracy_summary_str, loss_summary_str = sess.run([
accuracy, loss, accuracy_summary, loss_summary], feed_dict={
X: mnist.validation.images, y: mnist.validation.labels})
file_writer.add_summary(accuracy_summary_str, epoch)
tensorboard –logdir=tf_logs
這邊是上述程式碼片段的來源,這是一個可以跑通的手寫體識別程式碼
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
import tensorflow as tf
from datetime import datetime
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../data/")
X_train = mnist.train.images
X_test = mnist.test.images
y_train = mnist.train.labels.astype("int")
y_test = mnist.test.labels.astype("int")
n_inputs = 28 * 28 # MNIST
n_hidden1 = 300
n_hidden2 = 100
n_outputs = 10
X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")
with tf.name_scope("dnn"):
he_init = tf.contrib.layers.variance_scaling_initializer()
xavier = tf.contrib.layers.xavier_initializer()
hidden1 = tf.layers.dense(X, n_hidden1, name="hidden1", kernel_initializer=he_init,
activation=tf.nn.relu)
hidden2 = tf.layers.dense(hidden1, n_hidden2, name="hidden2", kernel_initializer=he_init,
activation=tf.nn.relu)
logits = tf.layers.dense(
hidden2, n_outputs, name="outputs", kernel_initializer=he_init)
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,
logits=logits)
loss = tf.reduce_mean(xentropy, name="loss")
loss_summary = tf.summary.scalar('log_loss', loss)
learning_rate = 0.01
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
training_op = optimizer.minimize(loss)
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
accuracy_summary = tf.summary.scalar('accuracy', accuracy)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
n_epochs = 40
batch_size = 50
def log_dir(prefix=""):
now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
root_logdir = "tf_logs"
if prefix:
prefix += "-"
name = prefix + "run-" + now
return "{}/{}/".format(root_logdir, name)
logdir = log_dir("logreg")
file_writer = tf.summary.FileWriter(logdir, tf.get_default_graph())
checkpoint_path = "./tmp/my_logreg_model.ckpt"
checkpoint_epoch_path = checkpoint_path + ".epoch"
final_model_path = "./my_logreg_model"
best_loss = np.infty
epochs_without_progress = 0
max_epochs_without_progress = 50
with tf.Session() as sess:
if os.path.isfile(checkpoint_epoch_path):
# if the checkpoint file exists, restore the model and load the epoch
# number
with open(checkpoint_epoch_path, "rb") as f:
start_epoch = int(f.read())
print("Training was interrupted. Continuing at epoch", start_epoch)
saver.restore(sess, checkpoint_path)
else:
start_epoch = 0
sess.run(init)
init.run()
for epoch in range(n_epochs):
for iteration in range(mnist.train.num_examples // batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run([training_op, loss_summary],
feed_dict={X: X_batch, y: y_batch})
accuracy_val, loss_val, accuracy_summary_str, loss_summary_str = sess.run([
accuracy, loss, accuracy_summary, loss_summary], feed_dict={
X: mnist.validation.images, y: mnist.validation.labels})
file_writer.add_summary(accuracy_summary_str, epoch)
file_writer.add_summary(loss_summary_str, epoch)
if epoch % 5 == 0:
print("epoch:", epoch, "\tVal accuracy:{:.3f}%".format(
accuracy_val * 100), "\tLoss:{:.5f}".format(loss_val))
saver.save(sess, checkpoint_path)
with open(checkpoint_epoch_path, "wb") as f:
f.write(b"%d" % (epoch + 1))
if loss_val < best_loss:
saver.save(sess, final_model_path)
best_loss = loss_val
else:
epochs_without_progress += 5
if epochs_without_progress > max_epochs_without_progress:
print("Early stopping")
break
os.remove(checkpoint_epoch_path)
# tensorboard --logdir=tf_logs
然後在命令列裡敲入 tensorboard –logdir=tf_logs
相關文章
- Tensorboard的使用ORB
- github的簡單使用方法Github
- (11)tensorboard的使用ORB
- tensorboard 使用ORB
- sqlmap簡單使用方法SQL
- Lumen/Laravel 使用 alipay 最簡單的方法Laravel
- 最簡單mysql的使用方法(轉)MySql
- iOS UICollectionView的簡單使用和常用代理方法iOSUIView
- 使用簡單方法排除路由器的故障路由器
- BBED的安裝及簡單的使用方法
- 簡單的排序方法排序
- 在ASP.NET中使用AJAX的簡單方法ASP.NET
- 簡單的方法掌握JS中slice,splice和split的使用方法JS
- Tensorboard 在伺服器上的使用ORB伺服器
- MySQL傳輸表空間的簡單使用方法MySql
- Kdevelop的簡單使用和簡單除錯dev除錯
- docker的簡單使用Docker
- postman的簡單使用Postman
- RecyclerView的簡單使用View
- git的簡單使用Git
- LayUi的簡單使用UI
- RocketMQ的簡單使用MQ
- Vue簡單的使用Vue
- Cookie的簡單使用Cookie
- HttpClient的簡單使用HTTPclient
- explain for 的簡單使用AI
- OD的簡單使用
- 使用原生 cookieStore 方法,讓 Cookie 操作更簡單Cookie
- 懶載入簡單的方法
- android 簡單的分享方法Android
- 關於axios以及jsonp的簡單使用方法iOSJSON
- 最簡單實現跨域的方法:使用nginx反向代理跨域Nginx
- jQuery外掛Tmpl使用方法簡單介紹jQuery
- shell script的簡單使用
- uuid的簡單使用UI
- Mackdown簡單的使用教程Mac
- react hooks 的簡單使用ReactHook
- vue框架的簡單使用Vue框架