[model] clean obsolete models (#9736)

This commit is contained in:
Yaowei Zheng
2026-01-09 14:08:18 +08:00
committed by hiyouga
parent 5fb5d7ebd3
commit 5cccaeec82
6 changed files with 17 additions and 795 deletions

View File

@@ -35,7 +35,7 @@ from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils import logging
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from ..utils.types import DistributedConfig, ProcessGroup, TensorLike
from . import helper
@@ -214,7 +214,7 @@ class DistributedInterface:
"""Get parallel local world size."""
return self._local_world_size
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
"""Gather tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))