[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -17,7 +17,8 @@
import gc
import os
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Literal, Union
import torch
import torch.distributed as dist
@@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
class AverageMeter:
r"""
Computes and stores the average and current value.
"""
r"""Compute and store the average and current value."""
def __init__(self):
self.reset()
@@ -75,9 +74,7 @@ class AverageMeter:
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Optionally checks the package version.
"""
r"""Optionally check the package version."""
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
@@ -91,9 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
r"""Check the version of the required packages."""
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.2.0")
check_version("accelerate>=0.34.0,<=1.2.1")
@@ -103,10 +98,8 @@ def check_dependencies() -> None:
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""
Calculates effective tokens per second.
"""
def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""Calculate effective tokens per second."""
effective_token_num = 0
for data in dataset:
if stage == "sft":
@@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
return result / dist.get_world_size() if dist.is_initialized() else result
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
r"""Return the number of trainable parameters and number of all parameters in the model."""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
@@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
def get_current_device() -> "torch.device":
r"""
Gets the current available device.
"""
r"""Get the current available device."""
if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_npu_available():
@@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
def get_device_count() -> int:
r"""
Gets the number of available GPU or NPU devices.
"""
r"""Get the number of available GPU or NPU devices."""
if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
@@ -180,18 +167,14 @@ def get_device_count() -> int:
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
r"""Get logits processor that removes NaN and Inf logits."""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def get_peak_memory() -> Tuple[int, int]:
r"""
Gets the peak memory usage for the current device (in Bytes).
"""
def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes)."""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
@@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
r"""Check if the path has a tokenized dataset."""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
@@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def is_gpu_or_npu_available() -> bool:
r"""
Checks if the GPU or NPU is available.
"""
r"""Check if the GPU or NPU is available."""
return is_torch_npu_available() or is_torch_cuda_available()
def is_env_enabled(env_var: str, default: str = "0") -> bool:
r"""
Checks if the environment variable is enabled.
"""
r"""Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r"""
Casts a torch tensor or a numpy array to a numpy array.
"""
r"""Cast a torch tensor or a numpy array to a numpy array."""
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
@@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
r"""Avoid flash attention import error in custom model files."""
if not is_env_enabled("FORCE_CHECK_IMPORTS"):
transformers.dynamic_module_utils.check_imports = get_relative_imports
def torch_gc() -> None:
r"""
Collects GPU or NPU memory.
"""
r"""Collect GPU or NPU memory."""
gc.collect()
if is_torch_xpu_available():
torch.xpu.empty_cache()