Former-commit-id: 243a596518ad69cf1eec20a082534b9e94353ce4
This commit is contained in:
hiyouga
2023-12-03 11:33:12 +08:00
parent b052574ddf
commit 4a14099cfd
2 changed files with 4 additions and 15 deletions

View File

@@ -68,18 +68,6 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def get_current_device() -> str:
import accelerate
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif accelerate.utils.is_npu_available():
return "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif torch.cuda.is_available():
return "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
else:
return "cpu"
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.