【NLP】TensorFlow實現CNN用於中文文字分類

widiot1發表於2018-02-04

程式碼基於 dennybritz/cnn-text-classification-tfclayandgithub/zh_cnn_text_classify
參考文章 瞭解用於NLP的卷積神經網路(譯)TensorFlow實現CNN用於文字分類(譯)
本文完整程式碼 - Widiot/cnn-zh-text-classification

1. 專案結構

以下是完整的目錄結構示例,包括執行之後形成的目錄和檔案:

cnn-zh-text-classification/
	data/
		maildata/
			cleaned_ham_5000.utf8
			cleaned_spam_5000.utf8
			ham_5000.utf8
			spam_5000.utf8
	runs/
		1517572900/
			checkpoints/
				...
			summaries/
				...
			prediction.csv
			vocab
	.gitignore
	README.md
	data_helpers.py
	eval.py
	text_cnn.py
	train.py

各個目錄及檔案的作用如下:

  • data 目錄用於存放資料
  • maildata 目錄用於存放郵件檔案,目前有四個檔案,ham_5000.utf8 及 spam_5000.utf8 分別為正常郵件和垃圾郵件,帶 cleaned 字首的檔案為清洗後的資料
  • runs 目錄用於存放每次執行產生的資料,以時間戳為目錄名
  • 1517572900 目錄用於存放每次執行產生的檢查點、日誌摘要、詞彙檔案及評估產生的結果
  • data_helpers.py 用於處理資料
  • eval.py 用於評估模型
  • text_cnn.py 是 CNN 模型類
  • train.py 用於訓練模型

2. 資料

2.1 資料格式

以分類正常郵件和垃圾郵件為例,如下是郵件資料的例子:

# 正常郵件
他們自己也是剛到北京不久 跟在北京讀書然後留在這裡工作的還不一樣 難免會覺得還有好多東西沒有安頓下來 然後來了之後還要帶著四處旅遊甚麼什麼的 卻是花費很大 你要不帶著出去玩,還真不行 這次我小表弟來北京玩,花了好多錢 就因為本來預定的幾個地方因為某種原因沒去 舅媽似乎就很不開心 結果就是錢全白花了 人家也是牢騷一肚子 所以是自己找出來的困難 退一萬步說 婆婆來幾個月
發文時難免欠點理智 我不怎麼灌水,沒想到上了十大了,拍的還挺歡,呵呵 寫這個貼子,是由於自己太鬱悶了,其時,我最主要的目的,是覺得,水木上肯定有一些嫁農村GG但現在很幸福的JJMM.我目前遇到的問題,我的確不知道怎麼解決,所以發上來,問一下成功解決這類問題的建議.因為沒有相同的經歷和體會,是不會理解的,我在我身邊就找不到可行的建議. 結果,無心得罪了不少人.呵呵,可能我想了太多關於城鄉差別的問題,意識的比較深刻,所以不經意寫了出來.
所以那些貴族1就要找一些特定的東西來章顯自己的與眾不同 這個東西一定是窮人買不起的,所以好多奢侈品也就營運誕生了 想想也是,他們要表也沒有啊, 我要是香paris hilton那麼有錢,就每天一個牌子的表,一個牌子的時裝,一個牌子的汽車,哈哈,。。。要得就是這個派 俺連表都不用, 帶手上都累贅, 上課又不能開手機, 所以俺只好經常退一下ppt去看右下腳的時間. 其實 貴族又不用趕時間, 要知道精確時間做啥? 表走的

# 垃圾郵件
中信(國際)電子科技有限公司推出新產品: 升職步步高、做生意發大財、連找情人都用的上,詳情進入 網  址:  http://www.usa5588.com/ccc 電話:020-33770208   服務熱線:013650852999
以下不能正確顯示請點此 IFRAME: http://www.ewzw.com/bbs/viewthread.php?tid=3809&fpage=1
尊敬的公司您好!打擾之處請見諒! 我深圳公司願在互惠互利、誠信為本代開3釐---2點國稅、地稅等發票。增值稅和海關繳款書就以2點---7點來代開。手機:13510631209       聯絡人:鄺先生  郵箱:ao998@163.com     祥細資料合作告知,希望合作。謝謝!!

每個句子單獨一行,正常郵件和垃圾郵件的資料分別存放在兩個檔案中。

2.2 資料處理

資料處理 data_helpers.py 的程式碼如下,與所參考的程式碼不同的是:

  • load_data_and_labels():將函式的引數修改為以逗號分隔的資料檔案的路徑字串,比如 './data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8',這樣可以讀取多個類別的資料檔案以實現多分類問題
  • read_and_clean_zh_file():將函式的 output_cleaned_file 修改為 boolean 型別,控制是否儲存清洗後的資料,並在函式中判斷,如果已經存在清洗後的資料檔案則直接載入,否則進行清洗並選擇儲存

其他函式與所參考的程式碼相比變動不大:

import numpy as np
import re
import os


def load_data_and_labels(data_files):
    """
    1. 載入所有資料和標籤
    2. 可以進行多分類,每個類別的資料單獨放在一個檔案中
    2. 儲存處理後的資料
    """
    data_files = data_files.split(',')
    num_data_file = len(data_files)
    assert num_data_file > 1
    x_text = []
    y = []
    for i, data_file in enumerate(data_files):
        # 將資料放在一起
        data = read_and_clean_zh_file(data_file, True)
        x_text += data
        # 形成資料對應的標籤
        label = [0] * num_data_file
        label[i] = 1
        labels = [label for _ in data]
        y += labels
    return [x_text, np.array(y)]


def read_and_clean_zh_file(input_file, output_cleaned_file=False):
    """
    1. 讀取中文檔案並清洗句子
    2. 可以將清洗後的結果儲存到檔案
    3. 如果已經存在經過清洗的資料檔案則直接載入
    """
    data_file_path, file_name = os.path.split(input_file)
    output_file = os.path.join(data_file_path, 'cleaned_' + file_name)
    if os.path.exists(output_file):
        lines = list(open(output_file, 'r').readlines())
        lines = [line.strip() for line in lines]
    else:
        lines = list(open(input_file, 'r').readlines())
        lines = [clean_str(seperate_line(line)) for line in lines]
        if output_cleaned_file:
            with open(output_file, 'w') as f:
                for line in lines:
                    f.write(line + '\n')
    return lines


def clean_str(string):
    """
    1. 將除漢字外的字元轉為一個空格
    2. 將連續的多個空格轉為一個空格
    3. 除去句子前後的空格字元
    """
    string = re.sub(r'[^\u4e00-\u9fff]', ' ', string)
    string = re.sub(r'\s{2,}', ' ', string)
    return string.strip()


def seperate_line(line):
    """
    將句子中的每個字用空格分隔開
    """
    return ''.join([word + ' ' for word in line])


def batch_iter(data, batch_size, num_epochs, shuffle=True):
    '''
    生成一個batch迭代器
    '''
    data = np.array(data)
    data_size = len(data)
    num_batches_per_epoch = int((data_size - 1) / batch_size) + 1
    for epoch in range(num_epochs):
        if shuffle:
            shuffle_indices = np.random.permutation(np.arange(data_size))
            shuffled_data = data[shuffle_indices]
        else:
            shuffled_data = data
        for batch_num in range(num_batches_per_epoch):
            start_idx = batch_num * batch_size
            end_idx = min((batch_num + 1) * batch_size, data_size)
            yield shuffled_data[start_idx:end_idx]


if __name__ == '__main__':
    data_files = './data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8'
    x_text, y = load_data_and_labels(data_files)
    print(x_text)

2.3 清洗標準

將原始資料進行清洗,僅保留漢字,並把每個漢字用一個空格分隔開,各個類別清洗後的資料分別存放在 cleaned 字首的檔案中,清洗後的資料格式如下:

本 公 司 有 部 分 普 通 發 票 商 品 銷 售 發 票 增 值 稅 發 票 及 海 關 代 徵 增 值 稅 專 用 繳 款 書 及 其 它 服 務 行 業 發 票 公 路 內 河 運 輸 發 票 可 以 以 低 稅 率 為 貴 公 司 代 開 本 公 司 具 有 內 外 貿 生 意 實 力 保 證 我 司 開 具 的 票 據 的 真 實 性 希 望 可 以 合 作 共 同 發 展 敬 侯 您 的 來 電 洽 談 諮 詢 聯 系 人 李 先 生 聯 系 電 話 如 有 打 擾 望 諒 解 祝 商 琪

3. 模型

CNN 模型類 text_cnn.py 的程式碼如下,修改的地方如下:

  • 將 concat 和 reshape 的操作結點放在 concat 名稱空間下,這樣在 TensorBoard 中的節點圖更加清晰合理
  • 將計算損失值的操作修改為通過 collection 進行,並只計算 W 的 L2 損失值,刪去了計算 b 的 L2 損失值的程式碼
import tensorflow as tf
import numpy as np


class TextCNN(object):
    """
    字元級CNN文字分類
    詞嵌入層->卷積層->池化層->softmax層
    """

    def __init__(self,
                 sequence_length,
                 num_classes,
                 vocab_size,
                 embedding_size,
                 filter_sizes,
                 num_filters,
                 l2_reg_lambda=0.0):

        # 輸入,輸出,dropout的佔位符
        self.input_x = tf.placeholder(
            tf.int32, [None, sequence_length], name='input_x')
        self.input_y = tf.placeholder(
            tf.float32, [None, num_classes], name='input_y')
        self.dropout_keep_prob = tf.placeholder(
            tf.float32, name='dropout_keep_prob')

        # l2正則化損失值(可選)
        #l2_loss = tf.constant(0.0)

        # 詞嵌入層
        # W為詞彙表,大小為0~詞彙總數,索引對應不同的字,每個字對映為128維的陣列,比如[3800,128]
        with tf.name_scope('embedding'):
            self.W = tf.Variable(
                tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
                name='W')
            self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
            self.embedded_chars_expanded = tf.expand_dims(
                self.embedded_chars, -1)

        # 卷積層和池化層
        # 為3,4,5分別建立128個過濾器,總共3×128個過濾器
        # 過濾器形狀為[3,128,1,128],表示一次能過濾三個字,最後形成188×128的特徵向量
        # 池化核形狀為[1,188,1,1],128維中的每一維表示該句子的不同向量表示,池化即從每一維中提取最大值表示該維的特徵
        # 池化得到的特徵向量為128維
        pooled_outputs = []
        for i, filter_size in enumerate(filter_sizes):
            with tf.name_scope('conv-maxpool-%s' % filter_size):
                # 卷積層
                filter_shape = [filter_size, embedding_size, 1, num_filters]
                W = tf.Variable(
                    tf.truncated_normal(filter_shape, stddev=0.1), name='W')
                b = tf.Variable(
                    tf.constant(0.1, shape=[num_filters]), name='b')
                conv = tf.nn.conv2d(
                    self.embedded_chars_expanded,
                    W,
                    strides=[1, 1, 1, 1],
                    padding='VALID',
                    name='conv')
                # ReLU啟用
                h = tf.nn.relu(tf.nn.bias_add(conv, b), name='relu')
                # 池化層
                pooled = tf.nn.max_pool(
                    h,
                    ksize=[1, sequence_length - filter_size + 1, 1, 1],
                    strides=[1, 1, 1, 1],
                    padding='VALID',
                    name='pool')
                pooled_outputs.append(pooled)

        # 組合所有池化後的特徵
        # 將三個過濾器得到的特徵向量組合成一個384維的特徵向量
        num_filters_total = num_filters * len(filter_sizes)
        with tf.name_scope('concat'):
            self.h_pool = tf.concat(pooled_outputs, 3)
            self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total])

        # dropout
        with tf.name_scope('dropout'):
            self.h_drop = tf.nn.dropout(self.h_pool_flat,
                                        self.dropout_keep_prob)

        # 全連線層
        # 分數和預測結果
        with tf.name_scope('output'):
            W = tf.Variable(
                tf.truncated_normal(
                    [num_filters_total, num_classes], stddev=0.1),
                name='W')
            b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name='b')
            if l2_reg_lambda:
                W_l2_loss = tf.contrib.layers.l2_regularizer(l2_reg_lambda)(W)
                tf.add_to_collection('losses', W_l2_loss)
            self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name='scores')
            self.predictions = tf.argmax(self.scores, 1, name='predictions')

        # 計算交叉損失熵
        with tf.name_scope('loss'):
            mse_loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(
                    logits=self.scores, labels=self.input_y))
            tf.add_to_collection('losses', mse_loss)
            self.loss = tf.add_n(tf.get_collection('losses'))

        # 正確率
        with tf.name_scope('accuracy'):
            correct_predictions = tf.equal(self.predictions,
                                           tf.argmax(self.input_y, 1))
            self.accuracy = tf.reduce_mean(
                tf.cast(correct_predictions, 'float'), name='accuracy')

最終的神經網路結構圖在 TensorBoard 中的樣式如下:

4. 訓練

訓練模型的 train.py 程式碼如下,修改的地方如下:

  • 將資料檔案的路徑引數修改為一個用逗號分隔開的字串,便於實現多分類問題
  • tf.flags 重新命名為 flags,更加簡潔
import tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_helpers
from text_cnn import TextCNN
from tensorflow.contrib import learn

# 引數
# ==================================================

flags = tf.flags

# 資料載入引數
flags.DEFINE_float('dev_sample_percentage', 0.1,
                   'Percentage of the training data to use for validation')
flags.DEFINE_string(
    'data_files',
    './data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8',
    'Comma-separated data source files')

# 模型超引數
flags.DEFINE_integer('embedding_dim', 128,
                     'Dimensionality of character embedding (default: 128)')
flags.DEFINE_string('filter_sizes', '3,4,5',
                    'Comma-separated filter sizes (default: "3,4,5")')
flags.DEFINE_integer('num_filters', 128,
                     'Number of filters per filter size (default: 128)')
flags.DEFINE_float('dropout_keep_prob', 0.5,
                   'Dropout keep probability (default: 0.5)')
flags.DEFINE_float('l2_reg_lambda', 0.0,
                   'L2 regularization lambda (default: 0.0)')

# 訓練引數
flags.DEFINE_integer('batch_size', 64, 'Batch Size (default: 64)')
flags.DEFINE_integer('num_epochs', 10,
                     'Number of training epochs (default: 10)')
flags.DEFINE_integer(
    'evaluate_every', 100,
    'Evaluate model on dev set after this many steps (default: 100)')
flags.DEFINE_integer('checkpoint_every', 100,
                     'Save model after this many steps (default: 100)')
flags.DEFINE_integer('num_checkpoints', 5,
                     'Number of checkpoints to store (default: 5)')

# 其他引數
flags.DEFINE_boolean('allow_soft_placement', True,
                     'Allow device soft device placement')
flags.DEFINE_boolean('log_device_placement', False,
                     'Log placement of ops on devices')

FLAGS = flags.FLAGS
FLAGS._parse_flags()
print('\nParameters:')
for attr, value in sorted(FLAGS.__flags.items()):
    print('{}={}'.format(attr.upper(), value))
print('')

# 資料準備
# ==================================================

# 載入資料
print('Loading data...')
x_text, y = data_helpers.load_data_and_labels(FLAGS.data_files)

# 建立詞彙表
max_document_length = max([len(x.split(' ')) for x in x_text])
vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
x = np.array(list(vocab_processor.fit_transform(x_text)))

# 隨機混淆資料
np.random.seed(10)
shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices]

# 劃分train/test資料集
# TODO: 這種做法比較暴力,應該用交叉驗證
dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y)))
x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:]
y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:]

del x, y, x_shuffled, y_shuffled

print('Vocabulary Size: {:d}'.format(len(vocab_processor.vocabulary_)))
print('Train/Dev split: {:d}/{:d}'.format(len(y_train), len(y_dev)))
print('')

# 訓練
# ==================================================

with tf.Graph().as_default():
    session_conf = tf.ConfigProto(
        allow_soft_placement=FLAGS.allow_soft_placement,
        log_device_placement=FLAGS.log_device_placement)
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        cnn = TextCNN(
            sequence_length=x_train.shape[1],
            num_classes=y_train.shape[1],
            vocab_size=len(vocab_processor.vocabulary_),
            embedding_size=FLAGS.embedding_dim,
            filter_sizes=list(map(int, FLAGS.filter_sizes.split(','))),
            num_filters=FLAGS.num_filters,
            l2_reg_lambda=FLAGS.l2_reg_lambda)

        # 定義訓練相關操作
        global_step = tf.Variable(0, name='global_step', trainable=False)
        optimizer = tf.train.AdamOptimizer(1e-3)
        grads_and_vars = optimizer.compute_gradients(cnn.loss)
        train_op = optimizer.apply_gradients(
            grads_and_vars, global_step=global_step)

        # 跟蹤梯度值和稀疏性(可選)
        grad_summaries = []
        for g, v in grads_and_vars:
            if g is not None:
                grad_hist_summary = tf.summary.histogram(
                    '{}/grad/hist'.format(v.name), g)
                sparsity_summary = tf.summary.scalar('{}/grad/sparsity'.format(
                    v.name), tf.nn.zero_fraction(g))
                grad_summaries.append(grad_hist_summary)
                grad_summaries.append(sparsity_summary)
        grad_summaries_merged = tf.summary.merge(grad_summaries)

        # 模型和摘要的儲存目錄
        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(
            os.path.join(os.path.curdir, 'runs', timestamp))
        print('\nWriting to {}\n'.format(out_dir))

        # 損失值和正確率的摘要
        loss_summary = tf.summary.scalar('loss', cnn.loss)
        acc_summary = tf.summary.scalar('accuracy', cnn.accuracy)

        # 訓練摘要
        train_summary_op = tf.summary.merge(
            [loss_summary, acc_summary, grad_summaries_merged])
        train_summary_dir = os.path.join(out_dir, 'summaries', 'train')
        train_summary_writer = tf.summary.FileWriter(train_summary_dir,
                                                     sess.graph)

        # 開發摘要
        dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
        dev_summary_dir = os.path.join(out_dir, 'summaries', 'dev')
        dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)

        # 檢查點目錄,預設存在
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, 'checkpoints'))
        checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(
            tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)

        # 寫入詞彙表檔案
        vocab_processor.save(os.path.join(out_dir, 'vocab'))

        # 初始化變數
        sess.run(tf.global_variables_initializer())

        def train_step(x_batch, y_batch):
            """
            一個訓練步驟
            """
            feed_dict = {
                cnn.input_x: x_batch,
                cnn.input_y: y_batch,
                cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
            }
            _, step, summaries, loss, accuracy = sess.run([
                train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy
            ], feed_dict)
            time_str = datetime.datetime.now().isoformat()
            print('{}: step {}, loss {:g}, acc {:g}'.format(
                time_str, step, loss, accuracy))
            train_summary_writer.add_summary(summaries, step)

        def dev_step(x_batch, y_batch, writer=None):
            """
            在開發集上驗證模型
            """
            feed_dict = {
                cnn.input_x: x_batch,
                cnn.input_y: y_batch,
                cnn.dropout_keep_prob: 1.0
            }
            step, summaries, loss, accuracy = sess.run(
                [global_step, dev_summary_op, cnn.loss, cnn.accuracy],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            print('{}: step {}, loss {:g}, acc {:g}'.format(
                time_str, step, loss, accuracy))
            if writer:
                writer.add_summary(summaries, step)

        # 生成batches
        batches = data_helpers.batch_iter(
            list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)
        # 迭代訓練每個batch
        for batch in batches:
            x_batch, y_batch = zip(*batch)
            train_step(x_batch, y_batch)
            current_step = tf.train.global_step(sess, global_step)
            if current_step % FLAGS.evaluate_every == 0:
                print('\nEvaluation:')
                dev_step(x_dev, y_dev, writer=dev_summary_writer)
                print('')
            if current_step % FLAGS.checkpoint_every == 0:
                path = saver.save(
                    sess, checkpoint_prefix, global_step=current_step)
                print('Saved model checkpoint to {}\n'.format(path))

訓練過程的輸出如下:

Parameters:
ALLOW_SOFT_PLACEMENT=True
BATCH_SIZE=64
CHECKPOINT_EVERY=100
DATA_FILES=./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8
DEV_SAMPLE_PERCENTAGE=0.1
DROPOUT_KEEP_PROB=0.5
EMBEDDING_DIM=128
EVALUATE_EVERY=100
FILTER_SIZES=3,4,5
L2_REG_LAMBDA=0.0
LOG_DEVICE_PLACEMENT=False
NUM_CHECKPOINTS=5
NUM_EPOCHS=10
NUM_FILTERS=128

Loading data...
Vocabulary Size: 3628
Train/Dev split: 9001/1000

Writing to /home/widiot/workspace/tensorflow-ws/tensorflow-gpu/text-classification/cnn-zh-text-classification/runs/1517734186

2018-02-04T16:50:03.709761: step 1, loss 5.36006, acc 0.46875
2018-02-04T16:50:03.786874: step 2, loss 4.61227, acc 0.390625
2018-02-04T16:50:03.857796: step 3, loss 2.50795, acc 0.5625
...
2018-02-04T16:50:10.819505: step 98, loss 0.622567, acc 0.90625
2018-02-04T16:50:10.899140: step 99, loss 1.10189, acc 0.875
2018-02-04T16:50:10.983192: step 100, loss 0.359102, acc 0.9375

Evaluation:
2018-02-04T16:50:11.848838: step 100, loss 0.132987, acc 0.961

Saved model checkpoint to /home/widiot/workspace/tensorflow-ws/tensorflow-gpu/text-classification/cnn-zh-text-classification/runs/1517734186/checkpoints/model-100

2018-02-04T16:50:12.019749: step 101, loss 0.512838, acc 0.890625
2018-02-04T16:50:12.100965: step 102, loss 0.164333, acc 0.96875
2018-02-04T16:50:12.184899: step 103, loss 0.145344, acc 0.921875
...

訓練之後會在 runs 目錄下生成對應的資料目錄,包含檢查點、日誌摘要和詞彙檔案。

訓練時的正確率變化如下:

5. 評估

評估模型的 eval.py 程式碼如下,修改的地方如下:

  • train.py 將資料檔案路徑引數修改為逗號分隔開的字串,便於實現多分類問題
  • 新增對自己未經處理的資料的清洗操作,便於直接分類評估資料
import tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_helpers
import csv
import data_helpers
from text_cnn import TextCNN
from tensorflow.contrib import learn

# 引數
# ==================================================

flags = tf.flags

# 資料引數
flags.DEFINE_string(
    'data_files',
    './data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8',
    'Comma-separated data source files')

# 評估引數
flags.DEFINE_integer('batch_size', 64, 'Batch Size (default: 64)')
flags.DEFINE_string('checkpoint_dir', './runs/1517572900/checkpoints',
                    'Checkpoint directory from training run')
flags.DEFINE_boolean('eval_train', False, 'Evaluate on all training data')

# 其他引數
flags.DEFINE_boolean('allow_soft_placement', True,
                     'Allow device soft device placement')
flags.DEFINE_boolean('log_device_placement', False,
                     'Log placement of ops on devices')

FLAGS = flags.FLAGS
FLAGS._parse_flags()
print('\nParameters:')
for attr, value in sorted(FLAGS.__flags.items()):
    print('{}={}'.format(attr.upper(), value))
print('')

# 載入訓練資料或者修改測試句子
if FLAGS.eval_train:
    x_raw, y_test = data_helpers.load_data_and_labels(FLAGS.data_files)
    y_test = np.argmax(y_test, axis=1)
else:
    x_raw = [
        '親愛的CFer,您獲得了英雄級道具。還有全新英雄級道具在等你來拿,立即登入遊戲領取吧!',
        '第一個build錯誤的解決方法能再說一下嗎,我還是不懂怎麼解決', '請聯絡張經理獲取最新資訊'
    ]
    y_test = [0, 1, 0]

# 對自己的資料的處理
x_raw_cleaned = [
    data_helpers.clean_str(data_helpers.seperate_line(line)) for line in x_raw
]
print(x_raw_cleaned)

# 將資料轉為詞彙表的索引
vocab_path = os.path.join(FLAGS.checkpoint_dir, '..', 'vocab')
vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_path)
x_test = np.array(list(vocab_processor.transform(x_raw_cleaned)))

print('\nEvaluating...\n')

# 評估
# ==================================================

checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
graph = tf.Graph()
with graph.as_default():
    session_conf = tf.ConfigProto(
        allow_soft_placement=FLAGS.allow_soft_placement,
        log_device_placement=FLAGS.log_device_placement)
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        # 載入儲存的元圖和變數
        saver = tf.train.import_meta_graph('{}.meta'.format(checkpoint_file))
        saver.restore(sess, checkpoint_file)

        # 通過名字從圖中獲取佔位符
        input_x = graph.get_operation_by_name('input_x').outputs[0]
        # input_y = graph.get_operation_by_name('input_y').outputs[0]
        dropout_keep_prob = graph.get_operation_by_name(
            'dropout_keep_prob').outputs[0]

        # 我們想要評估的tensors
        predictions = graph.get_operation_by_name(
            'output/predictions').outputs[0]

        # 生成每個輪次的batches
        batches = data_helpers.batch_iter(
            list(x_test), FLAGS.batch_size, 1, shuffle=False)

        # 收集預測值
        all_predictions = []

        for x_test_batch in batches:
            batch_predictions = sess.run(predictions, {
                input_x: x_test_batch,
                dropout_keep_prob: 1.0
            })
            all_predictions = np.concatenate(
                [all_predictions, batch_predictions])

# 如果提供了標籤則列印正確率
if y_test is not None:
    correct_predictions = float(sum(all_predictions == y_test))
    print('\nTotal number of test examples: {}'.format(len(y_test)))
    print('Accuracy: {:g}'.format(correct_predictions / float(len(y_test))))

# 儲存評估為csv
predictions_human_readable = np.column_stack((np.array(x_raw),
                                              all_predictions))
out_path = os.path.join(FLAGS.checkpoint_dir, '..', 'prediction.csv')
print('Saving evaluation to {0}'.format(out_path))
with open(out_path, 'w') as f:
    csv.writer(f).writerows(predictions_human_readable)

評估過程中的輸出如下:

Parameters:
ALLOW_SOFT_PLACEMENT=True
BATCH_SIZE=64
CHECKPOINT_DIR=./runs/1517572900/checkpoints
DATA_FILES=./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8
EVAL_TRAIN=False
LOG_DEVICE_PLACEMENT=False

['親 愛 的 您 獲 得 了 英 雄 級 道 具 還 有 全 新 英 雄 級 道 具 在 等 你 來 拿 立 即 登 錄 遊 戲 領 取 吧', '第 一 個 錯 誤 的 解 決 方 法 能 再 說 一 下 嗎 我 還 是 不 懂 怎 麼 解 決', '請 聯 系 張 經 理 獲 取 最 新 資 訊']

Evaluating...

Total number of test examples: 3
Accuracy: 1
Saving evaluation to ./runs/1517572900/checkpoints/../prediction.csv

評估之後會在 runs 目錄對應的資料夾下生成一個 prediction.csv 檔案,如下所示:

親愛的CFer,您獲得了英雄級道具。還有全新英雄級道具在等你來拿,立即登入遊戲領取吧!,0.0
第一個build錯誤的解決方法能再說一下嗎,我還是不懂怎麼解決,1.0
請聯絡張經理獲取最新資訊,0.0

相關文章