分類任務
import numpy as np import evaluate metric = evaluate.load("accuracy") def compute_metrics(eval_pred): logits, labels = eval_pred predictions = np.argmax(logits, axis=-1) return metric.compute(predictions=predictions, references=labels)
logits
是模型的輸出
labels是真實標籤
用 numpy
的 argmax
函式沿著最後一個維度(即每個樣本的類別維度)找到分數最大的索引,這些索引即為模型的預測類別
返回準確率
生成任務
BLEU、ROUGE、METEOR 等,這些指標用於比較生成的文字和參考文字