如何在Python中快速進行語料庫搜尋:近似最近鄰演算法

路雪發表於2018-01-24

最近,我一直在研究在 GloVe 詞嵌入中做加減法。例如,我們可以把「king」的詞嵌入向量減去「man」的詞嵌入向量,隨後加入「woman」的詞嵌入得到一個結果向量。隨後,如果我們有這些詞嵌入對應的語料庫,那麼我們可以透過搜尋找到最相似的嵌入並檢索相應的詞。如果我們做了這樣的查詢,我們會得到:


  1. King + (Woman - Man) = Queen


我們有很多方法來搜尋語料庫中詞嵌入對作為最近鄰查詢方式。絕對可以確保找到最優向量的方式是遍歷你的語料庫,比較每個對與查詢需求的相似程度——這當然是耗費時間且不推薦的。一個更好的技術是使用向量化餘弦距離方式,如下所示:

  1. vectors = np.array(embeddingmodel.embeddings)

  2. ranks = np.dot(query,vectors.T)/np.sqrt(np.sum(vectors**2,1))

  3. mostSimilar = []

  4. [mostSimilar.append(idx) for idx in ranks.argsort()[::-1]]



想要了解餘弦距離,可以看看這篇文章:http://masongallo.github.io/machine/learning,/python/2016/07/29/cosine-similarity.html


向量化的餘弦距離比迭代法快得多,但速度可能太慢。是近似最近鄰搜尋演算法該出現時候了:它可以快速返回近似結果。很多時候你並不需要準確的最佳結果,例如:「Queen」這個單詞的同義詞是什麼?在這種情況下,你只需要快速得到足夠好的結果,你需要使用近似最近鄰搜尋演算法。


在本文中,我們將會介紹一個簡單的 Python 指令碼來快速找到近似最近鄰。我們會使用的 Python 庫是 Annoy 和 Imdb。對於我的語料庫,我會使用詞嵌入對,但該說明實際上適用於任何型別的嵌入:如音樂推薦引擎需要用到的歌曲嵌入,甚至以圖搜圖中的圖片嵌入。


製作一個索引


讓我們建立一個名為:「make_annoy_index」的 Python 指令碼。首先我們需要加入用得到的依賴項:


  1. '''

  2. Usage: python2 make_annoy_index.py \

  3. --embeddings=<embedding path> \

  4. --num_trees=<int> \

  5. --verbose

  6. Generate an Annoy index and lmdb map given an embedding file

  7. Embedding file can be

  8. 1. A .bin file that is compatible with word2vec binary formats.

  9. There are pre-trained vectors to download at https://code.google.com/p/word2vec/

  10. 2. A .gz file with the GloVe format (item then a list of floats in plaintext)

  11. 3. A plain text file with the same format as above

  12. '''

  13. import annoy

  14. import lmdb

  15. import os

  16. import sys

  17. import argparse

  18. from vector_utils import get_vectors



最後一行裡非常重要的是「vector_utils」。稍後我們會寫「vector_utils」,所以不必擔心。


接下來,讓我們豐富這個指令碼:加入「creat_index」函式。這裡我們將生成 lmdb 圖和 Annoy 索引。


1. 首先需要找到嵌入的長度,它會被用來做例項化 Annoy 的索引。

2. 接下來例項化一個 Imdb 圖,使用:「env = lmdb.open(fn_lmdb, map_size=int(1e9))」。

3. 確保我們在當前路徑中沒有 Annoy 索引或 lmdb 圖。

4. 將嵌入檔案中的每一個 key 和向量新增至 lmdb 圖和 Annoy 索引。

5. 構建和儲存 Annoy 索引。


  1. '''

  2. function create_index(fn, num_trees=30, verbose=False)

  3. -------------------------------

  4. Creates an Annoy index and lmdb map given an embedding file fn

  5. Input:

  6. fn - filename of the embedding file

  7. num_trees - number of trees to build Annoy index with

  8. verbose - log status

  9. Return:

  10. Void

  11. '''

  12. def create_index(fn, num_trees=30, verbose=False):

  13. fn_annoy = fn + '.annoy'

  14. fn_lmdb = fn + '.lmdb' # stores word <-> id mapping

  15. word, vec = get_vectors(fn).next()

  16. size = len(vec)

  17. if verbose:

  18. print("Vector size: {}".format(size))

  19. env = lmdb.open(fn_lmdb, map_size=int(1e9))

  20. if not os.path.exists(fn_annoy) or not os.path.exists(fn_lmdb):

  21. i = 0

  22. a = annoy.AnnoyIndex(size)

  23. with env.begin(write=True) as txn:

  24. for word, vec in get_vectors(fn):

  25. a.add_item(i, vec)

  26. id = 'i%d' % i

  27. word = 'w' + word

  28. txn.put(id, word)

  29. txn.put(word, id)

  30. i += 1

  31. if verbose:

  32. if i % 1000 == 0:

  33. print(i, '...')

  34. if verbose:

  35. print("Starting to build")

  36. a.build(num_trees)

  37. if verbose:

  38. print("Finished building")

  39. a.save(fn_annoy)

  40. if verbose:

  41. print("Annoy index saved to: {}".format(fn_annoy))

  42. print("lmdb map saved to: {}".format(fn_lmdb))

  43. else:

  44. print("Annoy index and lmdb map already in path")



我已經推斷出 argparse,因此,我們可以利用命令列啟用我們的指令碼:


  1. '''

  2. private function _create_args()

  3. -------------------------------

  4. Creates an argeparse object for CLI for create_index() function

  5. Input:

  6. Void

  7. Return:

  8. args object with required arguments for threshold_image() function

  9. '''

  10. def _create_args():

  11. parser = argparse.ArgumentParser()

  12. parser.add_argument("--embeddings", help="filename of the embeddings", type=str)

  13. parser.add_argument("--num_trees", help="number of trees to build index with", type=int)

  14. parser.add_argument("--verbose", help="print logging", action="store_true")

  15. args = parser.parse_args()

  16. return args



新增主函式以啟用指令碼,得到 make_annoy_index.py:

  1. if __name__ == '__main__':

  2. args = _create_args()

  3. create_index(args.embeddings, num_trees=args.num_trees, verbose=args.verbose)



現在我們可以僅利用命令列啟用新指令碼,以生成 Annoy 索引和對應的 lmdb 圖!

  1. python2 make_annoy_index.py \

  2. --embeddings=<embedding path> \

  3. --num_trees=<int> \

  4. --verbose



寫向量Utils


我們在 make_annoy_index.py 中推匯出 Python 指令碼 vector_utils。現在要寫該指令碼,Vector_utils 用於幫助讀取.txt, .bin 和 .pkl 檔案中的向量。


寫該指令碼與我們現在在做的不那麼相關,因此我已經推匯出整個指令碼,如下:

  1. '''

  2. Vector Utils

  3. Utils to read in vectors from txt, .bin, or .pkl.

  4. Taken from Erik Bernhardsson

  5. Source: https://github.com/erikbern/ann-presentation/blob/master/util.py

  6. '''

  7. import gzip

  8. import struct

  9. import cPickle

  10. def _get_vectors(fn):

  11. if fn.endswith('.gz'):

  12. f = gzip.open(fn)

  13. fn = fn[:-3]

  14. else:

  15. f = open(fn)

  16. if fn.endswith('.bin'): # word2vec format

  17. words, size = (int(x) for x in f.readline().strip().split())

  18. t = 'f' * size

  19. while True:

  20. pos = f.tell()

  21. buf = f.read(1024)

  22. if buf == '' or buf == '\n': return

  23. i = buf.index(' ')

  24. word = buf[:i]

  25. f.seek(pos + i + 1)

  26. vec = struct.unpack(t, f.read(4 * size))

  27. yield word.lower(), vec

  28. elif fn.endswith('.txt'): # Assume simple text format

  29. for line in f:

  30. items = line.strip().split()

  31. yield items[0], [float(x) for x in items[1:]]

  32. elif fn.endswith('.pkl'): # Assume pickle (MNIST)

  33. i = 0

  34. for pics, labels in cPickle.load(f):

  35. for pic in pics:

  36. yield i, pic

  37. i += 1

  38. def get_vectors(fn, n=float('inf')):

  39. i = 0

  40. for line in _get_vectors(fn):

  41. yield line

  42. i += 1

  43. if i >= n:

  44. break



測試 Annoy 索引和 lmdb 圖


我們已經生成了 Annoy 索引和 lmdb 圖,現在我們來寫一個指令碼使用它們進行推斷。


將我們的檔案命名為 annoy_inference.py,得到下列依賴項:


  1. '''

  2. Usage: python2 annoy_inference.py \

  3. --token='hello' \

  4. --num_results=<int> \

  5. --verbose

  6. Query an Annoy index to find approximate nearest neighbors

  7. '''

  8. import annoy

  9. import lmdb

  10. import argparse



現在我們需要在 Annoy 索引和 lmdb 圖中載入依賴項,我們將進行全域性載入,以方便訪問。注意,這裡設定的 VEC_LENGTH 為 50。確保你的 VEC_LENGTH 與嵌入長度匹配,否則 Annoy 會不開心的哦~


  1. VEC_LENGTH = 50

  2. FN_ANNOY = 'glove.6B.50d.txt.annoy'

  3. FN_LMDB = 'glove.6B.50d.txt.lmdb'

  4. a = annoy.AnnoyIndex(VEC_LENGTH)

  5. a.load(FN_ANNOY)

  6. env = lmdb.open(FN_LMDB, map_size=int(1e9))



有趣的部分在於「calculate」函式。


1. 從 lmdb 圖中獲取查詢索引;

2. 用 get_item_vector(id) 獲取 Annoy 對應的向量;

3. 用 a.get_nns_by_vector(v, num_results) 獲取 Annoy 的最近鄰。


  1. '''

  2. private function calculate(query, num_results)

  3. -------------------------------

  4. Queries a given Annoy index and lmdb map for num_results nearest neighbors

  5. Input:

  6. query - query to be searched

  7. num_results - the number of results

  8. Return:

  9. ret_keys - list of num_results nearest neighbors keys

  10. '''

  11. def calculate(query, num_results, verbose=False):

  12. ret_keys = []

  13. with env.begin() as txn:

  14. id = int(txn.get('w' + query)[1:])

  15. if verbose:

  16. print("Query: {}, with id: {}".format(query, id))

  17. v = a.get_item_vector(id)

  18. for id in a.get_nns_by_vector(v, num_results):

  19. key = txn.get('i%d' % id)[1:]

  20. ret_keys.append(key)

  21. if verbose:

  22. print("Found: {} results".format(len(ret_keys)))

  23. return ret_keys



再次,這裡使用 argparse 來使讀取命令列引數更加簡單。


  1. '''

  2. private function _create_args()

  3. -------------------------------

  4. Creates an argeparse object for CLI for calculate() function

  5. Input:

  6. Void

  7. Return:

  8. args object with required arguments for threshold_image() function

  9. '''

  10. def _create_args():

  11. parser = argparse.ArgumentParser()

  12. parser.add_argument("--token", help="query word", type=str)

  13. parser.add_argument("--num_results", help="number of results to return", type=int)

  14. parser.add_argument("--verbose", help="print logging", action="store_true")

  15. args = parser.parse_args()

  16. return args



主函式從命令列中啟用 annoy_inference.py。


  1. if __name__ == '__main__':

  2. args = _create_args()

  3. print(calculate(args.token, args.num_results, args.verbose))


現在我們可以使用 Annoy 索引和 lmdb 圖,獲取查詢的最近鄰!


  1. python2 annoy_inference.py --token="test" --num_results=30

  2. ['test', 'tests', 'determine', 'for', 'crucial', 'only', 'preparation', 'needed', 'positive', 'guided', 'time', 'performance', 'one', 'fitness', 'replacement', 'stages', 'made', 'both', 'accuracy', 'deliver', 'put', 'standardized', 'best', 'discovery', '.', 'a', 'diagnostic', 'delayed', 'while', 'side']


程式碼


本教程所有程式碼的 GitHub 地址:https://github.com/kyang6/annoy_tutorial

相關文章