@@ -66,8 +66,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
|
||||
def get_current_device() -> str:
|
||||
import accelerate
|
||||
from accelerate import Accelerator
|
||||
dummy_accelerator = Accelerator()
|
||||
dummy_accelerator = accelerate.Accelerator()
|
||||
if accelerate.utils.is_xpu_available():
|
||||
return "xpu:{}".format(dummy_accelerator.local_process_index)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user