PyTorch預訓練Bert模型

數學程式設計發表於2020-11-17

本文介紹以下內容:

  1. 使用transformers框架做預訓練的bert-base模型;
  2. 開發平臺使用Google的Colab平臺,白嫖GPU加速;
  3. 使用datasets模組下載IMDB影評資料作為訓練資料。

transformers模組簡介

transformers框架為Huggingface開源的深度學習框架,支援幾乎所有的Transformer架構的預訓練模型。使用非常的方便,本文基於此框架,嘗試一下預訓練模型的使用,簡單易用。

本來打算預訓練bert-large模型,發現colab上GPU視訊記憶體不夠用,只能使用base版本了。開啟colab,並且設定好GPU加速,接下來開始介紹程式碼。

程式碼實現

首先安裝資料下載模組和transformers包。

!pip install datasets
!pip install transformers

使用datasets下載IMDB資料,返回DatasetDict型別的資料.返回的資料是文字型別,需要進行編碼。下面會使用tokenizer進行編碼。

from datasets import load_dataset

imdb = load_dataset('imdb')
print(imdb['train'][:3]) # 列印前3條訓練資料

接下來載入tokenizer和模型.從transformers匯入AutoModelForSequenceClassificationAutoTokenizer,建立模型和tokenizer

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

對原始資料進行編碼,並且分批次(batch)


def preprocessing_func(examples):
    return tokenizer(examples['text'], 
                     padding=True,
                     truncation=True, max_length=300)
                     
batch_size = 16

encoded_data = imdb.map(preprocessing_func, batched=True, batch_size=batch_size)

上面得到編碼資料,每個批次設定為16.接下來需要指定訓練的引數,訓練引數的指定使用transformers給出的介面類TrainingArguments,模型的訓練可以使用Trainer

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    'out',
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=5e-5,
    evaluation_strategy='epoch',
    num_train_epochs=10,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model,
    args=args,
    train_dataset=encoded_data['train'],
    eval_dataset=encoded_data['test'],
    tokenizer=tokenizer
)

訓練模型使用trainer物件的train方法

trainer.train()

截圖2020-11-15 下午7.03.55

評估模型使用trainer物件的evaluate方法

trainer.evaluate()

總結

本文介紹了基於transformers框架實現的bert預訓練模型,此框架提供了非常友好的介面,可以方便讀者嘗試各種預訓練模型。同時datasets也提供了很多資料集,便於學習NLP的各種問題。加上Google提供的colab環境,資料下載和預訓練模型下載都非常快,建議讀者自行去煉丹。本文完整的案例下載

相關文章