mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 08:33:38 +00:00
support RM metrics, add generating Args
Former-commit-id: c461c6190bc124e98dde7f3cf96a59ce40b26fb0
This commit is contained in:
@@ -6,13 +6,14 @@
|
||||
|
||||
|
||||
from utils import (
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
load_pretrained,
|
||||
preprocess_data,
|
||||
PairwiseDataCollatorWithPadding,
|
||||
PairwisePeftTrainer,
|
||||
LogCallback,
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
compute_accuracy,
|
||||
plot_loss
|
||||
)
|
||||
|
||||
@@ -23,7 +24,7 @@ def main():
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, model.pretrained_model)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
|
||||
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
@@ -45,6 +46,7 @@ def main():
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=[LogCallback()],
|
||||
compute_metrics=compute_accuracy,
|
||||
**trainer_kwargs
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user