用 Sentence Transformers v3 訓練和微調嵌入模型

HuggingFace發表於2024-06-07

Sentence Transformers 是一個 Python 庫,用於使用和訓練各種應用的嵌入模型,例如檢索增強生成 (RAG)、語義搜尋、語義文字相似度、釋義挖掘 (paraphrase mining) 等等。其 3.0 版本的更新是該工程自建立以來最大的一次,引入了一種新的訓練方法。在這篇部落格中,我將向你展示如何使用它來微調 Sentence Transformer 模型,以提高它們在特定任務上的效能。你也可以使用這種方法從頭開始訓練新的 Sentence Transformer 模型。

現在,微調 Sentence Transformers 涉及幾個組成部分,包括資料集、損失函式、訓練引數、評估器以及新的訓練器本身。我將詳細講解每個組成部分,並提供如何使用它們來訓練有效模型的示例。

為什麼進行微調?

微調 Sentence Transformer 模型可以顯著提高它們在特定任務上的效能。這是因為每個任務都需要獨特的相似性概念。讓我們以幾個新聞文章標題為例:

  • “Apple 釋出新款 iPad”
  • “NVIDIA 正在為下一代 GPU 做準備 “

根據用例的不同,我們可能希望這些文字具有相似或不相似的嵌入。例如,一個針對新聞文章的分類模型可能會將這些文字視為相似,因為它們都屬於技術類別。另一方面,一個語義文字相似度或檢索模型應該將它們視為不相似,因為它們具有不同的含義。

訓練元件

訓練 Sentence Transformer 模型涉及以下元件:

  1. 資料集 : 用於訓練和評估的資料。
  2. 損失函式 : 一個量化模型效能並指導最佳化過程的函式。
  3. 訓練引數 (可選): 影響訓練效能和跟蹤/除錯的引數。
  4. 評估器 (可選): 一個在訓練前、中或後評估模型的工具。
  5. 訓練器 : 將模型、資料集、損失函式和其他元件整合在一起進行訓練。

現在,讓我們更詳細地瞭解這些元件。

資料集

SentenceTransformerTrainer 使用 datasets.Datasetdatasets.DatasetDict 例項進行訓練和評估。你可以從 Hugging Face 資料集中心載入資料,或使用各種格式的本地資料,如 CSV、JSON、Parquet、Arrow 或 SQL。

注意: 許多開箱即用的 Sentence Transformers 的 Hugging Face 資料集已經標記為 sentence-transformers ,你可以透過瀏覽 https://huggingface.co/datasets?other=sentence-transformers 輕鬆找到它們。我們強烈建議你瀏覽這些資料集,以找到可能對你任務有用的訓練資料集。

Hugging Face Hub 上的資料

要從 Hugging Face Hub 中的資料集載入資料,請使用 load_dataset 函式:

from datasets import load_dataset

train_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="dev")

print(train_dataset)
"""
Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942069
})
"""

一些資料集,如 sentence-transformers/all-nli,具有多個子集,不同的資料格式。你需要指定子集名稱以及資料集名稱。

本地資料 (CSV, JSON, Parquet, Arrow, SQL)

如果你有常見檔案格式的本地資料,你也可以使用 load_dataset 輕鬆載入:

from datasets import load_dataset

dataset = load_dataset("csv", data_files="my_file.csv")
# or
dataset = load_dataset("json", data_files="my_file.json")

需要預處理的本地資料

如果你的本地資料需要預處理,你可以使用 datasets.Dataset.from_dict 用列表字典初始化你的資料集:

from datasets import Dataset

anchors = []
positives = []
# Open a file, perform preprocessing, filtering, cleaning, etc.
# and append to the lists

dataset = Dataset.from_dict({
    "anchor": anchors,
    "positive": positives,
})

字典中的每個鍵都成為結果資料集中的列。

資料集格式

確保你的資料集格式與你選擇的 損失函式 相匹配至關重要。這包括檢查兩件事:

  1. 如果你的損失函式需要 標籤 (如 損失概覽 表中所指示),你的資料集必須有一個名為“label” “score”的列。
  2. “label”“score” 之外的所有列都被視為 輸入 (如 損失概覽 表中所指示)。這些列的數量必須與你選擇的損失函式的有效輸入數量相匹配。列的名稱無關緊要, 只有它們的順序重要

例如,如果你的損失函式接受 (anchor, positive, negative) 三元組,那麼你的資料集的第一、第二和第三列分別對應於 anchorpositivenegative 。這意味著你的第一和第二列必須包含應該緊密嵌入的文字,而你的第一和第三列必須包含應該遠距離嵌入的文字。這就是為什麼根據你的損失函式,你的資料集列順序很重要的原因。
考慮一個帶有 ["text1", "text2", "label"] 列的資料集,其中 "label" 列包含浮點數相似性得分。這個資料集可以用 CoSENTLossAnglELossCosineSimilarityLoss ,因為:

  1. 資料集有一個“label”列,這是這些損失函式所必需的。
  2. 資料集有 2 個非標籤列,與這些損失函式所需的輸入數量相匹配。

如果你的資料集中的列沒有正確排序,請使用 Dataset.select_columns 來重新排序。此外,使用 Dataset.remove_columns 移除任何多餘的列 (例如, sample_idmetadatasourcetype ),因為否則它們將被視為輸入。

損失函式

損失函式衡量模型在給定資料批次上的表現,並指導最佳化過程。損失函式的選擇取決於你可用的資料和目標任務。請參閱 損失概覽 以獲取完整的選擇列表。

大多數損失函式可以使用你正在訓練的 SentenceTransformer model 來初始化:

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CoSENTLoss

# Load a model to train/finetune
model = SentenceTransformer("FacebookAI/xlm-roberta-base")

# Initialize the CoSENTLoss
# This loss requires pairs of text and a floating point similarity score as a label
loss = CoSENTLoss(model)

# Load an example training dataset that works with our loss function:
train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train")
"""
Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 942069
})
"""

訓練引數

SentenceTransformersTrainingArguments 類允許你指定影響訓練效能和跟蹤/除錯的引數。雖然這些引數是可選的,但實驗這些引數可以幫助提高訓練效率,併為訓練過程提供洞察。

在 Sentence Transformers 的文件中,我概述了一些最有用的訓練引數。我建議你閱讀 訓練概覽 > 訓練引數 部分。

以下是如何初始化 SentenceTransformersTrainingArguments 的示例:

from sentence_transformers.training_args import SentenceTransformerTrainingArguments

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/mpnet-base-all-nli-triplet",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    fp16=True, # Set to False if your GPU can't handle FP16
    bf16=False, # Set to True if your GPU supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES, # Losses using "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="mpnet-base-all-nli-triplet", # Used in W&B if `wandb` is installed
)

注意 eval_strategy 是在 transformers 版本 4.41.0 中引入的。之前的版本應該使用 evaluation_strategy 代替。

評估器

你可以為 SentenceTransformerTrainer 提供一個 eval_dataset 以便在訓練過程中獲取評估損失,但在訓練過程中獲取更具體的指標也可能很有用。為此,你可以使用評估器來在訓練前、中或後評估模型的效能,並使用有用的指標。你可以同時使用 eval_dataset 和評估器,或者只使用其中一個,或者都不使用。它們根據 eval_strategyeval_steps 訓練引數 進行評估。

以下是 Sentence Tranformers 隨附的已實現的評估器:

評估器 所需資料
BinaryClassificationEvaluator 帶有類別標籤的句子對
EmbeddingSimilarityEvaluator 帶有相似性得分的句子對
InformationRetrievalEvaluator 查詢(qid => 問題),語料庫 (cid => 文件),以及相關文件 (qid => 集合[cid])
MSEEvaluator 需要由教師模型嵌入的源句子和需要由學生模型嵌入的目標句子。可以是相同的文字。
ParaphraseMiningEvaluator ID 到句子的對映以及帶有重複句子 ID 的句子對。
RerankingEvaluator {'query': '..', 'positive': [...], 'negative': [...]} 字典的列表。
TranslationEvaluator 兩種不同語言的句子對。
TripletEvaluator (錨點,正面,負面) 三元組。

此外,你可以使用 SequentialEvaluator 將多個評估器組合成一個,然後將其傳遞給 SentenceTransformerTrainer

如果你沒有必要的評估資料但仍然想跟蹤模型在常見基準上的效能,你可以使用 Hugging Face 上的資料與這些評估器一起使用。

使用 STSb 的 Embedding Similarity Evaluator

STS 基準測試 (也稱為 STSb) 是一種常用的基準資料集,用於衡量模型對短文字 (如 “A man is feeding a mouse to a snake.”) 的語義文字相似性的理解。

你可以自由瀏覽 Hugging Face 上的 sentence-transformers/stsb 資料集。

from datasets import load_dataset
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction

# Load the STSB dataset
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")

# Initialize the evaluator
dev_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=eval_dataset["sentence1"],
    sentences2=eval_dataset["sentence2"],
    scores=eval_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)
# Run evaluation manually:
# print(dev_evaluator(model))

# Later, you can provide this evaluator to the trainer to get results during training

使用 AllNLI 的 Triplet Evaluator

AllNLI 是 SNLIMultiNLI 資料集的合併,這兩個資料集都是用於自然語言推理的。這個任務的傳統目的是確定兩段文字是否是蘊含、矛盾還是兩者都不是。它後來被採用用於訓練嵌入模型,因為蘊含和矛盾的句子構成了有用的 (anchor, positive, negative) 三元組: 這是訓練嵌入模型的一種常見格式。

在這個片段中,它被用來評估模型認為錨文字和蘊含文字比錨文字和矛盾文字更相似的頻率。一個示例文字是 “An older man is drinking orange juice at a restaurant.”。

你可以自由瀏覽 Hugging Face 上的 sentence-transformers/all-nli 資料集。

from datasets import load_dataset
from sentence_transformers.evaluation import TripletEvaluator, SimilarityFunction

# Load triplets from the AllNLI dataset
max_samples = 1000
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split=f"dev[:{max_samples}]")

# Initialize the evaluator
dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    main_distance_function=SimilarityFunction.COSINE,
    name=f"all-nli-{max_samples}-dev",
)
# Run evaluation manually:
# print(dev_evaluator(model))

# Later, you can provide this evaluator to the trainer to get results during training

訓練器

SentenceTransformerTrainer 將模型、資料集、損失函式和其他元件整合在一起進行訓練:

from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
    "microsoft/mpnet-base",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MPNet base trained on AllNLI triplets",
    )
)

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"].select(range(100_000))
eval_dataset = dataset["dev"]
test_dataset = dataset["test"]

# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)

# 5. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/mpnet-base-all-nli-triplet",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    fp16=True, # Set to False if GPU can't handle FP16
    bf16=False, # Set to True if GPU supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicates
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="mpnet-base-all-nli-triplet", # Used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    name="all-nli-dev",
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the test set, after training completes
test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="all-nli-test",
)
test_evaluator(model)

# 8. Save the trained model
model.save_pretrained("models/mpnet-base-all-nli-triplet/final")

# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub("mpnet-base-all-nli-triplet")

在這個示例中,我從一個尚未成為 Sentence Transformer 模型的基礎模型 microsoft/mpnet-base 開始進行微調。這需要比微調現有的 Sentence Transformer 模型,如 all-mpnet-base-v2,更多的訓練資料。

執行此指令碼後,tomaarsen/mpnet-base-all-nli-triplet 模型被上傳了。使用餘弦相似度的三元組準確性,即 cosine_similarity(anchor, positive) > cosine_similarity(anchor, negative) 的百分比為開發集上的 90.04% 和測試集上的 91.5% !作為參考,microsoft/mpnet-base 模型在訓練前在開發集上的得分為 68.32%。

所有這些資訊都被自動生成的模型卡儲存,包括基礎模型、語言、許可證、評估結果、訓練和評估資料集資訊、超引數、訓練日誌等。無需任何努力,你上傳的模型應該包含潛在使用者判斷你的模型是否適合他們的所有資訊。

回撥函式

Sentence Transformers 訓練器支援各種 transformers.TrainerCallback 子類,包括:

  • WandbCallback: 如果已安裝 wandb ,則將訓練指標記錄到 W&B
  • TensorBoardCallback: 如果可訪問 tensorboard ,則將訓練指標記錄到 TensorBoard
  • CodeCarbonCallback: 如果已安裝 codecarbon ,則跟蹤訓練期間的碳排放

這些回撥函式會自動使用,無需你進行任何指定,只要安裝了所需的依賴項即可。

有關這些回撥函式的更多資訊以及如何建立你自己的回撥函式,請參閱 Transformers 回撥文件

多資料集訓練

通常情況下,表現最好的模型是透過同時使用多個資料集進行訓練的。SentenceTransformerTrainer 透過允許你使用多個資料集進行訓練,而不需要將它們轉換為相同的格式,簡化了這一過程。你甚至可以為每個資料集應用不同的損失函式。以下是多資料集訓練的步驟:

  1. 使用一個 datasets.Dataset 例項的字典 (或 datasets.DatasetDict) 作為 train_dataseteval_dataset
  2. (可選) 如果你希望為不同的資料集使用不同的損失函式,請使用一個損失函式的字典,其中資料集名稱對映到損失。

每個訓練/評估批次將僅包含來自一個資料集的樣本。從多個資料集中取樣批次的順序由 MultiDatasetBatchSamplers 列舉確定,該列舉可以透過 multi_dataset_batch_sampler 傳遞給 SentenceTransformersTrainingArguments。有效的選項包括:

  • MultiDatasetBatchSamplers.ROUND_ROBIN : 以輪詢方式從每個資料集取樣,直到一個資料集用盡。這種策略可能不會使用每個資料集中的所有樣本,但它確保了每個資料集的平等取樣。
  • MultiDatasetBatchSamplers.PROPORTIONAL (預設): 按比例從每個資料集取樣。這種策略確保了每個資料集中的所有樣本都被使用,並且較大的資料集被更頻繁地取樣。

多工訓練已被證明是高度有效的。例如,Huang et al. 2024 使用了 MultipleNegativesRankingLossCoSENTLossMultipleNegativesRankingLoss 的一個變體 (不包含批次內的負樣本,僅包含硬負樣本),以在中國取得最先進的表現。他們還應用了 MatryoshkaLoss 以使模型能夠產生 Matryoshka Embeddings

以下是多資料集訓練的一個示例:

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import CoSENTLoss, MultipleNegativesRankingLoss, SoftmaxLoss

# 1. Load a model to finetune
model = SentenceTransformer("bert-base-uncased")

# 2. Loadseveral Datasets to train with
# (anchor, positive)
all_nli_pair_train = load_dataset("sentence-transformers/all-nli", "pair", split="train[:10000]")
# (premise, hypothesis) + label
all_nli_pair_class_train = load_dataset("sentence-transformers/all-nli", "pair-class", split="train[:10000]")
# (sentence1, sentence2) + score
all_nli_pair_score_train = load_dataset("sentence-transformers/all-nli", "pair-score", split="train[:10000]")
# (anchor, positive, negative)
all_nli_triplet_train = load_dataset("sentence-transformers/all-nli", "triplet", split="train[:10000]")
# (sentence1, sentence2) + score
stsb_pair_score_train = load_dataset("sentence-transformers/stsb", split="train[:10000]")
# (anchor, positive)
quora_pair_train = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[:10000]")
# (query, answer)
natural_questions_train = load_dataset("sentence-transformers/natural-questions", split="train[:10000]")

# Combine all datasets into a dictionary with dataset names to datasets
train_dataset = {
    "all-nli-pair": all_nli_pair_train,
    "all-nli-pair-class": all_nli_pair_class_train,
    "all-nli-pair-score": all_nli_pair_score_train,
    "all-nli-triplet": all_nli_triplet_train,
    "stsb": stsb_pair_score_train,
    "quora": quora_pair_train,
    "natural-questions": natural_questions_train,
}

# 3. Load several Datasets to evaluate with
# (anchor, positive, negative)
all_nli_triplet_dev = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
# (sentence1, sentence2, score)
stsb_pair_score_dev = load_dataset("sentence-transformers/stsb", split="validation")
# (anchor, positive)
quora_pair_dev = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[10000:11000]")
# (query, answer)
natural_questions_dev = load_dataset("sentence-transformers/natural-questions", split="train[10000:11000]")

# Use a dictionary for the evaluation dataset too, or just use one dataset or none at all
eval_dataset = {
    "all-nli-triplet": all_nli_triplet_dev,
    "stsb": stsb_pair_score_dev,
    "quora": quora_pair_dev,
    "natural-questions": natural_questions_dev,
}

# 4. Load several loss functions to train with
# (anchor, positive), (anchor, positive, negative)
mnrl_loss = MultipleNegativesRankingLoss(model)
# (sentence_A, sentence_B) + class
softmax_loss = SoftmaxLoss(model)
# (sentence_A, sentence_B) + score
cosent_loss = CoSENTLoss(model)

# Create a mapping with dataset names to loss functions, so the trainer knows which loss to apply where
# Note: You can also just use one loss if all your training/evaluation datasets use the same loss
losses = {
    "all-nli-pair": mnrl_loss,
    "all-nli-pair-class": softmax_loss,
    "all-nli-pair-score": cosent_loss,
    "all-nli-triplet": mnrl_loss,
    "stsb": cosent_loss,
    "quora": mnrl_loss,
    "natural-questions": mnrl_loss,
}

# 5. Define a simple trainer, although it's recommended to use one with args & evaluators
trainer = SentenceTransformerTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=losses,
)
trainer.train()

# 6. Save the trained model and optionally push it to the Hugging Face Hub
model.save_pretrained("bert-base-all-nli-stsb-quora-nq")
model.push_to_hub("bert-base-all-nli-stsb-quora-nq")

棄用

在 Sentence Transformer v3 釋出之前,所有模型都會使用 SentenceTransformer.fit 方法進行訓練。從 v3.0 開始,該方法將使用 SentenceTransformerTrainer 作為後端。這意味著你的舊訓練程式碼仍然應該可以工作,甚至可以升級到新的特性,如多 GPU 訓練、損失記錄等。然而,新的訓練方法更加強大,因此建議使用新的方法編寫新的訓練指令碼。

附加資源

訓練示例

以下頁面包含帶有解釋的訓練示例以及程式碼連結。我們建議你瀏覽這些頁面以熟悉訓練迴圈:

  • 語義文字相似度
  • 自然語言推理
  • 釋義
  • Quora 重複問題
  • Matryoshka Embeddings
  • 自適應層模型
  • 多語言模型
  • 模型蒸餾
  • 增強的句子轉換器

文件

此外,以下頁面可能有助於你瞭解 Sentence Transformers 的更多資訊:

  • 安裝
  • 快速入門
  • 使用
  • 預訓練模型
  • 訓練概覽 (本部落格是訓練概覽文件的提煉)
  • 資料集概覽
  • 損失概覽
  • API 參考

最後,以下是一些高階頁面,你可能會感興趣:

  • 超引數最佳化
  • 分散式訓練

英文原文: https://hf.co/blog/train-sentence-transformers

原文作者: Tom Aarsen

譯者: innovation64

相關文章