[parser] support omegaconf (#7793)
This commit is contained in:
@@ -24,6 +24,7 @@ from typing import Any, Optional, Union
|
||||
import torch
|
||||
import transformers
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
@@ -59,10 +60,14 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
|
||||
if args is not None:
|
||||
return args
|
||||
|
||||
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return json.loads(Path(sys.argv[1]).absolute().read_text())
|
||||
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
|
||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
elif sys.argv[1].endswith(".json"):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = json.loads(Path(sys.argv[1]).absolute().read_text())
|
||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
else:
|
||||
return sys.argv[1:]
|
||||
|
||||
@@ -330,12 +335,20 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
|
||||
|
||||
# Post-process training arguments
|
||||
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||
training_args.remove_unused_columns = False # important for multimodal dataset
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
|
||||
training_args.label_names = training_args.label_names or ["labels"]
|
||||
|
||||
if (
|
||||
training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
logger.info_rank0("Set `ddp_find_unused_parameters` to False in DDP training since LoRA is enabled.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||
|
||||
Reference in New Issue
Block a user