fix mixed mm inputs and rlhf-v
Former-commit-id: 7c248fac20bf85d57a91132ce7a793c7f84e9218
This commit is contained in:
@@ -195,6 +195,9 @@ def is_gpu_or_npu_available() -> bool:
|
||||
|
||||
|
||||
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
r"""
|
||||
Casts a torch tensor or a numpy array to a numpy array.
|
||||
"""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = inputs.cpu()
|
||||
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
|
||||
@@ -206,6 +209,9 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
|
||||
|
||||
def skip_check_imports() -> None:
|
||||
r"""
|
||||
Avoids flash attention import error in custom model files.
|
||||
"""
|
||||
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
|
||||
Reference in New Issue
Block a user