support batch_eval_metrics, fix #4826
Former-commit-id: 3fe1df17188825f8a32fbe6a1294b4b532ce0c85
This commit is contained in:
@@ -23,7 +23,7 @@ from ...extras.misc import get_logits_processor
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
|
||||
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||
from .trainer import CustomSeq2SeqTrainer
|
||||
|
||||
|
||||
@@ -46,15 +46,12 @@ def run_sft(
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left" # use left-padding in generation
|
||||
|
||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
block_diag_attn=model_args.block_diag_attn,
|
||||
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
||||
@@ -66,6 +63,14 @@ def run_sft(
|
||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
|
||||
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
if training_args.predict_with_generate:
|
||||
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
|
||||
elif finetuning_args.compute_accuracy:
|
||||
metric_module["compute_metrics"] = ComputeAccuracy()
|
||||
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
@@ -73,10 +78,9 @@ def run_sft(
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
|
||||
preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
)
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
@@ -95,6 +99,9 @@ def run_sft(
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left" # use left-padding in generation
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user