基於keras的BiLstm與CRF實現命名實體標註

帥蟲哥發表於2018-03-26

眾所周知,通過Bilstm已經可以實現分詞或命名實體標註了,同樣地單獨的CRF也可以很好的實現。既然LSTM都已經可以預測了,為啥要搞一個LSTM+CRF的hybrid model? 因為單獨LSTM預測出來的標註可能會出現(I-Organization->I-Person,B-Organization ->I-Person)這樣的問題序列。

但這種錯誤在CRF中是不存在的,因為CRF的特徵函式的存在就是為了對輸入序列觀察、學習各種特徵,這些特徵就是在限定視窗size下的各種詞之間的關係。

將CRF接在LSTM網路的輸出結果後,讓LSTM負責在CRF的特徵限定下,依照新的loss function,學習出新的模型。

基於字的模型標註:

假定我們使用Bakeoff-3評測中所採用的的BIO標註集,即B-PER、I-PER代表人名首字、人名非首字,B-ORG、I-ORG代表組織機構名首字、組織機構名非首字,O代表該字不屬於命名實體的一部分

  • B-Person
  • I- Person
  • B-Organization
  • I-Organization
  • O

加入CRF layer對LSTM網路輸出結果的影響

為直觀的看到加入後的區別我們可以借用網路中的圖來表示:其中(x)表示輸入的句子,包含5個字分別用(w_1),(w_2),(w_3),(w_4),(w_5)表示

沒有CRF layer的網路示意圖

Figure 1.3: The BiLSTM model with out CRF layer output correct labels

含有CRF layer的網路輸出示意圖

Figure 1.2: The meaning of outputs of BiLSTM layer

上圖可以看到在沒有CRF layer的情況下出現了 B-Person->I-Person 的序列,而在有CRF layer層的網路中,我們將 LSTM 的輸出再次送入CRF layer中計算新的結果。而在CRF layer中會加入一些限制,以排除可能會出現上文所提及的不合法的情況

CRF loss function

CRF loss function 如下:
Loss Function = (frac{P_{RealPath}}{P_1 + P_2 + … + P_N})

主要包括兩個部分Real path score 和 total path scroe

1、Real path score

(P_{RealPath}) =(e^{S_i})

因此重點在於求出:

(S_i) = EmissionScore + TransitionScore

EmissionScore=(x_{0,START}+x_{1,B-Person}+x_{2,I-Person}+x_{3,O}+x_{4,B-Organization}+x_{5,O}+x_{6,END})

2018-03-26-16-32-18

因此根據轉移概率和發射概率很容易求出(P_{RealPath})

2、total score

total scroe的計算相對比較複雜,可參看https://createmomo.github.io/2017/11/11/CRF-Layer-on-the-Top-of-BiLSTM-5/

實現程式碼(keras版本)

1、搭建網路模型

使用2.1.4版本的keras,在keras版本里面已經包含bilstm模型,但crf的loss function還沒有,不過可以從keras contribute中獲得,具體可參看:https://github.com/keras-team/keras-contrib

構建網路模型程式碼如下:

    model = Sequential()
    model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True))  # Random embedding
    model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True)))
    crf = CRF(len(chunk_tags), sparse_target=True)
    model.add(crf)
    model.summary()
    model.compile(`adam`, loss=crf.loss_function, metrics=[crf.accuracy])

2、清洗資料

清晰資料是最麻煩的一步,首先我們採用網上開源的語料庫作為訓練和測試資料。語料庫中已經做好了標記,其格式如下:

月 O
油 O
印 O
的 O
《 O
北 B-LOC
京 I-LOC
文 O
物 O
保 O
存 O
保 O
管 O

語料庫中對每一個字分別進行標記,比較包括如下幾種:

`O`, `B-PER`, `I-PER`, `B-LOC`, `I-LOC`, "B-ORG", "I-ORG"

分別表示,其他,人名第一個,人名非第一個,位置第一個,位置非第一個,組織第一個,非組織第一個

    train = _parse_data(open(`data/train_data.data`, `rb`))
    test = _parse_data(open(`data/test_data.data`, `rb`))

    word_counts = Counter(row[0].lower() for sample in train for row in sample)
    vocab = [w for w, f in iter(word_counts.items()) if f >= 2]
    chunk_tags = [`O`, `B-PER`, `I-PER`, `B-LOC`, `I-LOC`, "B-ORG", "I-ORG"]

    # save initial config data
    with open(`model/config.pkl`, `wb`) as outp:
        pickle.dump((vocab, chunk_tags), outp)

    train = _process_data(train, vocab, chunk_tags)
    test = _process_data(test, vocab, chunk_tags)
    return train, test, (vocab, chunk_tags)

3、訓練資料

在處理好資料後可以訓練資料,本文中將batch-size=16獲得較為高的accuracy(99%左右),進行了10個epoch的訓練。

import bilsm_crf_model

EPOCHS = 10
model, (train_x, train_y), (test_x, test_y) = bilsm_crf_model.create_model()
# train model
model.fit(train_x, train_y,batch_size=16,epochs=EPOCHS, validation_data=[test_x, test_y])
model.save(`model/crf.h5`)

4、驗證資料

import bilsm_crf_model
import process_data
import numpy as np

model, (vocab, chunk_tags) = bilsm_crf_model.create_model(train=False)
predict_text = `中華人民共和國國務院總理周恩來在外交部長陳毅的陪同下,連續訪問了衣索比亞等非洲10國以及阿爾巴尼亞`
str, length = process_data.process_data(predict_text, vocab)
model.load_weights(`model/crf.h5`)
raw = model.predict(str)[0][-length:]
result = [np.argmax(row) for row in raw]
result_tags = [chunk_tags[i] for i in result]

per, loc, org = ``, ``, ``

for s, t in zip(predict_text, result_tags):
    if t in (`B-PER`, `I-PER`):
        per += ` ` + s if (t == `B-PER`) else s
    if t in (`B-ORG`, `I-ORG`):
        org += ` ` + s if (t == `B-ORG`) else s
    if t in (`B-LOC`, `I-LOC`):
        loc += ` ` + s if (t == `B-LOC`) else s

print([`person:` + per, `location:` + loc, `organzation:` + org])

輸出結果如下:

[`person: 周恩來 陳毅, 王東`, `location: 衣索比亞 非洲 阿爾巴尼亞`, `organzation: 中華人民共和國國務院 外交部`]

原始碼地址:https://github.com/stephen-v/zh-NER-keras

相關文章