Former-commit-id: 8c92d268903c00392c8bd75a731daa1f107d6202
This commit is contained in:
hiyouga
2024-06-30 21:28:51 +08:00
parent 889c042ecd
commit 188b4be64d
4 changed files with 24 additions and 9 deletions

View File

@@ -20,7 +20,9 @@ import os
from typing import TYPE_CHECKING, Tuple
import torch
import transformers.dynamic_module_utils
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.dynamic_module_utils import get_relative_imports
from transformers.utils import (
is_torch_bf16_gpu_available,
is_torch_cuda_available,
@@ -69,6 +71,9 @@ class AverageMeter:
def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
@@ -79,7 +84,7 @@ def check_dependencies() -> None:
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
@@ -108,7 +113,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def get_current_device() -> torch.device:
def get_current_device() -> "torch.device":
r"""
Gets the current available device.
"""
@@ -147,6 +152,13 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
@@ -166,11 +178,9 @@ def is_gpu_or_npu_available() -> bool:
return is_torch_npu_available() or is_torch_cuda_available()
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def skip_check_imports() -> None:
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports
def torch_gc() -> None: