[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -17,7 +17,8 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
r"""
|
||||
Computes and stores the average and current value.
|
||||
"""
|
||||
r"""Compute and store the average and current value."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -75,9 +74,7 @@ class AverageMeter:
|
||||
|
||||
|
||||
def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
r"""
|
||||
Optionally checks the package version.
|
||||
"""
|
||||
r"""Optionally check the package version."""
|
||||
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
|
||||
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
@@ -91,9 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("datasets>=2.16.0,<=3.2.0")
|
||||
check_version("accelerate>=0.34.0,<=1.2.1")
|
||||
@@ -103,10 +98,8 @@ def check_dependencies() -> None:
|
||||
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
|
||||
|
||||
|
||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
r"""
|
||||
Calculates effective tokens per second.
|
||||
"""
|
||||
def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
r"""Calculate effective tokens per second."""
|
||||
effective_token_num = 0
|
||||
for data in dataset:
|
||||
if stage == "sft":
|
||||
@@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
|
||||
return result / dist.get_world_size() if dist.is_initialized() else result
|
||||
|
||||
|
||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
"""
|
||||
def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
|
||||
r"""Return the number of trainable parameters and number of all parameters in the model."""
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
num_params = param.numel()
|
||||
@@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
|
||||
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""
|
||||
Gets the current available device.
|
||||
"""
|
||||
r"""Get the current available device."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
@@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""
|
||||
Gets the number of available GPU or NPU devices.
|
||||
"""
|
||||
r"""Get the number of available GPU or NPU devices."""
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
elif is_torch_npu_available():
|
||||
@@ -180,18 +167,14 @@ def get_device_count() -> int:
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
r"""Get logits processor that removes NaN and Inf logits."""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def get_peak_memory() -> Tuple[int, int]:
|
||||
r"""
|
||||
Gets the peak memory usage for the current device (in Bytes).
|
||||
"""
|
||||
def get_peak_memory() -> tuple[int, int]:
|
||||
r"""Get the peak memory usage for the current device (in Bytes)."""
|
||||
if is_torch_npu_available():
|
||||
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
|
||||
elif is_torch_cuda_available():
|
||||
@@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
|
||||
|
||||
|
||||
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
r"""Check if the path has a tokenized dataset."""
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
"""
|
||||
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
|
||||
if _is_bf16_available and model_dtype == torch.bfloat16:
|
||||
return torch.bfloat16
|
||||
elif _is_fp16_available:
|
||||
@@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
|
||||
|
||||
def is_gpu_or_npu_available() -> bool:
|
||||
r"""
|
||||
Checks if the GPU or NPU is available.
|
||||
"""
|
||||
r"""Check if the GPU or NPU is available."""
|
||||
return is_torch_npu_available() or is_torch_cuda_available()
|
||||
|
||||
|
||||
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
||||
r"""
|
||||
Checks if the environment variable is enabled.
|
||||
"""
|
||||
r"""Check if the environment variable is enabled."""
|
||||
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
|
||||
|
||||
|
||||
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
r"""
|
||||
Casts a torch tensor or a numpy array to a numpy array.
|
||||
"""
|
||||
r"""Cast a torch tensor or a numpy array to a numpy array."""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = inputs.cpu()
|
||||
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
|
||||
@@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
|
||||
|
||||
def skip_check_imports() -> None:
|
||||
r"""
|
||||
Avoids flash attention import error in custom model files.
|
||||
"""
|
||||
r"""Avoid flash attention import error in custom model files."""
|
||||
if not is_env_enabled("FORCE_CHECK_IMPORTS"):
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU or NPU memory.
|
||||
"""
|
||||
r"""Collect GPU or NPU memory."""
|
||||
gc.collect()
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user