[train] fix denominator of ga in ksft loss (#9409)
This commit is contained in:
@@ -83,6 +83,7 @@ def run_sft(
|
||||
**dataset_module,
|
||||
**metric_module,
|
||||
)
|
||||
trainer.model_accepts_loss_kwargs = False
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
|
||||
Reference in New Issue
Block a user