Former-commit-id: 8c92d268903c00392c8bd75a731daa1f107d6202
This commit is contained in:
@@ -20,7 +20,9 @@ import os
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
import transformers.dynamic_module_utils
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
from transformers.dynamic_module_utils import get_relative_imports
|
||||
from transformers.utils import (
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
@@ -69,6 +71,9 @@ class AverageMeter:
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
@@ -79,7 +84,7 @@ def check_dependencies() -> None:
|
||||
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
"""
|
||||
@@ -108,7 +113,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""
|
||||
Gets the current available device.
|
||||
"""
|
||||
@@ -147,6 +152,13 @@ def get_logits_processor() -> "LogitsProcessorList":
|
||||
return logits_processor
|
||||
|
||||
|
||||
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
@@ -166,11 +178,9 @@ def is_gpu_or_npu_available() -> bool:
|
||||
return is_torch_npu_available() or is_torch_cuda_available()
|
||||
|
||||
|
||||
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
def skip_check_imports() -> None:
|
||||
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
|
||||
Reference in New Issue
Block a user