run style check

Former-commit-id: 5ec33baf5f95df9fa2afe5523c825d3eda8a076b
This commit is contained in:
Eric Tang
2025-01-06 23:55:56 +00:00
committed by hiyouga
parent 8683582300
commit 4f31ad997c
7 changed files with 54 additions and 35 deletions

View File

@@ -15,16 +15,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
from typing import Any, Dict, Optional, Tuple
import json
import yaml
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import torch
import transformers
import yaml
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
@@ -35,21 +34,35 @@ from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, get_current_device
from ..integrations.ray.ray_train_args import RayTrainArguments
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from ..integrations.ray.ray_train_args import RayTrainArguments
logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
_TRAIN_ARGS = [
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
RayTrainArguments,
]
_TRAIN_CLS = Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
RayTrainArguments,
]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
@@ -70,14 +83,17 @@ def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
return {}
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]:
def _parse_args(
parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
args_dict = _read_args(args)
if args_dict:
return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
else:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True)
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(
args=args_dict, return_remaining_strings=True
)
if unknown_args:
print(parser.format_help())
@@ -85,7 +101,6 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,)
def _set_transformers_logging() -> None:
@@ -187,7 +202,7 @@ def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
# Setup logging
if training_args.should_log:
_set_transformers_logging()