作者|huggingface 編譯|VK 來源|Github
本節說明如何儲存和重新載入微調模型(BERT,GPT,GPT-2和Transformer-XL)。你需要儲存三種檔案型別才能重新載入經過微調的模型:
- 模型本身應該是PyTorch序列化儲存的模型(pytorch.org/docs/stable…)
- 模型的配置檔案是儲存為JSON檔案
- 詞彙表(以及基於GPT和GPT-2合併的BPE的模型)。
這些檔案的預設檔名如下:
- 模型權重檔案:
pytorch_model.bin
- 配置檔案:
config.json
- 詞彙檔案:
vocab.txt
代表BERT和Transformer-XL,vocab.json
代表GPT/GPT-2(BPE詞彙), - 代表GPT/GPT-2(BPE詞彙)額外的合併檔案:
merges.txt
。
如果使用這些預設檔名儲存模型,則可以使用from_pretrained()方法重新載入模型和tokenizer。
這是儲存模型,配置和配置檔案的推薦方法。詞彙到output_dir
目錄,然後重新載入模型和tokenizer:
from transformers import WEIGHTS_NAME, CONFIG_NAME
output_dir = "./models/"
# 步驟1:儲存一個經過微調的模型、配置和詞彙表
#如果我們有一個分散式模型,只儲存封裝的模型
#它包裝在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model
#如果使用預定義的名稱儲存,則可以使用`from_pretrained`載入
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_dir)
# 步驟2: 重新載入儲存的模型
#Bert模型示例
model = BertForQuestionAnswering.from_pretrained(output_dir)
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case) # Add specific options if needed
#GPT模型示例
model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)複製程式碼
如果要為每種型別的檔案使用特定路徑,則可以使用另一種方法儲存和重新載入模型:
output_model_file = "./models/my_own_model_file.bin"
output_config_file = "./models/my_own_config_file.bin"
output_vocab_file = "./models/my_own_vocab_file.bin"
# 步驟1:儲存一個經過微調的模型、配置和詞彙表
#如果我們有一個分散式模型,只儲存封裝的模型
#它包裝在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file)
# 步驟2: 重新載入儲存的模型
# 我們沒有使用預定義權重名稱、配置名稱進行儲存,無法使用`from_pretrained`進行載入。
# 下面是在這種情況下的操作方法:
#Bert模型示例
config = BertConfig.from_json_file(output_config_file)
model = BertForQuestionAnswering(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)
#GPT模型示例
config = OpenAIGPTConfig.from_json_file(output_config_file)
model = OpenAIGPTDoubleHeadsModel(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = OpenAIGPTTokenizer(output_vocab_file)複製程式碼
原文連結:huggingface.co/transformer…
歡迎關注磐創AI部落格站: panchuang.net/
OpenCV中文官方文件: woshicver.com/
歡迎關注磐創部落格資源彙總站: docs.panchuang.net/