@@ -156,10 +156,9 @@ def get_train_args(
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
if training_args.optim == "adamw_hf":
|
||||
training_args.optim = "adamw_torch" # suppress warning
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(ddp_find_unused_parameters=False))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
if (
|
||||
training_args.resume_from_checkpoint is None
|
||||
@@ -172,7 +171,9 @@ def get_train_args(
|
||||
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args.resume_from_checkpoint = last_checkpoint
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
logger.info(
|
||||
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user