基於深度學習技術的AI輸入法引擎

weixin_34208283發表於2018-01-29
本文由 「AI前線」原創,原文連結:基於深度學習技術的AI輸入法引擎
作者| 姚從磊
編輯|Emily、Debra

AI 前線導讀:”目前,幾乎所有的輸入法基本能實現在使用者輸入第一個字後預測使用者接下來輸入的文字,並進行推薦。AI 在輸入法中的應用,能夠通過大量的語言、語義的學習,瞭解人的用語習慣,甚至是性格和思維。

當輸入法可以更加準確地瞭解使用者之後,機器與人類之間的交流的以加深,人機互動的體驗得以提升,進而提高使用者粘度。輸入法引擎 AI 化已成為輸入法產品發展的趨勢。

Kika 早在 2014 年就推出了首款輸入法產品 Kika Keyboard,在海外進入了 140 多個國家,支援 173 種語言,全球使用者數量達 4 億,月活 6000 多萬,這與此產品背後的輸入法引擎有著密不可分的聯絡。


以下內容整理自 AI 前線 2018 年 1 月 25 日社群分享

各位好,我是姚從磊,非常高興能夠有這樣一個跟大家交流的機會。今天主要想為大家介紹一下手機輸入法最核心的模組 - 輸入法引擎的技術方案,為什麼以及如何從傳統 N-gram 引擎演化到深度神經網路引擎的。

主要的內容分為五個部分:

  • 什麼是輸入法引擎;
  • 基於傳統 N-gram 語言模型的輸入法引擎;
  • 為什麼要轉向深度神經網路引擎;
  • 深度神經網路輸入法引擎的那些坑;
  • 高階預測功能。

先用一張圖介紹一下我們公司的情況。

基於深度學習技術的AI輸入法引擎


作為一家面向全球使用者提供 173 種語言輸入法的公司,Kika 利用 AI 技術,為使用者提供了一流的輸入體驗,也在全球獲得了大量的使用者。


基於深度學習技術的AI輸入法引擎


這張圖中列出了目前全球輸入法市場上使用者量較大的產品,背後的公司既包括 Kika、百度、搜狗、Go 以及觸寶這樣國內的公司,也包括 Google(產品為 GBoard)、微軟 (Swiftkey) 等國外大公司。大家都在輸入法引擎的核心技術上投入大量研發精力,期望為全球各國使用者提供一流的輸入體驗。

什麼是輸入法引擎

輸入法 (Input Method,簡稱 IME) 是最常用的工具軟體之一,也常被稱為 Keyboard、鍵盤等。對每種語言,輸入法會提供一個字母佈局 (Layout),上面按照使用者習慣將對應語言的基礎字母放置在合適的位置,比如英文鍵盤的 QWERT、漢語鍵盤的九宮格等。使用者輸入文字其實就是按照順序來敲擊 Layout 上的字母,字母敲擊序列稱為鍵碼序列;在使用者敲擊字母的過程中,鍵碼序列以及之前使用者輸入的詞會被傳入 Layout 下層的「輸入法引擎」,引擎會根據從大規模資料中訓練得到的語言模型,來預測使用者當前以及接下來可能輸入的詞 / 詞序列,並將最可能輸入的詞 / 詞序列在鍵盤的候選區上展示給使用者,供使用者選擇。

例如,如果一位使用者期望輸入的完整文字內容為「What’s the weather today?」,當前輸入到了「weather」的第三個字母「a」,此時詞序列「What’s the 」和鍵碼序列「W h a t ’ s SPACE t h e SAPCE w e a」(SPACE 表示空格) 作為輸入傳送至輸入法引擎,引擎基於訓練好的語言模型進行預測,並將最有可能的候選詞「weather」、「weapon」等展示給使用者,供使用者選擇。在這個 case 中,如果「weather」排在第一位,則可以認為引擎是合格的,可以打 60 分。如果僅輸入到「weather」的第一個字母「w」,就可以將「weather」排在第一位,則可以打 70 分。如果在輸入到「weather」的第一個字母「w」後,就可以直接預測使用者接下來要輸入的詞序列為「weather today?」,那就會更好,可以認為是 90 分。

總的來講,輸入法引擎的功能可以細分為「糾錯」、「補全」和「預測」三類。

  • 所謂糾錯,指的是在使用者輸入一個錯誤的詞,比如「westher」,會自動建議改為「weather」;
  • 所謂補全,指的是輸入一個詞的一部分即預測整體,比如「w」預測「weather」;
  • 所謂「預測」,指的是使用者沒有輸入任何字母時直接預測使用者接下來會輸入什麼,比如輸入「What’s the 」,預測出使用者會輸入「weather today?」。

同時,在拉丁等語系的輸入法中,會提供滑行輸入的功能。

使用者在鍵盤上快速滑行詞的字母序列,即便滑行軌跡有所偏差(因為滑行速度很快,使用者較難準確定位各個字母的位置),也可以準確預測使用者所想輸入的詞。在滑行輸入中,引擎的輸入是滑行點的軌跡,輸出是預測的詞。在本文中,我們不會深入探討滑行輸入的引擎實現邏輯。

更進一步,隨著使用者越來越多地傾向於利用 Emoji、表情圖片等非文字內容表達自己的情感,引擎也需要能夠根據使用者輸入詞 / 鍵碼序列來預測 Emoji 或表情圖片。而 Emoji 往往具有多義性(表情圖片也類似),此類預測的複雜度會更高,我們已經利用基於深度學習的建模技術較好地解決了這一問題。本文不會深入探討,有興趣的小夥伴可以單獨談討。

本文主要討論在手機的按鍵文字輸入場景下,輸入法引擎高效準確地預測的相關技術。

此類技術的演化可以分兩個階段:

1)N-gram 統計語言模型階段;

2)深度神經網路語言模型階段。

前者主要基於大規模語料進行統計,獲取一個詞在 N-1 個片語成的序列 (N-gram) 之後緊鄰出現的條件概率;但由於手機記憶體和 CPU 的限制,僅能對 N 較小 (N <= 3) 的 N-gram 進行計算,預測效果存在明顯天花板。後者通過構建深度神經網路,利用大規模語料資料集進行訓練,不僅可以突破 N-gram 中 N 的限制,且可利用詞與詞的語義關係,準確預測在訓練語料中未出現的詞序列,達到遠超統計語言模型的預測效果。

基於傳統 N-gram 語言模型的輸入法引擎

在輸入「What’s the weather today?」的這個 case 中,當使用者輸入到「weather」的第一個字母「w」時,引擎要做的事情,就是根據前面輸入的詞序列「What’s the」來預測下一個最可能的以「w」打頭的單詞,而其中最關鍵的就是如何預測下一個最可能的詞。

假設輸入詞序列為 w1,⋯,wN-1,預測下一個詞的問題實際上變成了 argmaxWNP(WN |w1,⋯,wN-1 ),這個簡單的模型稱為輸入法引擎的語言模型。

根據條件概率計算公式,P(WN |w1,⋯,wN-1 )=P(w1,⋯,wN-1, WN )/P(w1,⋯,wN-1, WN ),,根據最大似然估計原則,只有在語料資料規模足夠大以至於具備統計意義時,上述概率計算才會具有意義。

但事實上,如果 N 值過大,並不存在「足夠大」的語料資料可以支撐所有概率值的計算;並且,由於 WN 實際上僅同 w1,⋯,wN-1 中的部分詞相關,上述計算會造成大量的計算資源浪費。

因此,實際計算中,一種方式是引入馬爾科夫假設:當前詞出現的概率只與它前面有限的幾個詞有關,來簡化計算。如果當前詞出現的概率只與它前面的 N-1 個詞相關,我們就稱得到的語言模型為 N-gram 模型。常用的 N-gram 模型有 Unigram (N=1),Bigram (N=2),Trigram (N=3)。顯然,隨著 N 的增大,語言模型的資訊量會指數級增加。

為了得到有效的 N-gram 語言模型,一方面需要確保語料資料規模足夠大且有統計意義,另一方面也需要處理「資料稀疏」問題。所謂資料稀疏,指的是詞序列 w1,⋯,wN 並沒有在語料資料中出現,所以導致條件概率 P(WN |w1,⋯,wN-1 ) 為 0 的情況出現。這顯然是不合理的,如果資料規模繼續擴大,這些詞序列可能就會出現。我們可以引入平滑技術來解決資料稀疏問題。平滑技術通過把在訓練語料集中出現過的 N-gram 概率適當減小,而把未出現的概率適當增大,使得全部的 N-gram 概率之和為 1,且全部的 N-gram 概率都不為 0。經典的平滑演算法有很多種,個人推薦 Laplace 平滑和 Good-Turing 平滑技術。

在利用 N-gram 語言模型完成下一個詞的預測後,還需要根據使用者的按鍵序列來對預測的結果進行調整,可以利用編輯距離等衡量序列相似度的方法,將按鍵序列同預測詞的字母序列進行對比,細節不再贅述。

利用 N-gram 語言模型構建的輸入法引擎,在手機端執行時,存在著如下問題:

  1. 不能充分利用詞序列資訊進行預測:受制於手機有限的 CPU 和記憶體資源,N-gram 中的 N 通常都不能太大,基本上 N 為 3 已經是極限。這意味著只能根據最近的 1 到 2 個詞來進行預測,會丟失大量的關鍵資訊;
  2. 不能準確預測語料資料集中未出現的單詞序列。例如,如果在語料資料中出現過「go to work」,而沒有出現過「go to school」。即使使用者輸入「A parents guide to go to s」,引擎也不能準確地將「school」排在候選區靠前的位置。

上述問題,利用深度神經網路技術可以很好地解決。

為什麼要轉向深度神經網路引擎

深度神經網路 (Neep Neural Networks, DNN) 是一種具備至少一個隱層的神經網路,通過調整神經元的連線方式以及網路的層數,可以提供任意複雜度的非線性模型建模能力。基於強大的非線性建模能力,深度神經網路已經在影象識別、語音識別、機器翻譯等領域取得了突破性的進展,並正在自然語言處理、內容推薦等領域得到廣泛的應用。

典型的深度神經網路技術有卷積神經網路 (Convolutional Neural Networks, CNN) 、遞迴神經網路 (Recurrent Neural Network,RNN) 、生成對抗網路 (Generative Adversarial Nets, GAN) 等,分別適用於不同的應用場景。其中,RNN(如圖 1 所示)特別適合序列到序列的預測場

基於深度學習技術的AI輸入法引擎

圖1:RNN網路結構

傳統的神經網路中層與層之間是全連線的,但層間的神經元是沒有連線的(其實是假設各個資料之間是獨立的),這種結構不善於處理序列化預測的問題。在輸入法引擎的場景中,下一個詞往往與前面的詞序列是密切相關的。RNN 通過新增跨越時序的自連線隱藏層,對序列關係進行建模;也就是說,前一個狀態隱藏層的反饋,不僅僅作為本狀態的輸出,而且還進入下一狀態隱層中作為輸入,這樣的網路可以打破獨立假設,得以刻畫序列相關性。

RNN 的優點是可以考慮足夠長的輸入詞序列資訊,每一個輸入詞狀態的資訊可以作為下一個狀態的輸入發揮作用,但這些資訊不一定都是有用的,需要過濾以準確使用。為了實現這個目標,我們使用長短期記憶網路 (Long-Short Memory Networks, LSTM) 對資料進行建模,以實現更準確的預測。

LSTM(圖 2)是一種特殊的 RNN,能夠有選擇性地學習長期的依賴關係。 LSTM 也具有 RNN 鏈結構,但具有不同的網路結構。在 LSTM 中,每個單元都有三個門(輸入門,輸出門和遺忘門)來控制哪部分資訊應該被考慮進行預測。利用 LSTM,不僅可以考慮更長的輸入序列,並且可以利用三種門的引數訓練來自動學習篩選出真正對於預測有價值的輸入詞,而非同等對待整個序列中所有的詞。

基於深度學習技術的AI輸入法引擎

圖 2 LSTM 網路結構

並且,可以在 LSTM 的網路結構中新增嵌入層 (Word Embedding Layer) 來將詞與詞的語義關係加入到訓練和預測過程中。通過 Word Embedding,雖然在語料資料中沒有出現過「go to school」,但是因為「go to work」出現在語料庫中,而通過 Word Embedding 可以發現「work」和「school」具有強烈的語義關聯;這樣,當使用者輸入「A parents guide to go to s」時,引擎會根據「work」和「school」的語義關聯,以及 LSTM 中學習到的「parents」同「school」間存在的預測關係,而準確地向使用者推薦「school」,而非「swimming」。

深度神經網路輸入法引擎的那些坑

從理論上來講,LSTM 可以完美解決 N-gram 語言模型的問題:不僅能夠充分利用詞序列資訊進行預測,還可以準確預測語料資料集中未出現的單詞序列。但是,在實際利用 LSTM 技術實現在手機上可以準確流暢執行的輸入法引擎時,在雲端和客戶端都存在一些坑需要解決。

在雲端,有兩個問題需要重點解決:

充分利用詞序列和鍵碼序列資訊。如前所述,在輸入法引擎的預測過程中,LSTM 的輸入包含詞序列和鍵碼序列兩類不同的序列資訊,需要設計一個完備的 LSTM 網路可以充分利用這兩類資訊。對此,我們經過若干實驗,最終設計了圖 3 所示的兩階段網路結構。在第一個階段,詞序列資訊被充分利用,然後將最後一個詞對應隱層的輸出作為下一個階段的輸入,並和鍵碼序列一起來進行計算,最後通過 Softmax 計算來生成最終結果。同時,在兩階段間加入「Start Flag」,以區隔詞序列和鍵碼序列。

基於深度學習技術的AI輸入法引擎

圖 3 詞序列 / 鍵碼序列混合

高質量訓練資料的生成:在訓練 LSTM 語言模型時,訓練語料的質量和覆蓋度是關鍵因素。從質量角度來講,必須確保其中沒有亂碼、其他語言以及過短的句子等資料。從覆蓋度角度來講,一方面要確保訓練語料的規模,使其能夠覆蓋語言中的大部分詞彙,並足以支援語言模型的統計有效性,一般來講訓練語料的量級應該在千萬或者億;覆蓋度的另一個角度為文字型別,需要確保訓練語料中文字型別 (比如新聞、聊天、搜尋等) 的分佈同目標應用場景一致,對於手機輸入法來講,日常聊天型別的資料應該佔足夠大的比例;覆蓋度的第三個維度為時間維度,需要確保訓練語料可以覆蓋對應國家 / 語言固定時間週期 (通常為年) 中各個時間段的資料,尤其是大型節日的資料。

在客戶端,效能和記憶體是必須解決的關鍵問題。一個優秀的輸入法引擎,在手機端執行時,需要始終穩定地保持低記憶體佔用,確保在 Android Oreo (Go edition) 系統上也可以穩定執行,且保持良好的效能 (每次按鍵響應時間小於 60ms)。而 LSTM 原始模型通常較大 (例如美式英語的模型超過 1G),在手機端執行時響應時間也遠超 1s,需進行大幅優化。可以利用稀疏表示與學習的技術,來壓縮圖 3 LSTM 網路中的 word/ch embedding 矩陣及輸出端 softmax 向量矩陣,同時基於 Kmeans 聚類對模型引數進行自適應量化學習,最終可以將超過 1G 的模型量化壓縮到小於 5M。效能優化則意味著需要控制手機端的計算量,需要在保證效果的前提下優化模型結構,減少不必要的層數和神經元;同時,可基於 TensorFlow Lite (而非 TensorFlow Mobile) 進行手機端計算模組的開發,大幅提升效能和記憶體佔用,唯一的成本是需要自己實現一些必備的 operators。我們採用該方案可以將執行時記憶體佔用控制在 25M 以內,且響應時間保持在 20ms。圖 4 是 TensorFlow Mobile 和 TensorFlow Lite 在相同 benchmark 上的對比資料。

基於深度學習技術的AI輸入法引擎

圖 4 TensorFlow Mobile 和 TensorFlow Lite 在相同 Benchmark 上的對比資料

基於以上的雲端建模和客戶端預測技術,我們完成了基於深度神經網路 (LSTM) 的輸入法引擎方案的整體部署,並在大量語言上,同基於 N-gram 的語言模型進行了對比測試。在對比測試中,我們關注的關鍵指標為輸入效率 (Input Efficiency):

輸入效率 = # 輸入的文字長度 / # 完成文字輸入所需的按鍵次數

我們期望輸入效率越高越好;同時,我們也會關注每種語言對應的線上使用者的回刪率,在此不再贅述。

下圖是在一些語言上 LSTM 引擎相對 N-gram 引擎輸入效率的提升幅度。


基於深度學習技術的AI輸入法引擎

圖 5 LSTM 引擎相對 N-gram 引擎輸入效率的提升幅度

高階預測功能

高階預測功能在第一部分提到過,輸入法引擎也需要能夠準確預測使用者可能輸入的 Emoji 或表情圖片,而這些內容往往具有多義性,因此此類預測的複雜度會更高。

同時,對於 Emoji 來講,使用者往往會創造出一些有趣的 Emoji 組合,例如「?❤️?」,如何自動挖掘出這樣的 Emoji 組合,並將其整合進 LSTM 的模型框架中,也是一個很有趣的問題。

另一方面,從文字輸入效率的角度看,如果每次預測能夠準確預測不只一個詞,而是片語,對使用者的體驗會是一個很大的提升。如何發現有意義的儘可能長的片語,並且整合進 LSTM 的模型框架中,也會是一個很有挑戰的工作。

Q/A 環節

Q1: 請問姚老師,這裡 softmax 輸出是?

基於深度學習技術的AI輸入法引擎

A1: 因為這個網路的目的是預測下一個詞,所以 softmax 的輸出是預測詞的 id 和概率值。在實際產品中,我們會選擇 Top 3 的預測詞,按照概率值從高到底顯示在候選區。

Q2: 輸入法本質上就是根據使用者前面的輸入預測他接下來要輸入什麼。訓練資料集和預測模型是在雲端完成然後定期更新到手機端的嗎?還是完全在手機上端完成的?

A2: 我們的方案是包含兩個部分。對每種語言,我們會在雲端迭代訓練一個新 general language model,在新 model 效果得到離線評測驗證後,下發到手機端。並且,在手機端,會根據每個使用者的個人輸入歷史,來訓練 personalized model (這個 model 的訓練頻率會更高)。在實際預測時,會將這兩個 model 的 inference 結果 merge 起來得到最終結果。在手機端的訓練,需要尤其注意訓練的時機,不能在使用者手機負載高的時候執行訓練。

Q3:TF Lite 對手機有要求嗎?

A3:TF Lite 對手機沒有要求。但是 TF Lite 為了效能的考慮,砍掉了很多的 operators,我們在實現的過程中實現了自己模型 inference 必需的 operators.

Q4: 接上 Q1 的問題,針對英文情況,假如詞表為 1w,softmax 層 1w 個節點,怎麼優化 softmax 層呢?

A4:softmax 層的壓縮,本質上就是 softmax 向量矩陣的壓縮,其原理就是將巨大的向量矩陣轉換為少量的過完備基向量組合,而過完備基向量可以自動學習獲得。

Q5: 請問:每訓練一次模型 LSTM 需要多少個 cell 這個是由什麼決定的?

A5: 每訓練一次模型 LSTM 需要多少個 cell,決定因素大體有兩類:1)我們可以接受的模型的複雜度,這直接決定了最終量化壓縮後的模型的大小;2)我們期望達到的效果。最終的決定主要是在這兩類之間進行平衡。當然,也同語言本身的複雜度有關,比如德語同英語對比,會更加複雜,因此 cell 的數量多一些會更好。如果不考慮這個限制,我們可以通過雲端 service 的方式來進行 inference.

Q6:未來輸入法會支援語音輸入嗎?

A6: 我們正在開發 Kika 的語音識別和語義理解引擎,目前在英文上的語音識別水平接近 Google 的水平,所以會逐步上線 Kika 的語音輸入功能。同時,我們基於 Kika 的一系列語音技術,已經在 CES 2018 釋出了 KikaGO 車載語音解決方案,獲得了很多好評和 CES 的四項大獎,並正在準備產品的正式釋出。我們的全語音解決方案除了為車載場景下提供服務外,還會在場景上做出更多的嘗試。

Q7:”可以在 LSTM 的網路結構中新增嵌入層 (Word Embedding Layer) 來將詞與詞的語義關係加入到訓練和預測過程中” 能具體解釋下,在實際的資料預處理中,是如何新增的?簡單拼接還是啥?LSTM cell 的輸入又是啥?

A7:Word Embedding Layer 的作用是將高維的詞空間對映到低維的向量空間,確保在低維空間上語義相似的詞的向量距離比較小。Word Embedding Layer 的輸出作為 LSTM 的輸入。TensorFlow 本身自帶 Word Embedding Layer,其實就是一個簡單的 lookup table。但如果所處理的問題領域不是 general domain,建議用所在 domain 的語料資料利用 Word2vector 來訓練得到對應的 domain specific word embedding,用來替換掉 TensorFlow 自帶的 Word Embedding Layer。

Q8: 輸入法的研究應該是很有挑戰的領域,尤其是人類語言太多了。能談談這方面的發展趨勢嗎?

A8:「趨勢」都是大拿才可以談的事情,我只談一些自己的淺薄想法。輸入法本質為了解決人類通過機器 (手機、電腦、智慧家居等) 和網路達成的人與人的溝通問題。而這樣的溝通問題,最重要的是能夠達到或超越人與人在現實世界中利用語音、表情和肢體動作面對面溝通的效果,能夠全面、準確、快速地達到意圖、資訊和情感的溝通。

因此,我們認為輸入法的下一步發展一定是圍繞著「全」、「準」、「快」三個方面進行的。「全」是指能夠提供文字、語音、表情、多媒體內容等的溝通方式,「準」指的是使得接收方能夠準確接收到表達方的意圖、資訊和情感,不會產生誤解,「快」指的是表達放產生到接收方接收的時間足夠短。

而「全」、「準」、「快」這三方面的使用者體驗,通過 AI 的技術,都可以大幅提升。

Q9: 我還想提問這樣的 ai 輸入法和百度輸入法怎麼競爭。謝謝。

A9: 產品之間的競爭,本質就是如何為使用者創造價值。Kika 輸入法的定位,就是解決全世界使用者在人與人溝通之間產生的問題,就是為使用者提供極致的「全」、「準」、「快」的溝通方式,為使用者創造更多的價值。從市場劃分上來看,兩款產品也不在一個維度上競爭—國內的輸入法巨頭像搜狗、百度主要解決的是中文,而 kika 則重點為中文之外的其他語種的使用者創造價值。

更多幹貨內容,可關注AI前線,ID:ai-front,後臺回覆「AI」、「TF」、「大資料」可獲得《AI前線》系列PDF迷你書和技能圖譜。


相關文章