drafting ray integration

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>

Former-commit-id: 19c12ddae9350f6e25a270fe3372f5b9094cf960
This commit is contained in:
Kourosh Hakhamaneshi
2024-12-30 16:48:52 -08:00
committed by hiyouga
parent 5ccc607222
commit 8683582300
9 changed files with 143 additions and 21 deletions

View File

@@ -19,6 +19,10 @@ import os
import sys
from typing import Any, Dict, Optional, Tuple
import json
import yaml
from pathlib import Path
import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
@@ -37,39 +41,51 @@ from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from ..integrations.ray.ray_train_args import RayTrainArguments
logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
if args is not None:
return parser.parse_dict(args)
return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
# read yaml file
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# read json file
return json.loads(Path(sys.argv[1]).absolute().read_text())
else:
return {}
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]:
if unknown_args:
print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
args_dict = _read_args(args)
if args_dict:
return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
else:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True)
return (*parsed_args,)
if unknown_args:
print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,)
def _set_transformers_logging() -> None:
@@ -161,9 +177,17 @@ def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
parser = HfArgumentParser(RayTrainArguments)
ray_args = _parse_args(parser, args, allow_extra_keys=True)[0]
if ray_args.use_ray:
require_version("ray", "To fix: pip install ray")
return ray_args
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
# Setup logging
if training_args.should_log:
_set_transformers_logging()