從rnn結構說起
根據輸出和輸入序列不同數量rnn可以有多種不同的結構,不同結構自然就有不同的引用場合。如下圖,
- one to one 結構,僅僅只是簡單的給一個輸入得到一個輸出,此處並未體現序列的特徵,例如影像分類場景。
- one to many 結構,給一個輸入得到一系列輸出,這種結構可用於生產圖片描述的場景。
- many to one 結構,給一系列輸入得到一個輸出,這種結構可用於文字情感分析,對一些列的文字輸入進行分類,看是消極還是積極情感。
- many to many 結構,給一些列輸入得到一系列輸出,這種結構可用於翻譯或聊天對話場景,對輸入的文字轉換成另外一些列文字。
- 同步 many to many 結構,它是經典的rnn結構,前一輸入的狀態會帶到下一個狀態中,而且每個輸入都會對應一個輸出,我們最熟悉的就是用於字元預測了,同樣也可以用於視訊分類,對視訊的幀打標籤。
seq2seq
在 many to many 的兩種模型中,上圖可以看到第四和第五種是有差異的,經典的rnn結構的輸入和輸出序列必須要是等長,它的應用場景也比較有限。而第四種它可以是輸入和輸出序列不等長,這種模型便是seq2seq模型,即Sequence to Sequence。它實現了從一個序列到另外一個序列的轉換,比如google曾用seq2seq模型加attention模型來實現了翻譯功能,類似的還可以實現聊天機器人對話模型。經典的rnn模型固定了輸入序列和輸出序列的大小,而seq2seq模型則突破了該限制。
其實對於seq2seq的decoder,它在訓練階段和預測階段對rnn的輸出的處理可能是不一樣的,比如在訓練階段可能對rnn的輸出不處理,直接用target的序列作為下時刻的輸入,如上圖一。而預測階段會將rnn的輸出當成是下一時刻的輸入,因為此時已經沒有target序列可以作為輸入了,如上圖二。
encoder-decoder結構
seq2seq屬於encoder-decoder結構的一種,這裡看看常見的encoder-decoder結構,基本思想就是利用兩個RNN,一個RNN作為encoder,另一個RNN作為decoder。encoder負責將輸入序列壓縮成指定長度的向量,這個向量就可以看成是這個序列的語義,這個過程稱為編碼,如下圖,獲取語義向量最簡單的方式就是直接將最後一個輸入的隱狀態作為語義向量C。也可以對最後一個隱含狀態做一個變換得到語義向量,還可以將輸入序列的所有隱含狀態做一個變換得到語義變數。
而decoder則負責根據語義向量生成指定的序列,這個過程也稱為解碼,如下圖,最簡單的方式是將encoder得到的語義變數作為初始狀態輸入到decoder的rnn中,得到輸出序列。可以看到上一時刻的輸出會作為當前時刻的輸入,而且其中語義向量C只作為初始狀態參與運算,後面的運算都與語義向量C無關。
decoder處理方式還有另外一種,就是語義向量C參與了序列所有時刻的運算,如下圖,上一時刻的輸出仍然作為當前時刻的輸入,但語義向量C會參與所有時刻的運算。
encoder-decoder模型對輸入和輸出序列的長度沒有要求,應用場景也更加廣泛。
如何訓練
前面有介紹了encoder-decoder模型的簡單模型,但這裡以下圖稍微複雜一點的模型說明訓練的思路,不同的encoder-decoder模型結構有差異,但訓練的核心思想都大同小異。
我們知道RNN是可以學習概率分佈然後進行預測的,比如我們輸入t個時刻的資料後預測t+1時刻的資料,最經典的就是字元預測的例子,可在前面的《迴圈神經網路》和《TensorFlow構建迴圈神經網路》瞭解到更加詳細的說明。為了得到概率分佈一般會在RNN的輸出層使用softmax啟用函式,就可以得到每個分類的概率。
對於RNN,對於某個序列,對於時刻t,它的輸出概率為$p(x_t|x1,…,x{t-1})$,則softmax層每個神經元的計算如下:
其中$h_t$是隱含狀態,它與上一時刻的狀態及當前輸入有關,即$ht = f(h{t-1},x_t)$。
那麼整個序列的概率就為
而對於encoder-decoder模型,設有輸入序列
,輸出序列
,輸入序列和輸出序列的長度可能不同。那麼其實就是要根據輸入序列去得到輸出序列的可能,於是有下面的條件概率,
發生的情況下
發生的概率等於
連乘。其中v表示
對應的隱含狀態向量,它其實可以等同表示輸入序列。
此時,
,decoder的隱含狀態與上一時刻狀態、上一時刻輸出和狀態向量v都有關,這裡不同於RNN,RNN是與當前時刻輸入相關,而decoder是將上一時刻的輸出輸入到RNN中。於是decoder的某一時刻的概率分佈可用下式表示,
所以對於訓練樣本,我們要做的就是在整個訓練樣本下,所有樣本的
概率之和最大,對應的對數似然條件概率函式為,
,使之最大化,$ heta$則是待確定的模型引數。對於rnn、lstm和gru的結構可以看這幾篇文章《迴圈神經網路》 《LSTM神經網路》 《GRU神經網路》。
========廣告時間========
鄙人的新書《Tomcat核心設計剖析》已經在京東銷售了,有需要的朋友可以到 item.jd.com/12185360.ht… 進行預定。感謝各位朋友。
=========================
歡迎關注: