[parser] support omegaconf (#7793)

This commit is contained in:
hoshi-hiyouga
2025-04-21 23:30:30 +08:00
committed by GitHub
parent bd7bc31c79
commit 416853dd25
25 changed files with 62 additions and 94 deletions

View File

@@ -24,6 +24,7 @@ from typing import Any, Optional, Union
import torch
import transformers
import yaml
from omegaconf import OmegaConf
from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
@@ -59,10 +60,14 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
if args is not None:
return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return json.loads(Path(sys.argv[1]).absolute().read_text())
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
elif sys.argv[1].endswith(".json"):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = json.loads(Path(sys.argv[1]).absolute().read_text())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
else:
return sys.argv[1:]
@@ -330,12 +335,20 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset
if finetuning_args.finetuning_type == "lora":
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
training_args.label_names = training_args.label_names or ["labels"]
if (
training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
logger.info_rank0("Set `ddp_find_unused_parameters` to False in DDP training since LoRA is enabled.")
training_args.ddp_find_unused_parameters = False
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: