Former-commit-id: e6120a937ddb4f3c0b9bcb2466742f5cf4f77f8c
This commit is contained in:
hiyouga
2023-08-23 20:21:15 +08:00
parent 4606340f0f
commit eb9ac9ee1f
4 changed files with 19 additions and 13 deletions

View File

@@ -156,10 +156,9 @@ def get_train_args(
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False
if training_args.optim == "adamw_hf":
training_args.optim = "adamw_torch" # suppress warning
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(ddp_find_unused_parameters=False))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
if (
training_args.resume_from_checkpoint is None
@@ -172,7 +171,9 @@ def get_train_args(
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
logger.info(
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
)