[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -16,14 +16,12 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
from typing import Any, Literal, Optional
@dataclass
class DataArguments:
r"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
template: Optional[str] = field(
default=None,
@@ -162,5 +160,5 @@ class DataArguments:
if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return asdict(self)

View File

@@ -21,9 +21,7 @@ from datasets import DownloadMode
@dataclass
class EvaluationArguments:
r"""
Arguments pertaining to specify the evaluation parameters.
"""
r"""Arguments pertaining to specify the evaluation parameters."""
task: str = field(
metadata={"help": "Name of the evaluation task."},

View File

@@ -13,14 +13,12 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal, Optional
@dataclass
class FreezeArguments:
r"""
Arguments pertaining to the freeze (partial-parameter) training.
"""
r"""Arguments pertaining to the freeze (partial-parameter) training."""
freeze_trainable_layers: int = field(
default=2,
@@ -56,9 +54,7 @@ class FreezeArguments:
@dataclass
class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
r"""Arguments pertaining to the LoRA training."""
additional_target: Optional[str] = field(
default=None,
@@ -128,9 +124,7 @@ class LoraArguments:
@dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO, DPO and KTO training.
"""
r"""Arguments pertaining to the PPO, DPO and KTO training."""
pref_beta: float = field(
default=0.1,
@@ -212,9 +206,7 @@ class RLHFArguments:
@dataclass
class GaloreArguments:
r"""
Arguments pertaining to the GaLore algorithm.
"""
r"""Arguments pertaining to the GaLore algorithm."""
use_galore: bool = field(
default=False,
@@ -253,9 +245,7 @@ class GaloreArguments:
@dataclass
class ApolloArguments:
r"""
Arguments pertaining to the APOLLO algorithm.
"""
r"""Arguments pertaining to the APOLLO algorithm."""
use_apollo: bool = field(
default=False,
@@ -306,9 +296,7 @@ class ApolloArguments:
@dataclass
class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
r"""Arguments pertaining to the BAdam optimizer."""
use_badam: bool = field(
default=False,
@@ -393,9 +381,7 @@ class SwanLabArguments:
class FinetuningArguments(
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
pure_bf16: bool = field(
default=False,
@@ -452,13 +438,13 @@ class FinetuningArguments(
return [item.strip() for item in arg.split(",")]
return arg
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: List[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target)
self.lora_target: list[str] = split_arg(self.lora_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
@@ -499,7 +485,7 @@ class FinetuningArguments(
if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args

View File

@@ -13,16 +13,14 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
from typing import Any, Optional
from transformers import GenerationConfig
@dataclass
class GeneratingArguments:
r"""
Arguments pertaining to specify the decoding parameters.
"""
r"""Arguments pertaining to specify the decoding parameters."""
do_sample: bool = field(
default=True,
@@ -35,7 +33,9 @@ class GeneratingArguments:
top_p: float = field(
default=0.7,
metadata={
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
"help": (
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
)
},
)
top_k: int = field(
@@ -71,7 +71,7 @@ class GeneratingArguments:
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)

View File

@@ -17,7 +17,7 @@
import json
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
from typing import Any, Literal, Optional, Union
import torch
from transformers.training_args import _convert_str_dict
@@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
@dataclass
class BaseModelArguments:
r"""
Arguments pertaining to the model.
"""
r"""Arguments pertaining to the model."""
model_name_or_path: Optional[str] = field(
default=None,
@@ -184,9 +182,7 @@ class BaseModelArguments:
@dataclass
class QuantizationArguments:
r"""
Arguments pertaining to the quantization method.
"""
r"""Arguments pertaining to the quantization method."""
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
@@ -212,9 +208,7 @@ class QuantizationArguments:
@dataclass
class ProcessorArguments:
r"""
Arguments pertaining to the image processor.
"""
r"""Arguments pertaining to the image processor."""
image_max_pixels: int = field(
default=768 * 768,
@@ -244,9 +238,7 @@ class ProcessorArguments:
@dataclass
class ExportArguments:
r"""
Arguments pertaining to the model export.
"""
r"""Arguments pertaining to the model export."""
export_dir: Optional[str] = field(
default=None,
@@ -292,9 +284,7 @@ class ExportArguments:
@dataclass
class VllmArguments:
r"""
Arguments pertaining to the vLLM worker.
"""
r"""Arguments pertaining to the vLLM worker."""
vllm_maxlen: int = field(
default=4096,
@@ -324,8 +314,7 @@ class VllmArguments:
@dataclass
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments):
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first.
"""
@@ -335,7 +324,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
)
device_map: Optional[Union[str, Dict[str, Any]]] = field(
device_map: Optional[Union[str, dict[str, Any]]] = field(
default=None,
init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
@@ -372,7 +361,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
return result
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args

View File

@@ -19,7 +19,7 @@ import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
import transformers
@@ -47,17 +47,15 @@ check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
r"""
Gets arguments from the command line or a config file.
"""
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
r"""Get arguments from the command line or a config file."""
if args is not None:
return args
@@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[
def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
) -> tuple[Any]:
args = read_args(args)
if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
@@ -161,31 +159,31 @@ def _check_extra_dependencies(
check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
@@ -364,9 +362,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
and training_args.resume_from_checkpoint is not None
):
logger.warning_rank0(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
)
f"Add {training_args.resume_from_checkpoint} to `adapter_name_or_path` to resume training from checkpoint."
)
# Post-process model arguments
@@ -382,20 +378,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
# Log on each process the small summary
logger.info(
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format(
training_args.process_index,
training_args.world_size,
training_args.device,
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
str(model_args.compute_dtype),
)
f"Process rank: {training_args.process_index}, "
f"world size: {training_args.world_size}, device: {training_args.device}, "
f"distributed training: {training_args.parallel_mode == ParallelMode.DISTRIBUTED}, "
f"compute dtype: {str(model_args.compute_dtype)}"
)
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
@@ -426,7 +419,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()

View File

@@ -10,9 +10,7 @@ from ..extras.misc import use_ray
@dataclass
class RayArguments:
r"""
Arguments pertaining to the Ray training.
"""
r"""Arguments pertaining to the Ray training."""
ray_run_name: Optional[str] = field(
default=None,
@@ -43,9 +41,7 @@ class RayArguments:
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""
Arguments pertaining to the trainer.
"""
r"""Arguments pertaining to the trainer."""
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)