support batch_eval_metrics, fix #4826
Former-commit-id: 3fe1df17188825f8a32fbe6a1294b4b532ce0c85
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers.dynamic_module_utils
|
||||
@@ -43,6 +43,8 @@ except Exception:
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
|
||||
@@ -178,6 +180,17 @@ def is_gpu_or_npu_available() -> bool:
|
||||
return is_torch_npu_available() or is_torch_cuda_available()
|
||||
|
||||
|
||||
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = inputs.cpu()
|
||||
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
|
||||
inputs = inputs.to(torch.float32)
|
||||
|
||||
inputs = inputs.numpy()
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def skip_check_imports() -> None:
|
||||
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
Reference in New Issue
Block a user