refactor ray integration, support save ckpt
Former-commit-id: 2f50b27e608b2092bfceab6c6e84e6631e973ee2
This commit is contained in:
@@ -19,12 +19,12 @@ import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import yaml
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.training_args import ParallelMode
|
||||
@@ -34,12 +34,12 @@ from transformers.utils.versions import require_version
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from ..integrations.ray.ray_train_args import RayTrainArguments
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
from .training_args import RayArguments, TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -47,60 +47,41 @@ logger = logging.get_logger(__name__)
|
||||
check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
RayTrainArguments,
|
||||
]
|
||||
_TRAIN_CLS = Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
RayTrainArguments,
|
||||
]
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
|
||||
|
||||
def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
|
||||
if args is not None:
|
||||
return args
|
||||
|
||||
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||
# read yaml file
|
||||
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# read json file
|
||||
return json.loads(Path(sys.argv[1]).absolute().read_text())
|
||||
else:
|
||||
return {}
|
||||
return sys.argv[1:]
|
||||
|
||||
|
||||
def _parse_args(
|
||||
parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False
|
||||
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
|
||||
) -> Tuple[Any]:
|
||||
args_dict = _read_args(args)
|
||||
args = read_args(args)
|
||||
if isinstance(args, dict):
|
||||
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
if args_dict:
|
||||
return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
|
||||
else:
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(
|
||||
args=args_dict, return_remaining_strings=True
|
||||
)
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
|
||||
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
return (*parsed_args,)
|
||||
return (*parsed_args,)
|
||||
|
||||
|
||||
def _set_transformers_logging() -> None:
|
||||
@@ -141,7 +122,7 @@ def _verify_model_args(
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
training_args: Optional["Seq2SeqTrainingArguments"] = None,
|
||||
training_args: Optional["TrainingArguments"] = None,
|
||||
) -> None:
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
@@ -177,31 +158,29 @@ def _check_extra_dependencies(
|
||||
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
|
||||
parser = HfArgumentParser(RayTrainArguments)
|
||||
ray_args = _parse_args(parser, args, allow_extra_keys=True)[0]
|
||||
if ray_args.use_ray:
|
||||
require_version("ray", "To fix: pip install ray")
|
||||
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
|
||||
parser = HfArgumentParser(RayArguments)
|
||||
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
||||
return ray_args
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
|
||||
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
@@ -410,7 +389,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
@@ -443,7 +422,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
|
||||
Reference in New Issue
Block a user