[v1] add models & accelerator (#9579)

This commit is contained in:
Yaowei Zheng
2025-12-08 02:30:25 +08:00
committed by GitHub
parent 739954910a
commit 5744f1ea94
27 changed files with 335 additions and 105 deletions

View File

@@ -19,7 +19,6 @@ from typing import Any, Literal, Optional, Union
from datasets import load_dataset
from ...config.data_args import DataArguments
from ...extras.types import DatasetInfo, HFDataset
@@ -27,9 +26,6 @@ from ...extras.types import DatasetInfo, HFDataset
class DataLoaderPlugin:
"""Plugin for loading dataset."""
args: DataArguments
"""Data arguments."""
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
"""Get dataset builder name.
@@ -42,7 +38,7 @@ class DataLoaderPlugin:
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
dataset_dir = dataset_info.get("dataset_dir", self.args.dataset_dir)
dataset_dir = dataset_info.get("dataset_dir", ".")
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
if "file_name" in dataset_info:

View File

@@ -18,9 +18,9 @@ import torch
import torch.nn.functional as F
import torch_npu
from .....accelerator.helper import is_torch_npu_available
from .....extras.packages import is_transformers_version_greater_than
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import MetaMoEKernel

View File

@@ -17,8 +17,8 @@ import types
import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import MetaSwiGluKernel

View File

@@ -15,8 +15,8 @@
from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Callable, Optional
from ....accelerator.helper import get_current_accelerator
from ....extras.types import HFModel
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
from .constants import DeviceType, KernelType
@@ -206,7 +206,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
discovered_kernels: list[type[MetaKernel]] = []
# Detect current device type
accelerator = get_available_accelerator()
accelerator = get_current_accelerator()
try:
device_type = DeviceType(accelerator.type)
except ValueError:
@@ -238,11 +238,11 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel)
"""
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
return kernel.apply(model, **kwargs)
raise ValueError(
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
)

View File

@@ -14,8 +14,8 @@
import re
import types
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import MetaRMSNormKernel

View File

@@ -16,8 +16,8 @@ import sys
import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import MetaRoPEKernel

View File

@@ -11,37 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import lru_cache
import torch
def get_available_accelerator():
"""Get available accelerator in current environment.
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
"""
accelerator = torch.accelerator.current_accelerator()
if accelerator is None:
return torch.device("cpu")
return accelerator
@lru_cache
def is_torch_npu_available():
return get_available_accelerator().type == "npu"
@lru_cache
def is_torch_cuda_available():
return get_available_accelerator().type == "cuda"
@lru_cache
def is_torch_xpu_available():
return get_available_accelerator().type == "xpu"
@lru_cache
def is_torch_mps_available():
return get_available_accelerator().type == "mps"