RNN程式碼解讀之char-RNN with TensorFlow(model.py)
此工程解讀連結(建議按順序閱讀):
RNN程式碼解讀之char-RNN with TensorFlow(model.py)
RNN程式碼解讀之char-RNN with TensorFlow(train.py)
RNN程式碼解讀之char-RNN with TensorFlow(util.py)
RNN程式碼解讀之char-RNN with TensorFlow(sample.py)
最近一直在學習RNN的相關知識,個人認為相比於CNN各種模型在detection/classification/segmentation等方面超人的表現,RNN還有很長的一段路要走,畢竟現在的nlp模型單從output質量上來看只是差強人意,要和人相比還有一段距離。CNN+RNN的任務比如image caption更是有很多有待研究和提高的地方。
關於對CNN和RNN相關內容的學習和探討,我將會在近期更新對一些經典論文的解讀以及自己的看法,屆時歡迎大家給予指導。
當然,CS231n中有一句名言“Don’t think too hard, just cross your fingers.” 想法還是要落地才可以看到成果,那麼我們今天就一起來看一下大牛Adrew Karparthy的char-RNN模型,AK使用lua基於torch寫的,git上已經有人及時的復現了TensorFlow with Python版本(https://github.com/sherjilozair/char-rnn-tensorflow)。
網上已經有很多相關的解析了,但大部分只是針對model進行解釋,這對於整體模型的巨集觀理解以及TensorFlow的學習都是很不利的。因此,這裡我會給出自己對所有程式碼的理解,若有錯誤歡迎及時指正。
這一個版本的程式碼共分為四個模組:model.py,train.py, util.py以及sample.py,我們將按照這個順序,分四篇博文對四個模組進行梳理。我在程式碼中對所有我認為重要的地方都寫了註釋,有的部分甚至每一行都有明確的註釋,但難免有的基本方法會讓人產生疑惑。面對這種問題,我強烈建議大家一邊debug一步一步的執行看結果,一邊百度或者google。這樣梳理一遍程式碼一定會全身舒暢,豁然開朗,感覺開啟了新世界的大門,對於RNN模型的TensorFlow實現也會更有把握。
當然理解這一個工程並不是我們的終極目的,針對後面跟新的paper中提到的有創新的方法,我們也會再此模型的基礎上進一步實現,走上我們的科研之路。
廢話說太多了,下面我們先開始看最重點的model.py
注意:這裡註釋解釋的只是訓練過程中的理解,在infer過程中batch=1,sequence=1,大體理解沒有差別,但是具體思想還需要大家到時候再推敲推敲。此外,此class中的sample方法這一節不討論,到第四節sample.py的時候一併討論。
#-*-coding:utf-8-*-
import tensorflow as tf
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import seq2seq
import numpy as np
class Model():
def __init__(self, args, infer=False):
self.args = args
#在測試狀態下(inference)才用如下選項
if infer:
args.batch_size = 1
args.seq_length = 1
#幾種備選的rnn型別
if args.model == 'rnn':
cell_fn = rnn_cell.BasicRNNCell
elif args.model == 'gru':
cell_fn = rnn_cell.GRUCell
elif args.model == 'lstm':
cell_fn = rnn_cell.BasicLSTMCell
else:
raise Exception("model type not supported: {}".format(args.model))
#固定格式是例:cell = rnn_cell.GRUCelll(rnn_size)
#rnn_size指的是每個rnn單元中的神經元個數(雖然RNN途中只有一個圓圈代表,但這個圓圈代表了rnn_size個神經元)
#這裡state_is_tuple根據官網解釋,每個cell返回的h和c狀態是儲存在一個list裡還是兩個tuple裡,官網建議設定為true
cell = cell_fn(args.rnn_size, state_is_tuple=True)
#固定格式,有幾層rnn
self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
#input_data&target(標籤)格式:[batch_size, seq_length]
self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
#cell的初始狀態設為0,因為在前面設定cell時,cell_size已經設定好了,因此這裡只需給出batch_size即可
#(一個batch內有batch_size個sequence的輸入)
self.initial_state = cell.zero_state(args.batch_size, tf.float32)
#rnnlm = recurrent neural network language model
#variable_scope就是變數的作用域
with tf.variable_scope('rnnlm'):
#softmax層的引數
softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
with tf.device("/cpu:0"):
#推薦使用tf.get_variable而不是tf.variable
#embedding矩陣是將輸入轉換到了cell_size,因此這樣的大小設定
embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
#關於tf.nn.embedding_lookup(embedding, self.input_data):
# 呼叫tf.nn.embedding_lookup,索引與train_dataset對應的向量,相當於用train_dataset作為一個id,去檢索矩陣中與這個id對應的embedding
#將第三個引數,在第1維度,切成seq_length長短的片段
#embeddinglookup得到的look_up尺寸是[batch_size, seq_length, rnn_size],這裡是[50,50,128]
look_up = tf.nn.embedding_lookup(embedding, self.input_data)
#將上面的[50,50,128]切開,得到50個[50,1,128]的inputs
inputs = tf.split(1, args.seq_length, look_up)
#之後將 1 squeeze掉,50個[50,128]
inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
#在infer的時候方便檢視
def loop(prev, _):
prev = tf.matmul(prev, softmax_w) + softmax_b
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
return tf.nn.embedding_lookup(embedding, prev_symbol)
#seq2seq.rnn_decoder基於schedule sampling實現,相當於一個黑盒子,可以直接呼叫
#得到的兩個引數shape均為50個50*128的張量,和輸入是一樣的
outputs, last_state = seq2seq.rnn_decoder(inputs,
self.initial_state, cell,
loop_function=loop if infer else None,
scope='rnnlm')
#將outputsreshape在一起,形成[2500,128]的張量
output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
#logits和probs的大小都是[2500,65]([2500,128]*[128,65])
self.logits = tf.matmul(output, softmax_w) + softmax_b
self.probs = tf.nn.softmax(self.logits)
#得到length為2500的loss(即每一個batch的sequence中的每一個單詞輸入,都會最終產生一個loss,50*50=2500)
loss = seq2seq.sequence_loss_by_example([self.logits],
[tf.reshape(self.targets, [-1])],
[tf.ones([args.batch_size * args.seq_length])],
args.vocab_size)
#得到一個batch的cost後面用於求梯度
self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
#將state轉換一下,便於下一次繼續訓練
self.final_state = last_state
#因為學習率不需要BPTT更新,因此trainable=False
#具體的learning_rate是由train.py中args引數傳過來的,這裡只是初始化設了一個0
self.lr = tf.Variable(0.0, trainable=False)
#返回了包括前面的softmax_w/softmax_b/embedding等所有變數
tvars = tf.trainable_variables()
#求grads要使用clip避免梯度爆炸,這裡設定的閾值是5(見args)
grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
args.grad_clip)
#使用adam優化方法
optimizer = tf.train.AdamOptimizer(self.lr)
#參考tensorflow手冊,
# 將計算出的梯度應用到變數上,是函式minimize()的第二部分,返回一個應用指定的梯度的操作Operation,對global_step做自增操作
self.train_op = optimizer.apply_gradients(zip(grads, tvars))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
以上就是對於model.py的程式碼分析,總體來說就是“模型定義+引數設定+優化”的思路,如果有哪裡出錯還望大家多多指教啦~!
參考資料:
http://blog.csdn.net/mydear_11000/article/details/52776295
https://github.com/sherjilozair/char-rnn-tensorflow
http://www.tensorfly.cn/tfdoc/api_docs/python/constant_op.html#truncated_normal
相關文章
- 【Tensorflow_DL_Note9】Tensorflow原始碼解讀1原始碼
- [譯] TensorFlow 中的 RNN 串流RNN
- tensorflow教程:tf.contrib.rnn.DropoutWrapperRNNAPP
- OceanBase 原始碼解讀(九):儲存層程式碼解讀之「巨集塊儲存格式」原始碼
- kafka程式碼解讀Kafka
- 【精讀】自然語言處理基礎之RNN自然語言處理RNN
- DeepSort之原始碼解讀原始碼
- 資深 Googler 深度解讀 TensorFlowGo
- Linklist程式碼實現以及程式碼解讀
- React原始碼解讀之setStateReact原始碼
- React原始碼解讀之componentMountReact原始碼
- Tensorflow實現RNN(LSTM)手寫數字識別RNN
- 【原始碼解讀(一)】EFCORE原始碼解讀之建立DBContext查詢攔截原始碼Context
- Java之Integer#highestOneBit程式碼品讀Java
- Element UI 原始碼解讀之 Table 元件UI原始碼元件
- AspNetCore7.0原始碼解讀之UseMiddlewareNetCore原始碼
- RNN 結構詳解RNN
- tensorflow原始碼解析之framework-resource原始碼Framework
- tensorflow原始碼解析之framework-allocator原始碼Framework
- [原始碼解析] TensorFlow 分散式之 ClusterCoordinator原始碼分散式
- [原始碼解析] TensorFlow 分散式之 MirroredStrategy原始碼分散式
- AttributeError: module ‘tensorflow._api.v1.nn.rnn_cell‘ has no attribute ‘InputProjectionWrapper‘ErrorAPIRNNProjectAPP
- graph attention network(ICLR2018)官方程式碼詳解(tensorflow)ICLR
- 目標識別程式碼解讀整理
- 零程式碼的多方面解讀
- Python進階學習之程式碼閱讀Python
- 夢斷程式碼閱讀筆記之六筆記
- [譯] RNN 迴圈神經網路系列 3:編碼、解碼器RNN神經網路
- RxJava2原始碼解讀之 Map、FlatMapRxJava原始碼
- D3原始碼解讀系列之Chord原始碼
- D3原始碼解讀系列之Dispatches原始碼
- D3原始碼解讀系列之Force原始碼
- D3原始碼解讀系列之Hierarchies原始碼
- D3原始碼解讀系列之Path原始碼
- D3原始碼解讀系列之Quadtrees原始碼
- D3原始碼解讀系列之Requests原始碼
- D3原始碼解讀系列之Selections原始碼
- D3原始碼解讀系列之Shape原始碼