drafting ray integration
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Former-commit-id: 19c12ddae9350f6e25a270fe3372f5b9094cf960
This commit is contained in:
committed by
hiyouga
parent
5ccc607222
commit
8683582300
@@ -19,6 +19,10 @@ import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import json
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
@@ -37,39 +41,51 @@ from .finetuning_args import FinetuningArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
|
||||
from ..integrations.ray.ray_train_args import RayTrainArguments
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
|
||||
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
return args
|
||||
|
||||
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
# read yaml file
|
||||
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# read json file
|
||||
return json.loads(Path(sys.argv[1]).absolute().read_text())
|
||||
else:
|
||||
return {}
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]:
|
||||
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
args_dict = _read_args(args)
|
||||
|
||||
if args_dict:
|
||||
return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
|
||||
else:
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True)
|
||||
|
||||
return (*parsed_args,)
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
return (*parsed_args,)
|
||||
|
||||
|
||||
|
||||
def _set_transformers_logging() -> None:
|
||||
@@ -161,9 +177,17 @@ def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||
def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
|
||||
parser = HfArgumentParser(RayTrainArguments)
|
||||
ray_args = _parse_args(parser, args, allow_extra_keys=True)[0]
|
||||
if ray_args.use_ray:
|
||||
require_version("ray", "To fix: pip install ray")
|
||||
return ray_args
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
_set_transformers_logging()
|
||||
|
||||
Reference in New Issue
Block a user