LSTM入門必讀:從入門基礎到工作方式詳解

機器之心發表於2017-07-24

長短期記憶(LSTM)是一種非常重要的神經網路技術,其在語音識別和自然語言處理等許多領域都得到了廣泛的應用。。在這篇文章中,Edwin Chen 對 LSTM 進行了系統的介紹。機器之心對本文進行了編譯。

我第一次學習 LSTM 的時候,它就吸引了我的眼球。然而並不是那種看到果凍甜圈圈時候的驚喜的形式。事實證明 LSTM 是對神經網路的一個相當簡單的擴充套件,而且在最近幾年裡深度學習所實現的驚人成就背後都有它們的身影。所以我會盡可能直觀地來呈現它們——以便你們自己就可以弄明白。

首先,讓我們來看一幅圖:

LSTM入門必讀:從入門基礎到工作方式詳解

LSTM 很漂亮吧?讓我們開始吧!

(提示:如果你已經熟知神經網路和 LSTM,請直接跳到中間部分,本文的前半部分是一個入門。)

神經網路

想象一下,我們有一部電影的影像序列,我們想用一個活動來標記每一副影像(例如,這是一場戰鬥嗎?圖中的人物在交談嗎?圖中的人物在吃東西嗎......)

我們如何做到這一點呢?

一種方法就是忽略影像的順序本質,構造將每幅影像單獨考慮的影像分類器。例如,在提供足夠多的影像和標籤時:

  • 我們的演算法首先檢測到較低水平的模式,例如形狀和邊緣。
  • 在更多的資料下,它可能學會將這些模式組合成更加複雜的模式,例如人臉(兩個圓形東西下面有一個三角形的東西,下面還有一個橢圓形的東西),或者貓。
  • 甚至在更多的資料下,它可能學會把這些高水平的模式對映到活動本身(具有嘴巴、牛排和叉子的情景可能與吃有關)。

那麼,這就是一個深度神經網路(deep neural network):它使用一副圖片作為輸入返回一個活動作為輸出,就像我們可以在不瞭解任何關於狗的知識就可以學會在狗的行為中檢測到模式一樣(在看了足夠多的柯基犬之後,我們會發現一些諸如毛茸茸的屁股和鼓槌般的腿),深度神經網路可以透過隱藏層的表徵來學會表示圖片。

數學描述

我假定讀者早已熟悉了基本的神經網路,下面讓我們來快速地複習一下吧。

  • 只有一個單獨的隱藏層的神經網路將一個向量 x 作為輸入,我們可以將它看做一組神經元。
  • 每個輸入神經元都被透過一組學習得到的權重連線到隱藏層。
  • 第 j 個隱藏神經元的輸出如下:(其中ϕ 是一個啟用函式)LSTM入門必讀:從入門基礎到工作方式詳解
  • 隱藏層是全連線到輸出層的,第 j 個輸出神經元的輸出 yj 如下:如果我們需要輸出機率,我們可以透過 softmax 函式對輸出做一下變換。LSTM入門必讀:從入門基礎到工作方式詳解

寫成矩陣形式如下:

LSTM入門必讀:從入門基礎到工作方式詳解

其中

  • x 是輸入向量
  • W 是連線輸入和隱藏層的權重矩陣
  • V 是連線隱藏層和輸出的權重矩陣
  • 常用的啟用函式ϕ分別是 sigmoid 函式σ(x),它可以將數字壓縮在(0,1)的範圍;雙曲正切函式(hyperbolic tangent)tanh(x),它將數字壓縮在(-1,1)的範圍;以及修正線性單元函式(rectified linear unit)函式,ReLU(x)=max(0,x)。

下面用一幅圖來描述神經網路:

LSTM入門必讀:從入門基礎到工作方式詳解

(注意:為了使符號更加簡潔,我假設 x 和 h 各包含一個代表學習偏差權重的固定為 1 的附加偏置神經元(bias neuron)。)

使用迴圈神經網路(RNN)記憶資訊

然而忽略電影影像的序列資訊只是最簡單的機器學習。如果我們看見了一副沙灘的景象,我們應該在之後的幀裡強調沙灘的活動:某人在水中的圖片應該被更多可能地標記為游泳,而不是洗澡;某人閉著眼睛躺著的圖片應該被更多地標記為日光浴。如果我們記得 Bob 剛剛到了一家超市,那麼即使沒有任何特別的超市特徵,Bob 拿著一塊培根的照片應該更可能地被歸類為購物而不是烹飪。

所以我們想要的就是讓我們的模型去追蹤這個世界的狀態:

1. 在看完每一張圖片之後,模型會輸出一個標籤,也會更新關於這個世界的知識。例如,模型可能學會自動地發現和追蹤位置(目前的場景是在室內還是在沙灘?)、一天中的時間(如果場景中包含月亮,那麼模型應該記住現在是晚上)以及電影中的進度(這是第一張圖還是第 100 幀?)等資訊。至關重要的是,就像神經網路能夠在沒有被饋送資訊的情況下自動地發現隱藏的邊緣、形狀以及人臉等影像一樣,我們的模型也應該依靠它們自己來發現一些有用的資訊。

2. 在被給定一張新圖片的時候,模型應該結合已經收集到的知識來做出更好的工作。

這就是一個迴圈神經網路(RNN)。除了簡單地輸入一幅影像並返回一個活動標籤,RNN 也會維護內部關於這個世界的知識(就是分配給不同資訊片段的權重),以幫助執行它的分類。

數學描述

所以,讓我們把內部知識(internal knowledge)的概念加入到我們的方程中吧,我們可以將內部記憶看做網路會隨著時間進行維護的資訊片段的記憶。

但是這是容易的:我們知道神經網路的隱藏層早已將輸入的有用資訊做了編碼,所以我們為何不把這些隱藏層作為記憶呢?這就有了我們的 RNN 方程:

LSTM入門必讀:從入門基礎到工作方式詳解

注意在時間 t 計算得到的隱藏狀態 ht(ht 就是我們這裡的內部知識)會被反饋到下一個時間。(另外,我會使用例如隱藏狀態、知識、記憶以及信念這樣的詞語來變換地描述 ht)

LSTM入門必讀:從入門基礎到工作方式詳解

透過 LSTM 來實現更長時間的記憶

讓我們來思考一下模型是如何更新關於這個世界的知識的。到目前為止,我們還沒有給這種更新施加任何限制,所以它的知識可能變得非常混亂:在一幀影像裡面它會認為人物在美國,在下一幀它看到人在吃壽司,就會認為人是在日本,在其後的一幀它看到了北極熊,就會認為他們是在伊茲拉島。或者也許它有大量的資訊表明 Alice 是一名投資分析師,但是在它看到了她的廚藝之後它就會認定她是一名職業殺手。

這種混亂意味著資訊在快速地轉移和消失,模型難以保持長期的記憶。所以我們想要的是讓網路學會如何讓它以一種更加溫和的方式來進化自己關於這個世界的知識,從而更新自己的信念(沒有 Bob 的場景不應該改變關於 Bob 的資訊包含 Alice 的場景應該聚焦於收集關於她的一些細節資訊)。

下面是我們如何做這件事的 4 種方式:

1. 新增一個遺忘機制(forgetting mechanism):如果一個場景結束了,模型應該忘記當前場景中的位置,一天的時間並且重置任何與場景相關的資訊;然而,如果場景中的一個人死掉了,那麼模型應該一直記住那個死去的人已經不再活著了。因此,我們想要模型學會一種有區分的遺忘/記憶機制:當新的輸入到來時,它需要知道記住哪些信念,以及丟棄哪些信念。

2. 新增一個儲存機制(saving mechanism):當模型看到一副新的圖片時,它需要學習關於這張圖片的資訊是否值得使用和儲存。或許你媽媽給了你一片關於凱莉·詹娜的文章,但是誰會在乎呢?

3. 所以當新的輸入來臨時,模型首先要忘掉任何它認為不再需要的長期記憶資訊。然後學習新輸入的哪些部分是值得利用的,並將它們儲存在自己的長期記憶中。

4. 將長期記憶聚焦在工作記憶中:最後,模型需要學習長期記憶中的哪些部分是即刻有用的。例如,Bob 的年齡可能是一條需要長期保持的資訊(兒童很可能正在玩耍,而成年人很可能正在工作),但是如果他不在當前的場景中,那麼這條資訊很可能就不是特別相關。所以,模型學習去聚焦哪一部分,而不總是使用完全的長期記憶。

這就是一個長短期記憶網路(long short-term memory network)。LSTM 會以一種非常精確的方式來傳遞記憶——使用了一種特定的學習機制:哪些部分的資訊需要被記住,哪些部分的資訊需要被更新,哪些部分的資訊需要被注意。與之相反,迴圈神經網路會以一種不可控制的方式在每一個時間步驟都重寫記憶。這有助於在更長的時間內追蹤資訊。


數學描述

讓我們來對 LSTM 做一下數學描述。

在時間 t,我們收到了新的輸入 xt。我們也有自己的從之前的時間步中傳遞下來的長期記憶和工作記憶,ltm(t−1)以及 wm(t−1)(兩者都是 n 維向量),這就是我們想要更新的東西。

我們將要開始我們的長期記憶。首先,我們需要知道哪些長期記憶需要保持,哪些需要丟棄,所以我們想要使用新的輸入和我們的工作記憶來學習一個由 n 個介於 0 和 1 之間的數字組成的記憶門,每一個數字都決定一個長期記憶的元素被保持多少。(1 意味著完全保持,0 意味著完全丟棄。)

自然地我們可以使用一個小型神經網路來學習這個記憶門:

LSTM入門必讀:從入門基礎到工作方式詳解

(注意與我們之前的神經網路方程的相似性;這只是一個淺層的神經網路。並且,我們使用了 sigmoid 啟用函式,因為我們需要的數字是介於 0 和 1 之間的。)

接下來,我們需要計算我們能夠從 xt 中學習到的資訊,也就是我們長期記憶中的候選者:

LSTM入門必讀:從入門基礎到工作方式詳解

ϕ是一個啟用函式,通常選擇雙曲正切函式。

然而,在我們將這個候選者加進我們的記憶之前,我們想要學到哪些部分是實際上值得使用和儲存的:

LSTM入門必讀:從入門基礎到工作方式詳解

(思考一下當你在網頁上讀到某些內容的時候會發生什麼。當一條新聞文章可能包含希拉蕊的資訊時,如果訊息來源是 Breitbart,那你就應該忽略它。)

現在讓我們把所有這些步驟結合起來。在忘掉我們認為將來不會再次用到的資訊以及儲存有用的新來的資訊之後,我們就有了更新的長期記憶:

LSTM入門必讀:從入門基礎到工作方式詳解

接下來,來更新我們的工作記憶:我們想要學習如何將我們的長期記憶專注於那些將會即刻有用的資訊上。(換句話說,我們想要學習將哪些資訊從外部硬碟移動到正在工作的筆記本記憶體上。)所以我們會學習一個聚焦/注意向量(focus/attention vector):

LSTM入門必讀:從入門基礎到工作方式詳解

然後我們的工作記憶就成為了:

LSTM入門必讀:從入門基礎到工作方式詳解

換言之,我們將全部注意集中在 focus 為 1 的元素上,並且忽略那些 focus 是 0 的元素。

然後我們對長期記憶的工作就完成了!也希望這能夠稱為你的長期記憶。

總結:一個普通的 RNN 用一個方程來更新隱藏狀態/記憶:

LSTM入門必讀:從入門基礎到工作方式詳解

而 LSTM 使用數個方程:

LSTM入門必讀:從入門基礎到工作方式詳解

其中每一個記憶/注意子機制只是 LSTM 的一個迷你形式:

LSTM入門必讀:從入門基礎到工作方式詳解

(注意:我在這裡使用的術語和變數的名字和通常文獻中是有所不同的。以下是一些標準名稱,以後我將會交換使用:

  • 長期記憶 ltm(t), 通常被稱為**cell state**, 簡寫 c(t).
  • 工作記憶 wm(t) 通常被稱為**hidden state**, 簡寫 h(t)。這個和普通 RNN 中的隱藏狀態是類似的。
  • 記憶向量 remember(t),通常被稱為**forget gate** (儘管遺忘門中,1 仍舊意味著完全保持記憶 0 意味著完全忘記),簡稱 f(t)。
  • 儲存向量 save(t),通常被稱為 input gate,(因為它決定輸入中有多少被允許進入 cell state),簡稱 i(t)。
  • 注意向量 focus(t),通常被稱為 output gate,簡稱 o(t)。

LSTM入門必讀:從入門基礎到工作方式詳解

卡比獸

寫這篇博文的時間我本可以抓一百隻 Pidgeys,請看下面的漫畫。

神經網路

LSTM入門必讀:從入門基礎到工作方式詳解

神經網路會以 0.6 的機率判定輸入圖片中的卡比獸正在淋浴,以 0.3 的機率判定卡比獸正在喝水,以 0.1 的機率判定卡比獸正在遭遇襲擊。

迴圈神經網路

LSTM入門必讀:從入門基礎到工作方式詳解

當迴圈神經網路被用來做這件事的時候,它具有對前一幅圖的記憶。最終結果是卡比獸正在遭遇襲擊的機率為 0.6,卡比獸正在淋浴的機率是 0.3,卡比獸正在喝水的機率是 0.1。結果要明顯好於上一幅圖中的神經網路。

LSTM

LSTM入門必讀:從入門基礎到工作方式詳解

具備長期記憶的 LSTM,在記憶了多種相關資訊的前提下,將對卡通圖畫中的場景描述準確的機率提高到了 0.9。

學會程式設計

讓我們來看一下一個 LSTM 可以做到的一些例子吧。遵循著 Andrej Karpathy 的精湛的博文(http://karpathy.github.io/2015/05/21/rnn-effectiveness/),我將使用字元級別的 LSTM 模型,這些模型接受字元序列的輸入,被訓練來預測序列中的下一個字元。

雖然這看起來有點玩笑,但是字元級別的模型確實是非常有用的,甚至比單詞級別的模型更加有用。例如:

  • 試想一個自動程式設計器足夠智慧,能夠允許你在你的手機上程式設計。從理論上講,一個 LSTM 模型能夠追蹤你當前所在函式的返回型別,可以更好地建議你返回那個變數;它也能夠在不經過編譯的情況下透過返回的錯誤型別就知道你是不是已經造成了一個 bug。
  • 像機器翻譯這樣的自然語言處理應用在處理罕見詞條的時候經常會出現問題。你如何翻譯一個從未見過的單詞呢,或者你如何將一個形容詞轉換成動詞呢?即使你知道一條推文的意思,你如何生成一個新的標籤來描述它呢?字元級別的模型可以空想出新的項,所以這是另外一個具有有趣應用的領域。

所以就開始了,我啟動了一個 EC2 p2.xlarge spot 例項,並在 Apache Commons Lang 程式碼庫(連結:https://github.com/apache/commons-lang)上訓練了一個 3 層的 LSTM 模型。以下是幾個小時後生成的程式:

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.commons.math4.linear;

import java.text.NumberFormat;
import java.io.ByteArrayInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math4.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math4.ml.neuralnet.sofm.NeuronSquareMesh2D;
import org.apache.commons.math4.distribution.DescriptiveStatistics;
import org.apache.commons.math4.optim.nonlinear.scalar.NodeFieldIntegrator;
import org.apache.commons.math4.optim.nonlinear.scalar.GradientFunction;
import org.apache.commons.math4.optim.PointValuePair;
import org.apache.commons.numbers.core.Precision;

/**
 * <p>Natural infinite is defined in basic eigenvalues of a transform are in a subconsider for the optimization ties.</p>
 *
 * <p>This implementation is the computation at a collection of a set of the solvers.</p>
 * <p>
 * This class is returned the default precision parameters after a new value for the interpolation interpolators for barycenter.
 * <p>
 * The distribution values do not ratio example function containing this interface, which should be used in uniform real distributions.</p>
 * <p>
 * This class generates a new standard deviation of the following conventions, the variance was reached as
 * constructor, and invoke the interpolation arrays</li>
 * <li>{@code a < 1} and {@code this} the regressions returned by calling
 * the same special corresponding to a representation.
 * </p>
 *
 * @since 1.2
 */
public class SinoutionIntegrator implements Serializable {

    /** Serializable version identifier */
    private static final long serialVersionUID = -7989543519820244888L;

    /**
     * Start distance between the instance and a result (does not all lead to the number of seconds).
     * <p>
     * Note that this implementation this can prevent the permutation of the preneved statistics.
     * </p>
     * <p>
     * <strong>Preconditions</strong>: <ul>
     * <li>Returns number of samples and the designated subarray, or
     * if it is null, {@code null}. It does not dofine the base number.</p>
     *
     * @param source the number of left size of the specified value
     * @param numberOfPoints number of points to be checked
     * @return the parameters for a public function.
     */
    public static double fitness(final double[] sample) {
        double additionalComputed = Double.POSITIVE_INFINITY;
        for (int i = 1; i < dim; i++) {
            final double coefficients[i] = point[i] * coefficients[i];
            double diff = a * FastMath.cos(point[i]);
            final double sum = FastMath.max(random.nextDouble(), alpha);
            final double sum = FastMath.sin(optimal[i].getReal() - cholenghat);
            final double lower = gamma * cHessian;
            final double fs = factor * maxIterationCount;
            if (temp > numberOfPoints - 1) {
                final int pma = points.size();
                boolean partial = points.toString();
                final double segments = new double[2];
                final double sign = pti * x2;
                double n = 0;
                for (int i = 0; i < n; i++) {
                    final double ds = normalizedState(i, k, difference * factor);
                    final double inv = alpha + temp;
                    final double rsigx = FastMath.sqrt(max);
                    return new String(degree, e);
                }
            }
            // Perform the number to the function parameters from one count of the values
            final PointValuePair part = new PointValuePair[n];
            for (int i = 0; i < n; i++) {
                if (i == 1) {
                    numberOfPoints = 1;
                }
                final double dev = FastMath.log(perturb(g, norm), values[i]);
                if (Double.isNaN(y) &&
                                     NaN) {
                    sum /= samples.length;
                }
                double i = 1;
                for (int i = 0; i < n; i++) {
                    statistics[i] = FastMath.abs(point[i].sign() + rhs[i]);
                }
                return new PointValuePair(true, params);
            }
        }
    }

    /**
     * Computes the number of values
     * @throws NotPositiveException if {@code NumberIsTooSmallException if {@code seed <= 0}.
     * @throws NullArgumentException if row or successes is null
     */
    public static double numericalMean(double value) {
        if (variance == null) {
            throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SUBCORSE_TRANSTOR_POPULATIONS_COEFFICIENTS,
                                                        p, numberOfSuccesses, true);
        }
        return sum;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public LeastSquaresProblem create(final StatisticalSummary sampleStats1,
                                       final double[] values, final double alpha) throws MathIllegalArgumentException {
        final double sum = sumLogImpl.toSubSpace(sample);
        final double relativeAccuracy = getSumOfLogs();
        final double[] sample1 = new double[dimension];

        for (int i = 0; i < result.length; i++) {
            verifyInterval.solve(params, alpha);
        }
        return max;
    }

    /**
     * Test creates a new PolynomialFunction function
     * @see #applyTo(double)
     */
    @Test
    public void testCosise() {
        final double p = 7.7;
        final double expected = 0.0;
        final SearchInterval d = new Power(1.0, 0.0);
        final double penalty = 1e-03;
        final double init = 0.245;
        final double t = 0.2;
        final double result = (x + 1.0) / 2.0;
        final double numeratorAdd = 13;
        final double bhigh = 2 * (k - 1) * Math.acos();

        Assert.assertEquals(0.0, true);
        Assert.assertTrue(percentile.evaluate(singletonArray), 0);
        Assert.assertEquals( 0.0, getNumberOfTrials(0, 0), 1E-10);
        Assert.assertEquals(0.201949230731, percentile.evaluate(specialValues), 1.0e-3);
        Assert.assertEquals(-10.0, distribution.inverseCumulativeProbability(0.50), 0);
        Assert.assertEquals(0.0, solver.solve(100, f, 1.0, 0.5), 1.0e-10);
    }

儘管這段程式碼確實不是完美的,但是它比很多我認識的資料科學家要做的好一些。我們可以發現 LSTM 已經學會了很多有趣的(也是正確的!)程式設計行為。

  • 它懂得如何構造類: 最頂部有 license 相關的資訊,緊跟著是 package 和 import,再然後是註釋和類的定義,再到後面是變數和函式。類似地,它知道如何建立函式:註釋遵循正確的順序(描述,然後是 @param,然後是 @return,等等),decorator 被正確放置,非空函式能夠以合適的返回語句結束。關鍵是,這種行為跨越了大篇幅的程式碼——你看圖中的程式碼塊有多大!
  • 它還能夠追蹤子程式和巢狀級別:縮排總是正確的,if 語句和 for 迴圈總能夠被處理好。
  • 它甚至還懂得如何構造測試。

那麼模型是如何做到這一點的呢?讓我們來看一下幾個隱藏狀態。

下面是一個貌似在追蹤程式碼外層縮排的神經元(當讀取字元作為輸入的時候,也就是說,在嘗試生成下一個字元的時候,每一個字元都被著上了神經元狀態的顏色;紅色的單元是負的,藍色的單元是正的):

LSTM入門必讀:從入門基礎到工作方式詳解

下面是一個統計空格數量的神經元:

LSTM入門必讀:從入門基礎到工作方式詳解

娛樂一下,下面是在 TensorFlow 程式碼庫上訓練得到的另一個不同的 3 層 LSTM 模型的輸出:

"""Tests for softplus layer tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import numpy as np

from tensorflow.python.platform import test


class InvalidAllOpCost(Experiment):

  def _runTestToIndForDead(self):
    return self._divs()

  def testPad(self):
    with ops.Graph().as_default():
      var = sess.run(bucketized_op)
      self.assertAllClose(
          list(variables.global_variables()), status.eval())

  def testHttptimenaterRoutingOptimizerSize(self):
    with self.test_session() as sess:
      table = lookup_ops.IdTableWithHashBuckets(
          keys=['id', 'z'],
          example_id_column='price',
          num_outputs=6,
          input_columns=['dummy_range', 'feature', 'dimensions'])

    with self.assertRaisesRegexp(ValueError, 'Expected dict of rank dimensions'):
      fc.numeric_column('aaa', indices=[[0, 0], [1, 0]], dtype=dtypes.int64)
    output = table.lookup(input_string)

    # all input tensors in SparseColumn has dimensions [end_back_prob, dimension] in the format.
    with self.assertRaisesRegexp(
        TypeError, "Shape of values must be specified during training."):
      fc.bucketized_column(attrs, boundaries=[62, 62])

網路上還有很多有趣的例子,如果你想了解更多,請檢視:http://karpathy.github.io/2015/05/21/rnn-effectiveness/

研究 LSTM 的內部

讓我們再稍往深處挖掘一下。我們看一下上一部分隱藏狀態的例子,但是我也想玩轉 LSTM cell 狀態以及其他的記憶機制。我們期待著,它們會迸發出火花呢,還是會有令人驚喜的畫面?

計數

為了研究,讓我們從教一個 LSTM 計數開始。(你應該還記得 Java 和 Python 的 LSTM 模型是如何生成合適的縮排的!)所以我生成了如下形式的序列:

aaaaaXbbbbb

(N 個字母「a」,後面跟著一個字母分隔符 X,後面是 N 個字母「b」,其中 1 <= N <= 10),然後訓練一個具有 10 個隱藏神經元的單層 LSTM。

不出所料,LSTM 模型在訓練期間完美地學習--甚至能夠將生成推廣到幾步之外。(即使在開始的時候當我們嘗試讓它記到 19 的時候它失敗了。)

aaaaaaaaaaaaaaaXbbbbbbbbbbbbbbb
aaaaaaaaaaaaaaaaXbbbbbbbbbbbbbbbb
aaaaaaaaaaaaaaaaaXbbbbbbbbbbbbbbbbb
aaaaaaaaaaaaaaaaaaXbbbbbbbbbbbbbbbbbb
aaaaaaaaaaaaaaaaaaaXbbbbbbbbbbbbbbbbbb # Here it begins to fail: the model is given 19 "a"s, but outputs only 18 "b"s.

我們期望找到一個隱藏狀態神經元,它能夠在我們觀察模型內部的時候計出每一個 a 的數目。正如我們做的:

LSTM入門必讀:從入門基礎到工作方式詳解

我開發了一個可以讓你玩轉 LSTM 的小型的 web app,神經元 #2 貌似既能夠記錄已經看到的 a 的數目,也能記錄已經看到的字元 b 的數目。(請記得,單元的顏色是根據啟用程度著色的,從深紅色的 [-1] 到深藍色的 [+1]。)

那麼 cell 的狀態是怎麼樣的呢?它的行為類似於這樣:

LSTM入門必讀:從入門基礎到工作方式詳解

有趣的是,工作記憶就像是長期記憶的「銳化版」。但是這個在一般情況是否成立呢?

這確實是成立的。(我正是我們所期望的,因為長期記憶被雙曲正切啟用函式進行了壓縮,而且輸出門限制了透過它的內容。)例如,下圖是所有的 10 個 cell 在某一時刻的狀態。我們看到了大量的顏色很清淡的 cell,這代表它們的值接近 0。

LSTM入門必讀:從入門基礎到工作方式詳解

相比之下,10 個工作記憶的神經元看起來更加聚焦。第 1、3、5、7 個神經元甚至在序列的前半部分全是 0。

LSTM入門必讀:從入門基礎到工作方式詳解

讓我們再回頭看一下神經元 #2。這裡有一些候選的記憶和輸入門。它們在每個序列的前半部分或者後半部分都是相對不變的——就像神經元在每一步都在進行 a+=1 或者 a-=1 的計算。

LSTM入門必讀:從入門基礎到工作方式詳解

LSTM入門必讀:從入門基礎到工作方式詳解

最後,這裡是神經元 2 的整體概覽:

LSTM入門必讀:從入門基礎到工作方式詳解

如果你想自己研究一下不同計數神經元,你可以在這個視覺化 web app 中自己玩一下。

LSTM入門必讀:從入門基礎到工作方式詳解

(注意:這遠遠不是一個 LSTM 模型可以學會計數的唯一方式,我在這裡只描述了一個而已。但是我認為觀察網路行為是有趣的,並且這有助於構建更好的模型;畢竟,神經網路中的很多思想都是來自於人腦。如果我們看到了意料之外的行為,我們也許會有能力設計出更加有效地學習機制。)

來自計數的計數

讓我們來看一下一個稍微有點複雜的計數器。這次,我生成了如下的序列形式:

aaXaXaaYbbbbb

(N 個 a 中間隨機地插入 X,後邊跟一個分隔符 Y,再後邊是 N 個 b。)LSTM 仍然必須數清楚 a 的數目,但是這一次需要忽略 X 的數目。

在這個連結中檢視整個 LSTM(http://blog.echen.me/lstm-explorer/#/network?file=selective_counter)我們希望看到一個正在計數的神經元——一個正在計數的、每看到一個 X 輸入門就變成 0 的神經元。在我們做到了!

LSTM入門必讀:從入門基礎到工作方式詳解


上圖是 neuron 20 的 cell 狀態。它的值一直保持增大,直到遇到分割字元 Y,然後就一直減小,直到序列的末尾——就像在計算一個隨著 a 增大,隨著 b 減小的變數 num_bs_left_to_print 一樣。

如果我們觀察它的輸入門,會看到它確實是將 X 的數量忽略了:

LSTM入門必讀:從入門基礎到工作方式詳解

然而,有趣的是,候選的記憶會在有關聯的 X 上被完全啟用--這證明了為什麼需要哪些輸入門。(但是,如果輸入門不是模型架構的一部分,至少在這個簡單的例子中,網路也會以其他的方式忽略 X 的數量。)

LSTM入門必讀:從入門基礎到工作方式詳解

我們再來看一下神經元 10。

LSTM入門必讀:從入門基礎到工作方式詳解

這個神經元是有趣的,因為它僅僅在讀取到 Y 的時候才會被啟用—然而它還是能夠對序列中遇到的 a 字元進行編碼。(在圖中可能很難區分出來,但是序列中 a 的數目一樣的時候,Y 的顏色是相同的,即便不相同,差距也在 0.1% 以內。你可以看到,a 比較少的序列中 Y 的顏色要淺一些。)或許其他的神經元會看到神經元 10 比較鬆弛。

LSTM入門必讀:從入門基礎到工作方式詳解

記憶狀態

下面我想研究一下 LSTM 是如何記憶狀態的。同樣的,我生成了以下形式的序列:

AxxxxxxYa
BxxxxxxYb

(也就是說,一個「A」或者「B」,後面跟著 1-10 個 x,然後是一個分割字元「Y」,最終以一個起始字元的小寫形式結尾。)這種情況下,網路需要記住到底是一個「狀態 A」還是一個「B」狀態。

我們希望找到一個神經元能夠在記得序列是以「A」開頭的,希望找到另一個神經元記得序列是以「B」開頭的。我們做到了。

例如這裡是一個「A」神經元,當讀取到「A」的時候它會啟用,持續記憶,直到需要生成最後一個字母的時候。要注意,輸入門忽略了序列中所有的 x。

LSTM入門必讀:從入門基礎到工作方式詳解

下面是對應的「B」神經元:

LSTM入門必讀:從入門基礎到工作方式詳解

有趣的一點是,即使在讀取到分隔符「Y」之前,關於 A 和 B 的知識是不需要的,但是隱藏狀態在所有的中間輸入中都是存在的。這看上去有一點「低效」,因為神經元在計數 x 的過程中做了一些雙重任務。

LSTM入門必讀:從入門基礎到工作方式詳解

複製任務

最後,讓我們來看一下 LSTM 是如何學會複製資訊的。(回想一下我們的 Java 版的 LSTM 曾經學會了記憶並且複製一個 Apache license。)

(注意:如果你思考 LSTM 是如何工作的,記住大量的單獨的、細節的資訊其實並不是它們所擅長的事情。例如,你可能已經注意到了 LSTM 生成的程式碼的一個主要缺陷就是它經常使用未定義的變數—LSTM 無法記住哪些變數已經在環境中了。這並不是令人驚奇的事情,因為很難使用單個 cell 就能有效地對想字元一樣的多值資訊進行編碼,並且 LSTM 並沒有一種自然的機制來連線相鄰的記憶以形成單詞。記憶網路(memory networks)和神經圖靈機(neutral turing machine)就是兩種能夠有助於修正這個缺點的神經網路的擴充套件形式,透過增加外部記憶元件。所以儘管複製並不是 LSTM 可以很有效地完成的,但是無論如何,去看一下它是如何完成這個工作是有趣的。)

針對這個複製任務,我訓練了一個很小的兩層 LSTM 來生成如下形式的序列:

baaXbaa
abcXabc

(也就是說,一個由 a、b、c3 種字元組成的子序列,後面跟著一個分隔符「X」,後面再跟著一個同樣的子序列)。

我不確定「複製神經元」到底應該是長什麼樣子的,所以為了找到能夠記住部分初始子序列的神經元,我觀察了一下它們在讀取分隔符 X 時的隱藏狀態。由於神經網路需要編碼初始子序列,它的狀態應該依據它們學到的東西而看起來有所不同。

例如,下面的這一幅圖畫出了神經元 5 在讀入分隔符「X」時候的隱藏狀態。這個神經元明顯將那些以「c」開頭的序列從那些不是以「c」開頭的序列中區分出來。

LSTM入門必讀:從入門基礎到工作方式詳解

另一個例子,這是神經元 20 在讀入分隔符「X」時的隱藏狀態。看起來它選擇了那些以「b」開頭的子序列。

LSTM入門必讀:從入門基礎到工作方式詳解

有趣的是,如果我們觀察神經元 20 的 cell 狀態,它貌似能夠捕捉這三種子序列。

LSTM入門必讀:從入門基礎到工作方式詳解

這裡是神經元 20 關於整個序列的 cell 狀態個隱藏狀態。請注意在整個初始序列中它的隱藏狀態是關閉的(也許這是期望之中的,因為它的記憶僅僅需要在某一點被動保持)。

LSTM入門必讀:從入門基礎到工作方式詳解

然而,如果我們看得更加仔細一些,就會發現,只要下一個字元是「b」, 它就是正的。所以,與其說是以 b 字母開頭的序列,還不如說是下一個字元是 b 的序列。


就我所知,這個模式在整個網路中都存在——所有的神經元貌似都在預測下一個字元,而不是在記住處在當前位置的字元。例如,神經元 5 貌似就是一個「下一個字元」預測器.

LSTM入門必讀:從入門基礎到工作方式詳解

我不確定這是不是 LSTM 在學習複製資訊時候的預設型別,或者複製機制還有哪些型別呢?

LSTM入門必讀:從入門基礎到工作方式詳解

擴充套件

讓我們來回顧一下你如何自己來探索 LSTM。

首先,我們想要解決的大多數問題都是階段性的,所以我們應該把一些過去的學習結合到我們的模型中。但是我們早已知道神經網路的隱藏層在編碼自己的資訊,所以為何不使用這些隱藏層,將它們作為我們向下一步傳遞的記憶呢?這樣一來,我們就有了迴圈神經網路(RNN)。

但是從我們的行為就能知道,我們是不願意去追蹤知識的;當我們閱讀一篇新的政論文章時,我們並不會立即相信它所談論的內容並將其與我們自己對這個世界的信念所結合。我們選擇性地儲存哪些資訊,丟棄哪些資訊,以及哪些資訊可以用來決定如何處理下一次讀到的新聞。因此,我們想要學習收集、更新以及應用資訊——為何不透過它們自己的小型神經網路來學習這些東西呢?如此,我們就有了 LSTM。

現在我們已經走通了這個過程,我們也可以想出我們的修正:

  • 例如,或許你認為 LSTM 區分長期記憶和工作記憶是愚蠢的行為—為何不使用一種記憶呢?或者,或許你能夠發現區分記憶門和儲存門是多餘的--任何我們忘記地東西都應該被新的資訊代替,反之亦然。所以我們現在想出了一種流行的 LSTM 變種,門控迴圈神經網路(GRU):https://arxiv.org/abs/1412.3555
  • 或者你可能認為,當決定哪些資訊需要被記住、儲存、注意的時候,我們不應該僅僅依靠我們的工作記憶—為什麼不同時使用長期記憶呢?如此,你發現了 Peephole LSTM

讓我們看一下最後的例子,使用一個兩層多的 LSTM 來訓練 Trump 的推特,儘管這是很大規模的資料集,但是這個 LSTM 已經足以學到很多模式。

例如,這是一個在標籤、URL 以及 @mention 中跟蹤位置的神經元:

LSTM入門必讀:從入門基礎到工作方式詳解

這是一個合適的名詞檢測器(注意它並不是簡單的注重大寫單詞):

LSTM入門必讀:從入門基礎到工作方式詳解

這是一個助動詞+「to be」的檢測器(例如 will be, I've always been,has never been)

LSTM入門必讀:從入門基礎到工作方式詳解

這是一個引文屬性:

LSTM入門必讀:從入門基礎到工作方式詳解

這是一個 MAGA 和大小寫神經元:

LSTM入門必讀:從入門基礎到工作方式詳解

這裡是一些用 LSTM 生成的公告(ok,其中有一個是一條真正的推特,你猜一下哪個是):

LSTM入門必讀:從入門基礎到工作方式詳解

不幸的是,LSTM 僅僅學會了像瘋子一樣瘋狂書寫。

相關文章