run style check
Former-commit-id: 5ec33baf5f95df9fa2afe5523c825d3eda8a076b
This commit is contained in:
@@ -15,16 +15,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import json
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import yaml
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
@@ -35,21 +34,35 @@ from transformers.utils.versions import require_version
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from ..integrations.ray.ray_train_args import RayTrainArguments
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
|
||||
from ..integrations.ray.ray_train_args import RayTrainArguments
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
|
||||
_TRAIN_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
RayTrainArguments,
|
||||
]
|
||||
_TRAIN_CLS = Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
RayTrainArguments,
|
||||
]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
@@ -70,14 +83,17 @@ def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]:
|
||||
|
||||
def _parse_args(
|
||||
parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False
|
||||
) -> Tuple[Any]:
|
||||
args_dict = _read_args(args)
|
||||
|
||||
|
||||
if args_dict:
|
||||
return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
|
||||
else:
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True)
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(
|
||||
args=args_dict, return_remaining_strings=True
|
||||
)
|
||||
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
@@ -85,7 +101,6 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
return (*parsed_args,)
|
||||
|
||||
|
||||
|
||||
def _set_transformers_logging() -> None:
|
||||
@@ -187,7 +202,7 @@ def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
|
||||
|
||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
|
||||
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
_set_transformers_logging()
|
||||
|
||||
Reference in New Issue
Block a user