[譯] RNN 迴圈神經網路系列 3:編碼、解碼器

歐長坤發表於2019-03-03

本系列文章彙總

  1. RNN 迴圈神經網路系列 1:基本 RNN 與 CHAR-RNN
  2. RNN 迴圈神經網路系列 2:文字分類
  3. RNN 迴圈神經網路系列 3:編碼、解碼器
  4. RNN 迴圈神經網路系列 4:注意力機制
  5. RNN 迴圈神經網路系列 5:自定義單元

RNN 迴圈神經網路系列 3:編碼、解碼器

在本文中,我將介紹基本的編碼器(encoder)和解碼器(decoder),用於處理諸如機器翻譯之類的 seq2seq 任務。我們不會在這篇文章中介紹注意力機制,而在下一篇文章中去實現它。

如下圖所示,我們將輸入序列輸入給編碼器,然後將生成一個最終的隱藏狀態,並將其輸入到解碼器中。即編碼器的最後一個隱藏狀態就是解碼器的新初始狀態。我們將使用 softmax 來處理解碼器輸出,並將其與目標進行比較,從而計算我們的損失函式。你可以從這篇博文中找到更多關於我對原始論文中提出這個模型的介紹。這裡的主要區別在於,我沒有向編碼器的輸入新增 EOS(譯註:句子結束符,end-of-sentence)token,同時我也沒有讓編碼器對句子進行反向讀取。

Screen Shot 2016-11-19 at 4.48.03 PM.png
Screen Shot 2016-11-19 at 4.48.03 PM.png

資料

我想建立一個非常小的資料集來使用(20 個英語和西班牙語的句子)。本教程的重點是瞭解如何構建一個編碼解碼器系統,而不是去關注這個系統對諸如機器翻譯和其他 seq2seq 處理等任務的處理。所以我自己寫了幾個句子,然後把它們翻譯成西班牙語。這就是我們的資料集。

首先,我們將這些句子分隔為 token,然後將這些 token 轉換為 token ID。在這個過程中,我們收集一個詞彙字典和一個反向詞彙字典,以便在 token 和 token ID 之間來回轉換。對於我們的目標語言(西班牙語)來說,我們將新增一個額外的 EOS token。然後,我們會將源 token 和目標 token 都填充到(對應資料集中最長句子的)最大長度。這是我們模型的輸入資料。對於編碼器而言,我們將填充後的源內容直接進行輸入,而對於目標內容做進一步處理,以獲得我們的解碼器輸入和輸出。

最後,輸入結果是這個樣子的:

Screen Shot 2016-11-19 at 4.20.54 PM.png
Screen Shot 2016-11-19 at 4.20.54 PM.png

這只是某個批次中的一個樣本。其中 0 是填充的值,1 是 GO token,2 則是 EOS token。下圖是資料變換更一般的表示形式。請無視目標權重,我們不會在實現中使用它們。

screen-shot-2016-11-16-at-5-09-10-pm
screen-shot-2016-11-16-at-5-09-10-pm

編碼器

編碼器只接受編碼器的輸入,而我們唯一關心的是最終的隱藏狀態。這個隱藏的狀態包含了所有輸入的資訊。我們不會像原始論文所建議的那樣反轉編碼器的輸入,因為我們使用的是 dynamic_rnnseq_len。它會基於 seq_len 自動返回最後一個對應的隱藏狀態。

with tf.variable_scope(`encoder`) as scope:

    # RNN 編碼器單元
    self.encoder_stacked_cell = rnn_cell(FLAGS, self.dropout,
        scope=scope)

    # 嵌入 RNN 編碼器輸入
    W_input = tf.get_variable("W_input",
        [FLAGS.en_vocab_size, FLAGS.num_hidden_units])
    self.embedded_encoder_inputs = rnn_inputs(FLAGS,
        self.encoder_inputs, FLAGS.en_vocab_size, scope=scope)
    #initial_state = encoder_stacked_cell.zero_state(FLAGS.batch_size, tf.float32)

    # RNN 編碼器的輸出
    self.all_encoder_outputs, self.encoder_state = tf.nn.dynamic_rnn(
        cell=self.encoder_stacked_cell,
        inputs=self.embedded_encoder_inputs,
        sequence_length=self.en_seq_lens, time_major=False,
        dtype=tf.float32)複製程式碼

我們將使用這個最終的隱藏狀態作為解碼器的新初始狀態。

解碼器

這個簡單的解碼器將編碼器的最終的隱藏狀態作為自己的初始狀態。我們還將接入解碼器的輸入,並使用 RNN 解碼器來處理它們。輸出的結果將通過 softmax 進行歸一化處理,然後與目標進行比較。注意,解碼器輸入從一個 GO token 開始,從而用來預測第一個目標 token。解碼器輸入的最後一個對應的 token 則是用來預測 EOS 目標 token 的。

with tf.variable_scope(`decoder`) as scope:

    # 初始狀態是編碼器的最後一個對應狀態
    self.decoder_initial_state = self.encoder_state

    # RNN 解碼器單元
    self.decoder_stacked_cell = rnn_cell(FLAGS, self.dropout,
        scope=scope)

    # 嵌入 RNN 解碼器輸入
    W_input = tf.get_variable("W_input",
        [FLAGS.sp_vocab_size, FLAGS.num_hidden_units])
    self.embedded_decoder_inputs = rnn_inputs(FLAGS, self.decoder_inputs,
        FLAGS.sp_vocab_size, scope=scope)

    # RNN 解碼器的輸出
    self.all_decoder_outputs, self.decoder_state = tf.nn.dynamic_rnn(
        cell=self.decoder_stacked_cell,
        inputs=self.embedded_decoder_inputs,
        sequence_length=self.sp_seq_lens, time_major=False,
        initial_state=self.decoder_initial_state)複製程式碼

那填充值會發生什麼呢?它們也會預測一些輸出目標,而我們並不關心這些內容,但如果我們把它們考慮進去,它們仍然會影響我們的損失函式。接下來我們將遮蔽掉這些損失以消除對目標結果的影響。

損失遮蔽

我們會檢查目標,並將目標中被填充的部分遮蔽為 0。因此,當我們獲得最後一個有關的解碼器 token 時,目標就會是表示 EOS 的 token ID。而對於下一個解碼器的輸入而言,目標就會是 PAD ID,這也就是遮蔽開始的地方。

# Logit
self.decoder_outputs_flat = tf.reshape(self.all_decoder_outputs,
    [-1, FLAGS.num_hidden_units])
self.logits_flat = rnn_softmax(FLAGS, self.decoder_outputs_flat,
    scope=scope)

# 損失遮蔽
targets_flat = tf.reshape(self.targets, [-1])
losses_flat = tf.nn.sparse_softmax_cross_entropy_with_logits(
    self.logits_flat, targets_flat)
mask = tf.sign(tf.to_float(targets_flat))
masked_losses = mask * losses_flat
masked_losses = tf.reshape(masked_losses,  tf.shape(self.targets))
self.loss = tf.reduce_mean(
    tf.reduce_sum(masked_losses, reduction_indices=1))複製程式碼

注意到可以使用 PAD ID 為 0 這個事實作為遮蔽手段,我們便只需計算(一個批次中樣本的)每一行損失之和即可,然後取所有樣本損失的平均值,從而得到一個批次的損失。這時,我們就可以通過最小化這個損失函式來進行訓練了。

以下是訓練結果:

Screen Shot 2016-11-19 at 4.56.18 PM.png
Screen Shot 2016-11-19 at 4.56.18 PM.png

我們不會在這裡做任何的模型推斷,但是你可以在接下來的關於注意力機制的文章中看到。如果你真的想在這裡實現模型推斷,使用相同的模型就可以了,但你還得將預測目標的結果作為輸入接入下一個 RNN 解碼器單元。同時你還要將相同的權重集嵌入解碼器中,並將其作為 RNN 的另一個輸入。這意味著對於初始的 GO token 而言,你得嵌入一些偽造的 token 進行輸入。

結論

這個編碼解碼器模型非常簡單,但是在理解 seq2seq 實現之前,它是一個必要的基礎。在下一篇 RNN 教程中,我們將涵蓋 Attention 模型及其在編碼解碼器模型結構上的優勢。

程式碼

GitHub 倉庫 (正在更新,敬請期待!)


掘金翻譯計劃 是一個翻譯優質網際網路技術文章的社群,文章來源為 掘金 上的英文分享文章。內容覆蓋 AndroidiOSReact前端後端產品設計 等領域,想要檢視更多優質譯文請持續關注 掘金翻譯計劃官方微博知乎專欄

相關文章