Former-commit-id: 3f9192dbbbafdc2171d2eb80282d5cae47565b7b
This commit is contained in:
hiyouga
2023-12-03 22:35:47 +08:00
parent 5fe3cce5a3
commit fb4c5f3c91
2 changed files with 20 additions and 5 deletions

View File

@@ -68,6 +68,20 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def get_current_device() -> torch.device:
import accelerate
if accelerate.utils.is_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif accelerate.utils.is_npu_available():
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif torch.cuda.is_available():
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
else:
device = "cpu"
return torch.device(device)
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.