如何載入本地下載下來的BERT模型,pytorch踩坑!!

ZhangHT97發表於2022-01-28

近期做實驗頻繁用到BERT,所以想著下載下來使用,結果各種問題,網上一搜也是簡單一句:xxx.from_pretrained("改為自己的路徑")
我只想說,大坑!!!
廢話不多說:

1.下載模型檔案:

不管你是從hugging-face還是哪裡下載來的模型(pytorch版)資料夾,應該包含以下三個檔案:

  • config.json
  • vocab.txt
  • pytorch_model.bin

具體都是什麼內容,不做介紹,你也不需要知道

2.更改檔名!!(坑點1)

很多下載的模型資料夾裡面上述三個檔名字可能會有不同,一定要注意!以清華OpenCLaP上下載下來的民事BERT為例,其中包含了三個檔案對應的名字為:

  • bert_config.json 看到沒有!!這個前面多了個bert_,一定要改掉!bert_config.json
  • vocab.txt
  • pytorch_model.bin

三個檔案一定要與第一步中的結構一樣,名字也必須一樣

3.將檔案放入自己的資料夾

這裡我們在自己的工程目錄裡新建一個資料夾:bert_localpath,將三個檔案放入其中,最終結構如下:

bert_localpath

config.json
vocab.txt
pytorch_model.bin

4.載入(坑點2)

使用 .from_pretrained("xxxxx")方法載入,本地載入bert需要修改兩個地方,一是tokenizer部分,二是model部分:
step1、導包: from transformers import BertModel,BertTokenizer
step2、載入詞表: tokenizer = BertTokenizer.from_pretrained("./bert_localpath/") 這裡要注意!!除了你自己建的資料夾名外,後面一定要加個/,才能保證該方法找到你的vocab.txt
step3、載入模型: bert = BertModel.from_pretrained("./bert_localpath") 然後,這個地方又不需要加上/

5.使用

至此,你就能夠使用你的本地bert了!!例如~outputs = bert(input_ids, token_type_ids, attention_mask)來獲得token的編碼輸出output

over,網上很多教程對小白很不友好,記錄一下自己的踩坑,希望能幫到你,如果覺得我寫的有問題的或者太簡單的,可以去看看其他人的

相關文章