面向機器智慧的TensorFlow實戰7:詞向量嵌入

CopperDong發表於2018-05-26

    本節將實現一個能夠學習詞向量的模型。對於NLP任務,這是一種表示詞的強大方式。

    作為語義關聯問題的一個解決方案,依據共生關係表示單詞的思路由來已久。這種方法的基本思路是,遍歷一個大規模文字語料庫,針對每個單詞,統計其在一定距離範圍內的周圍詞彙。然後,用附近詞彙的規範化數量表示每個詞語。這種方法背後的思想是在類似語境中使用的詞語在語義上也是相似的。這樣便可運用PCA或類似的方法對出現向量降維,從而得到更稠密的表示。雖然這種方法具有很好的效能,但它要求我們追蹤所有詞彙的共生矩陣,即一個寬度和高度均為詞彙表長度的方陣。


     2013年,Mikolov等提出了一種依據上下文計算此表示的實用有效的方法,skip-gram模型從隨機表示開始,並擁有一個試圖依據當前詞語預測一個上下文詞語的簡單分類器。誤差同時通過分類器權值和詞的表示進行傳播,需要對這兩者進行調整以減少預測誤差。研究發現,在大規模語料庫上訓練模型可表示向量逼近壓縮後的共生向量。下面利用TensorFlow實現skip-gram模型。

Efficient estimation of word representations in vector space

      準備維基百科語料庫:使用英文維基百科轉儲檔案。預設包含所有頁面的完整修訂歷史,

                     https://dumps.wikimedia.org/backup-index.html

      為了以正確的格式表示資料,還需執行若干步驟。資料收集和清洗是非常迫切和重要的任務。遍歷表示為one-hot編碼詞語的維基頁面。需要完成下列步驟:

1)下載轉儲檔案,提取頁面及其中的詞語

2)統計詞語的出現次數,構建一個由最常見詞語構成的詞彙表。

3)利用該詞彙表對提取的頁面進行編碼。

      模型結構

      噪聲對比分類器

      訓練模型:完整的語料庫,https://dumps.wikimedia.org/enwiki/20160501/enwiki-20160501-pages-meta-current.xml.bz2

Wikepedia.py

import bz2
import collections
import os
import re
from lxml import etree
from helpers import download

class Wikipedia:
    TOKEN_REGEX = re.compile(r'[A-Za-z]+|[!?.:,()]')
    def __init__(self, url, cache_dir, vocabulary_size=10000):
        self._cache_dir = os.path.expanduser(cache_dir)
        self._pages_path = os.path.join(self._cache_dir, 'pages.bz2')
        self._vocabulary_path = os.path.join(self._cache_dir, 'vocabulary.bz2')
        if not os.path.isfile(self._pages_path):
            print('Read pages')
            self._read_pages(url)
        if not os.path.isfile(self._vocabulary_path):
            print('Build vocabulary')
            self._build_vocabulary(vocabulary_size)
        with bz2.open(self._vocabulary_path, 'rt') as vocabulary:
            print('Read vocabulary')
            self._vocabulary = [x.strip() for x in vocabulary]
        self._indices = {x: i for i, x in enumerate(self._vocabulary)}

    def __iter__(self):
        """Iterate over pages represented as lists of word indices."""
        with bz2.open(self._pages_path, 'rt') as pages:
            for page in pages:
                words = page.strip().split()
                words = [self.encode(x) for x in words]
                yield words

    @property
    def vocabulary_size(self):
        return len(self._vocabulary)

    def encode(self, word):
        """Get the vocabulary index of a string word."""
        return self._indices.get(word, 0)

    def decode(self, index):
        """Get back the string word from a vocabulary index."""
        return self._vocabulary[index]

    def _read_pages(self, url):
        """
        Extract plain words from a Wikipedia dump and store them to the pages
        file. Each page will be a line with words separated by spaces.
        """
        wikipedia_path = download(url, self._cache_dir)
        with bz2.open(wikipedia_path) as wikipedia, \
                bz2.open(self._pages_path, 'wt') as pages:
            for _, element in etree.iterparse(wikipedia, tag='{*}page'):
                if element.find('./{*}redirect') is not None:
                    continue
                page = element.findtext('./{*}revision/{*}text')
                words = self._tokenize(page)
                pages.write(' '.join(words) + '\n')
                element.clear()

    def _build_vocabulary(self, vocabulary_size):
        """
        Count words in the pages file and write a list of the most frequent
        words to the vocabulary file.
        """
        counter = collections.Counter()
        with bz2.open(self._pages_path, 'rt') as pages:
            for page in pages:
                words = page.strip().split()
                counter.update(words)
        common = ['<unk>'] + counter.most_common(vocabulary_size - 1)
        common = [x[0] for x in common]
        with bz2.open(self._vocabulary_path, 'wt') as vocabulary:
            for word in common:
                vocabulary.write(word + '\n')

    @classmethod
    def _tokenize(cls, page):
        words = cls.TOKEN_REGEX.findall(page)
        words = [x.lower() for x in words]
        return words

batch.py

import numpy as np

def batched(iterator, batch_size):
    """Group a numerical stream into batches and yield them as Numpy arrays."""
    while True:
        data = np.zeros(batch_size)
        target = np.zeros(batch_size)
        for index in range(batch_size):
            data[index], target[index] = next(iterator)
        yield data, target

skipgrams.py

import random

def skipgrams(pages, max_context):
    """Form training pairs according to the skip-gram model."""
    for words in pages:
        for index, current in enumerate(words):
            context = random.randint(1, max_context)
            for target in words[max(0, index - context): index]:
                yield current, target
            for target in words[index + 1: index + context + 1]:
                yield current, target

EmeddingModel.py

import tensorflow as tf
import numpy as np
from helpers import lazy_property

class EmbeddingModel:
    def __init__(self, data, target, params):
        self.data = data
        self.target = target
        self.params = params
        self.embeddings
        self.cost
        self.optimize

    @lazy_property
    def embeddings(self):
        initial = tf.random_uniform(
            [self.params.vocabulary_size, self.params.embedding_size],
            -1.0, 1.0)
        return tf.Variable(initial)

    @lazy_property
    def optimize(self):
        optimizer = tf.train.MomentumOptimizer(
            self.params.learning_rate, self.params.momentum)
        return optimizer.minimize(self.cost)

    @lazy_property
    def cost(self):
        embedded = tf.nn.embedding_lookup(self.embeddings, self.data)
        weight = tf.Variable(tf.truncated_normal(
            [self.params.vocabulary_size, self.params.embedding_size],
            stddev=1.0 / self.params.embedding_size ** 0.5))
        bias = tf.Variable(tf.zeros([self.params.vocabulary_size]))
        target = tf.expand_dims(self.target, 1)
        return tf.reduce_mean(tf.nn.nce_loss(
            weight, bias, embedded, target,
            self.params.contrastive_examples,
            self.params.vocabulary_size))


相關文章