Former-commit-id: e6120a937ddb4f3c0b9bcb2466742f5cf4f77f8c
This commit is contained in:
hiyouga
2023-08-23 20:21:15 +08:00
parent 4606340f0f
commit eb9ac9ee1f
4 changed files with 19 additions and 13 deletions

View File

@@ -33,10 +33,12 @@ def run_sft(
)
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \
training_args.generation_max_length is not None else data_args.max_target_length
training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.eval_num_beams is not None else training_args.generation_num_beams
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(
generation_max_length=training_args.generation_max_length or data_args.max_target_length,
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = Seq2SeqPeftTrainer(