[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -16,14 +16,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
r"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
|
||||
|
||||
template: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -162,5 +160,5 @@ class DataArguments:
|
||||
if self.mask_history and self.train_on_prompt:
|
||||
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@@ -21,9 +21,7 @@ from datasets import DownloadMode
|
||||
|
||||
@dataclass
|
||||
class EvaluationArguments:
|
||||
r"""
|
||||
Arguments pertaining to specify the evaluation parameters.
|
||||
"""
|
||||
r"""Arguments pertaining to specify the evaluation parameters."""
|
||||
|
||||
task: str = field(
|
||||
metadata={"help": "Name of the evaluation task."},
|
||||
|
||||
@@ -13,14 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreezeArguments:
|
||||
r"""
|
||||
Arguments pertaining to the freeze (partial-parameter) training.
|
||||
"""
|
||||
r"""Arguments pertaining to the freeze (partial-parameter) training."""
|
||||
|
||||
freeze_trainable_layers: int = field(
|
||||
default=2,
|
||||
@@ -56,9 +54,7 @@ class FreezeArguments:
|
||||
|
||||
@dataclass
|
||||
class LoraArguments:
|
||||
r"""
|
||||
Arguments pertaining to the LoRA training.
|
||||
"""
|
||||
r"""Arguments pertaining to the LoRA training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -128,9 +124,7 @@ class LoraArguments:
|
||||
|
||||
@dataclass
|
||||
class RLHFArguments:
|
||||
r"""
|
||||
Arguments pertaining to the PPO, DPO and KTO training.
|
||||
"""
|
||||
r"""Arguments pertaining to the PPO, DPO and KTO training."""
|
||||
|
||||
pref_beta: float = field(
|
||||
default=0.1,
|
||||
@@ -212,9 +206,7 @@ class RLHFArguments:
|
||||
|
||||
@dataclass
|
||||
class GaloreArguments:
|
||||
r"""
|
||||
Arguments pertaining to the GaLore algorithm.
|
||||
"""
|
||||
r"""Arguments pertaining to the GaLore algorithm."""
|
||||
|
||||
use_galore: bool = field(
|
||||
default=False,
|
||||
@@ -253,9 +245,7 @@ class GaloreArguments:
|
||||
|
||||
@dataclass
|
||||
class ApolloArguments:
|
||||
r"""
|
||||
Arguments pertaining to the APOLLO algorithm.
|
||||
"""
|
||||
r"""Arguments pertaining to the APOLLO algorithm."""
|
||||
|
||||
use_apollo: bool = field(
|
||||
default=False,
|
||||
@@ -306,9 +296,7 @@ class ApolloArguments:
|
||||
|
||||
@dataclass
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
Arguments pertaining to the BAdam optimizer.
|
||||
"""
|
||||
r"""Arguments pertaining to the BAdam optimizer."""
|
||||
|
||||
use_badam: bool = field(
|
||||
default=False,
|
||||
@@ -393,9 +381,7 @@ class SwanLabArguments:
|
||||
class FinetuningArguments(
|
||||
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
|
||||
):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
|
||||
|
||||
pure_bf16: bool = field(
|
||||
default=False,
|
||||
@@ -452,13 +438,13 @@ class FinetuningArguments(
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: List[str] = split_arg(self.apollo_target)
|
||||
self.lora_target: list[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: list[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
@@ -499,7 +485,7 @@ class FinetuningArguments(
|
||||
if self.pissa_init:
|
||||
raise ValueError("`pissa_init` is only valid for LoRA training.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
args = asdict(self)
|
||||
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
|
||||
return args
|
||||
|
||||
@@ -13,16 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from transformers import GenerationConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratingArguments:
|
||||
r"""
|
||||
Arguments pertaining to specify the decoding parameters.
|
||||
"""
|
||||
r"""Arguments pertaining to specify the decoding parameters."""
|
||||
|
||||
do_sample: bool = field(
|
||||
default=True,
|
||||
@@ -35,7 +33,9 @@ class GeneratingArguments:
|
||||
top_p: float = field(
|
||||
default=0.7,
|
||||
metadata={
|
||||
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||
"help": (
|
||||
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||
)
|
||||
},
|
||||
)
|
||||
top_k: int = field(
|
||||
@@ -71,7 +71,7 @@ class GeneratingArguments:
|
||||
metadata={"help": "Whether or not to remove special tokens in the decoding."},
|
||||
)
|
||||
|
||||
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
|
||||
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
|
||||
args = asdict(self)
|
||||
if args.get("max_new_tokens", -1) > 0:
|
||||
args.pop("max_length", None)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
|
||||
|
||||
@dataclass
|
||||
class BaseModelArguments:
|
||||
r"""
|
||||
Arguments pertaining to the model.
|
||||
"""
|
||||
r"""Arguments pertaining to the model."""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -184,9 +182,7 @@ class BaseModelArguments:
|
||||
|
||||
@dataclass
|
||||
class QuantizationArguments:
|
||||
r"""
|
||||
Arguments pertaining to the quantization method.
|
||||
"""
|
||||
r"""Arguments pertaining to the quantization method."""
|
||||
|
||||
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
|
||||
default="bitsandbytes",
|
||||
@@ -212,9 +208,7 @@ class QuantizationArguments:
|
||||
|
||||
@dataclass
|
||||
class ProcessorArguments:
|
||||
r"""
|
||||
Arguments pertaining to the image processor.
|
||||
"""
|
||||
r"""Arguments pertaining to the image processor."""
|
||||
|
||||
image_max_pixels: int = field(
|
||||
default=768 * 768,
|
||||
@@ -244,9 +238,7 @@ class ProcessorArguments:
|
||||
|
||||
@dataclass
|
||||
class ExportArguments:
|
||||
r"""
|
||||
Arguments pertaining to the model export.
|
||||
"""
|
||||
r"""Arguments pertaining to the model export."""
|
||||
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -292,9 +284,7 @@ class ExportArguments:
|
||||
|
||||
@dataclass
|
||||
class VllmArguments:
|
||||
r"""
|
||||
Arguments pertaining to the vLLM worker.
|
||||
"""
|
||||
r"""Arguments pertaining to the vLLM worker."""
|
||||
|
||||
vllm_maxlen: int = field(
|
||||
default=4096,
|
||||
@@ -324,8 +314,7 @@ class VllmArguments:
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments):
|
||||
r"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
|
||||
The class on the most right will be displayed first.
|
||||
"""
|
||||
@@ -335,7 +324,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
init=False,
|
||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||
)
|
||||
device_map: Optional[Union[str, Dict[str, Any]]] = field(
|
||||
device_map: Optional[Union[str, dict[str, Any]]] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||
@@ -372,7 +361,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
|
||||
return result
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
args = asdict(self)
|
||||
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
|
||||
return args
|
||||
|
||||
@@ -19,7 +19,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -47,17 +47,15 @@ check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
|
||||
|
||||
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
|
||||
r"""
|
||||
Gets arguments from the command line or a config file.
|
||||
"""
|
||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
||||
r"""Get arguments from the command line or a config file."""
|
||||
if args is not None:
|
||||
return args
|
||||
|
||||
@@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[
|
||||
|
||||
|
||||
def _parse_args(
|
||||
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
|
||||
) -> Tuple[Any]:
|
||||
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
|
||||
) -> tuple[Any]:
|
||||
args = read_args(args)
|
||||
if isinstance(args, dict):
|
||||
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
|
||||
@@ -161,31 +159,31 @@ def _check_extra_dependencies(
|
||||
check_version("rouge_chinese", mandatory=True)
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
|
||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
|
||||
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
|
||||
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
|
||||
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
|
||||
parser = HfArgumentParser(RayArguments)
|
||||
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
||||
return ray_args
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
|
||||
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
@@ -364,9 +362,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
and training_args.resume_from_checkpoint is not None
|
||||
):
|
||||
logger.warning_rank0(
|
||||
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
f"Add {training_args.resume_from_checkpoint} to `adapter_name_or_path` to resume training from checkpoint."
|
||||
)
|
||||
|
||||
# Post-process model arguments
|
||||
@@ -382,20 +378,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
|
||||
# Log on each process the small summary
|
||||
logger.info(
|
||||
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format(
|
||||
training_args.process_index,
|
||||
training_args.world_size,
|
||||
training_args.device,
|
||||
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
f"Process rank: {training_args.process_index}, "
|
||||
f"world size: {training_args.world_size}, device: {training_args.device}, "
|
||||
f"distributed training: {training_args.parallel_mode == ParallelMode.DISTRIBUTED}, "
|
||||
f"compute dtype: {str(model_args.compute_dtype)}"
|
||||
)
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
|
||||
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
@@ -426,7 +419,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
|
||||
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
|
||||
@@ -10,9 +10,7 @@ from ..extras.misc import use_ray
|
||||
|
||||
@dataclass
|
||||
class RayArguments:
|
||||
r"""
|
||||
Arguments pertaining to the Ray training.
|
||||
"""
|
||||
r"""Arguments pertaining to the Ray training."""
|
||||
|
||||
ray_run_name: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -43,9 +41,7 @@ class RayArguments:
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
||||
r"""
|
||||
Arguments pertaining to the trainer.
|
||||
"""
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
def __post_init__(self):
|
||||
Seq2SeqTrainingArguments.__post_init__(self)
|
||||
|
||||
Reference in New Issue
Block a user