Former-commit-id: 263b2b24c8a649b51fa5ae768a24e67def8e0e96
This commit is contained in:
hiyouga
2023-11-19 14:15:47 +08:00
parent 3d1ee27ccd
commit 6889f044fb
8 changed files with 35 additions and 31 deletions

View File

@@ -66,8 +66,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def get_current_device() -> str:
import accelerate
from accelerate import Accelerator
dummy_accelerator = Accelerator()
dummy_accelerator = accelerate.Accelerator()
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index)
else: