RNN程式碼解讀之char-RNN with TensorFlow(model.py)

CopperDong發表於2018-05-27

此工程解讀連結(建議按順序閱讀): 
RNN程式碼解讀之char-RNN with TensorFlow(model.py) 
RNN程式碼解讀之char-RNN with TensorFlow(train.py) 
RNN程式碼解讀之char-RNN with TensorFlow(util.py) 
RNN程式碼解讀之char-RNN with TensorFlow(sample.py)

最近一直在學習RNN的相關知識,個人認為相比於CNN各種模型在detection/classification/segmentation等方面超人的表現,RNN還有很長的一段路要走,畢竟現在的nlp模型單從output質量上來看只是差強人意,要和人相比還有一段距離。CNN+RNN的任務比如image caption更是有很多有待研究和提高的地方。

關於對CNN和RNN相關內容的學習和探討,我將會在近期更新對一些經典論文的解讀以及自己的看法,屆時歡迎大家給予指導。

當然,CS231n中有一句名言“Don’t think too hard, just cross your fingers.” 想法還是要落地才可以看到成果,那麼我們今天就一起來看一下大牛Adrew Karparthy的char-RNN模型,AK使用lua基於torch寫的,git上已經有人及時的復現了TensorFlow with Python版本(https://github.com/sherjilozair/char-rnn-tensorflow)。

網上已經有很多相關的解析了,但大部分只是針對model進行解釋,這對於整體模型的巨集觀理解以及TensorFlow的學習都是很不利的。因此,這裡我會給出自己對所有程式碼的理解,若有錯誤歡迎及時指正。

這一個版本的程式碼共分為四個模組:model.py,train.py, util.py以及sample.py,我們將按照這個順序,分四篇博文對四個模組進行梳理。我在程式碼中對所有我認為重要的地方都寫了註釋,有的部分甚至每一行都有明確的註釋,但難免有的基本方法會讓人產生疑惑。面對這種問題,我強烈建議大家一邊debug一步一步的執行看結果,一邊百度或者google。這樣梳理一遍程式碼一定會全身舒暢,豁然開朗,感覺開啟了新世界的大門,對於RNN模型的TensorFlow實現也會更有把握。

當然理解這一個工程並不是我們的終極目的,針對後面跟新的paper中提到的有創新的方法,我們也會再此模型的基礎上進一步實現,走上我們的科研之路。

廢話說太多了,下面我們先開始看最重點的model.py 
注意:這裡註釋解釋的只是訓練過程中的理解,在infer過程中batch=1,sequence=1,大體理解沒有差別,但是具體思想還需要大家到時候再推敲推敲。此外,此class中的sample方法這一節不討論,到第四節sample.py的時候一併討論。

#-*-coding:utf-8-*-
import tensorflow as tf
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import seq2seq

import numpy as np

class Model():
    def __init__(self, args, infer=False):
        self.args = args
        #在測試狀態下(inference)才用如下選項
        if infer:
            args.batch_size = 1
            args.seq_length = 1
        #幾種備選的rnn型別
        if args.model == 'rnn':
            cell_fn = rnn_cell.BasicRNNCell
        elif args.model == 'gru':
            cell_fn = rnn_cell.GRUCell
        elif args.model == 'lstm':
            cell_fn = rnn_cell.BasicLSTMCell
        else:
            raise Exception("model type not supported: {}".format(args.model))
        #固定格式是例:cell = rnn_cell.GRUCelll(rnn_size)
        #rnn_size指的是每個rnn單元中的神經元個數(雖然RNN途中只有一個圓圈代表,但這個圓圈代表了rnn_size個神經元)
        #這裡state_is_tuple根據官網解釋,每個cell返回的h和c狀態是儲存在一個list裡還是兩個tuple裡,官網建議設定為true
        cell = cell_fn(args.rnn_size, state_is_tuple=True)
        #固定格式,有幾層rnn
        self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
        #input_data&target(標籤)格式:[batch_size, seq_length]
        self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
        self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
        #cell的初始狀態設為0,因為在前面設定cell時,cell_size已經設定好了,因此這裡只需給出batch_size即可
        #(一個batch內有batch_size個sequence的輸入)
        self.initial_state = cell.zero_state(args.batch_size, tf.float32)

        #rnnlm = recurrent neural network language model
        #variable_scope就是變數的作用域
        with tf.variable_scope('rnnlm'):
            #softmax層的引數
            softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
            softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
            with tf.device("/cpu:0"):
                #推薦使用tf.get_variable而不是tf.variable
                #embedding矩陣是將輸入轉換到了cell_size,因此這樣的大小設定
                embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
                #關於tf.nn.embedding_lookup(embedding, self.input_data):
                #   呼叫tf.nn.embedding_lookup,索引與train_dataset對應的向量,相當於用train_dataset作為一個id,去檢索矩陣中與這個id對應的embedding
                #將第三個引數,在第1維度,切成seq_length長短的片段
                #embeddinglookup得到的look_up尺寸是[batch_size, seq_length, rnn_size],這裡是[50,50,128]
                look_up = tf.nn.embedding_lookup(embedding, self.input_data)
                #將上面的[50,50,128]切開,得到50個[50,1,128]的inputs
                inputs = tf.split(1, args.seq_length, look_up)
                #之後將 1 squeeze掉,50個[50,128]
                inputs = [tf.squeeze(input_, [1]) for input_ in inputs]

        #在infer的時候方便檢視
        def loop(prev, _):
            prev = tf.matmul(prev, softmax_w) + softmax_b
            prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
            return tf.nn.embedding_lookup(embedding, prev_symbol)

        #seq2seq.rnn_decoder基於schedule sampling實現,相當於一個黑盒子,可以直接呼叫
        #得到的兩個引數shape均為50個50*128的張量,和輸入是一樣的
        outputs, last_state = seq2seq.rnn_decoder(inputs,
                                                  self.initial_state, cell,
                                                  loop_function=loop if infer else None,
                                                  scope='rnnlm')
        #將outputsreshape在一起,形成[2500,128]的張量
        output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
        #logits和probs的大小都是[2500,65]([2500,128]*[128,65])
        self.logits = tf.matmul(output, softmax_w) + softmax_b
        self.probs = tf.nn.softmax(self.logits)
        #得到length為2500的loss(即每一個batch的sequence中的每一個單詞輸入,都會最終產生一個loss,50*50=2500)
        loss = seq2seq.sequence_loss_by_example([self.logits],
                [tf.reshape(self.targets, [-1])],
                [tf.ones([args.batch_size * args.seq_length])],
                args.vocab_size)
        #得到一個batch的cost後面用於求梯度
        self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
        #將state轉換一下,便於下一次繼續訓練
        self.final_state = last_state
        #因為學習率不需要BPTT更新,因此trainable=False
        #具體的learning_rate是由train.py中args引數傳過來的,這裡只是初始化設了一個0
        self.lr = tf.Variable(0.0, trainable=False)
        #返回了包括前面的softmax_w/softmax_b/embedding等所有變數
        tvars = tf.trainable_variables()
        #求grads要使用clip避免梯度爆炸,這裡設定的閾值是5(見args)
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
                args.grad_clip)
        #使用adam優化方法
        optimizer = tf.train.AdamOptimizer(self.lr)
        #參考tensorflow手冊,
        # 將計算出的梯度應用到變數上,是函式minimize()的第二部分,返回一個應用指定的梯度的操作Operation,對global_step做自增操作
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95

以上就是對於model.py的程式碼分析,總體來說就是“模型定義+引數設定+優化”的思路,如果有哪裡出錯還望大家多多指教啦~!

參考資料: 
http://blog.csdn.net/mydear_11000/article/details/52776295 
https://github.com/sherjilozair/char-rnn-tensorflow 
http://www.tensorfly.cn/tfdoc/api_docs/python/constant_op.html#truncated_normal

相關文章