mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 08:33:38 +00:00
[model] clean obsolete models (#9736)
This commit is contained in:
@@ -35,7 +35,7 @@ from torch.distributed import barrier, destroy_process_group, init_process_group
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
|
||||
from ..utils import logging
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, TensorLike
|
||||
from . import helper
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ class DistributedInterface:
|
||||
"""Get parallel local world size."""
|
||||
return self._local_world_size
|
||||
|
||||
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
|
||||
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||
"""Gather tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
||||
|
||||
Reference in New Issue
Block a user