mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 08:53:38 +00:00
[v1] add cli sampler (#9721)
This commit is contained in:
@@ -25,12 +25,12 @@ class InitPlugin(BasePlugin):
|
||||
return super().__call__()
|
||||
|
||||
|
||||
@InitPlugin("init_on_meta").register
|
||||
@InitPlugin("init_on_meta").register()
|
||||
def init_on_meta() -> torch.device:
|
||||
return torch.device(DeviceType.META.value)
|
||||
|
||||
|
||||
@InitPlugin("init_on_rank0").register
|
||||
@InitPlugin("init_on_rank0").register()
|
||||
def init_on_rank0() -> torch.device:
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
@@ -38,6 +38,6 @@ def init_on_rank0() -> torch.device:
|
||||
return torch.device(DeviceType.META.value)
|
||||
|
||||
|
||||
@InitPlugin("init_on_default").register
|
||||
@InitPlugin("init_on_default").register()
|
||||
def init_on_default() -> torch.device:
|
||||
return DistributedInterface().current_accelerator
|
||||
return DistributedInterface().current_device
|
||||
|
||||
Reference in New Issue
Block a user