support batch_eval_metrics, fix #4826

Former-commit-id: 3fe1df17188825f8a32fbe6a1294b4b532ce0c85
This commit is contained in:
hiyouga
2024-07-17 00:33:00 +08:00
parent 45367105fc
commit 8c93921952
7 changed files with 85 additions and 36 deletions

View File

@@ -17,7 +17,7 @@
import gc
import os
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Tuple, Union
import torch
import transformers.dynamic_module_utils
@@ -43,6 +43,8 @@ except Exception:
if TYPE_CHECKING:
from numpy.typing import NDArray
from ..hparams import ModelArguments
@@ -178,6 +180,17 @@ def is_gpu_or_npu_available() -> bool:
return is_torch_npu_available() or is_torch_cuda_available()
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
inputs = inputs.to(torch.float32)
inputs = inputs.numpy()
return inputs
def skip_check_imports() -> None:
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports