[v1] add accelerator (#9607)

This commit is contained in:
Yaowei Zheng
2025-12-12 19:22:06 +08:00
committed by GitHub
parent 4fd94141a4
commit 203069e11c
36 changed files with 941 additions and 443 deletions

View File

@@ -13,12 +13,13 @@
# limitations under the License.
from typing import Callable, TypedDict
from typing import Any, Literal, TypedDict
from typing_extensions import NotRequired, Required
from typing_extensions import NotRequired
from ....extras import logging
from ...extras.types import DPOSample, Sample, SFTSample
from ...utils import logging
from ...utils.plugin import BasePlugin
from ...utils.types import DPOSample, Sample, SFTSample
logger = logging.get_logger(__name__)
@@ -26,35 +27,48 @@ logger = logging.get_logger(__name__)
class AlpacaSample(TypedDict, total=False):
system: NotRequired[str]
instruction: NotRequired[str]
instruction: str
input: NotRequired[str]
output: NotRequired[str]
output: str
ShareGPTMessage = TypedDict(
"ShareGPTMessage",
{
"from": Required[str], # Role of the message sender (e.g., "human", "gpt", "system")
"value": Required[str], # Content of the message
},
SharegptMessage = TypedDict(
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
)
class ShareGPTSample(TypedDict, total=False):
"""Type definition for raw ShareGPT sample."""
class SharegptSample(TypedDict, total=False):
conversations: list[SharegptMessage]
tools: NotRequired[str]
conversations: Required[list[ShareGPTMessage]]
class OpenaiMessage(TypedDict, total=False):
role: Literal["user", "assistant", "tool"]
content: str
class OpenaiSample(TypedDict, total=False):
messages: list[OpenaiMessage]
class PairSample(TypedDict, total=False):
prompt: NotRequired[str]
chosen: NotRequired[list[dict]]
rejected: NotRequired[list[dict]]
chosen: list[OpenaiMessage]
rejected: list[OpenaiMessage]
class DataConverterPlugin(BasePlugin):
"""Plugin for data converters."""
def __call__(self, raw_sample: dict[str, Any]) -> Sample:
return super().__call__(raw_sample)
@DataConverterPlugin("alpaca").register
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample.
See raw example at: https://huggingface.co/datasets/llamafactory/alpaca_gpt4_en
Args:
raw_sample (AlpacaSample): Alpaca sample.
@@ -67,20 +81,6 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
{"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
)
if "history" in raw_sample:
for idx, item in enumerate(raw_sample["history"]):
if len(item) != 2:
logger.warning_rank0(
f"Warning: History item at index {idx} has invalid length (expected 2, got {len(item)}). Skipping."
)
continue
old_prompt, old_response = item
messages.append({"role": "user", "content": [{"type": "text", "value": old_prompt}], "loss_weight": 0.0})
messages.append(
{"role": "assistant", "content": [{"type": "text", "value": old_response}], "loss_weight": 1.0}
)
if "instruction" in raw_sample or "input" in raw_sample:
messages.append(
{
@@ -100,149 +100,85 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages}
def sharegpt_converter(raw_sample: ShareGPTSample) -> SFTSample:
"""Converts a raw ShareGPT sample into a formatted SFT (Supervised Fine-Tuning) sample.
@DataConverterPlugin("sharegpt").register
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"""Convert ShareGPT sample to SFT sample.
Retains only SFT-relevant scenarios and removes parity checks.
See raw example at: https://huggingface.co/datasets/llamafactory/glaive_toolcall_en
Args:
raw_sample (ShareGPTSample): A raw sample in ShareGPT format.
raw_sample (SharegptSample): ShareGPT sample.
Returns:
dict: A dictionary containing the formatted 'messages' list for SFT training.
Returns an empty list if the input data is invalid.
SFTSample: SFT sample.
"""
tag_mapping = {
"system": "system",
"human": "user",
"gpt": "assistant",
"observation": "observation",
"function_call": "function",
"observation": "tool",
"function_call": "assistant",
}
messages = raw_sample.get("conversations", [])
aligned_messages = []
system_content = ""
messages = []
tools = raw_sample.get("tools", "")
# Extract system message if present (typically the first message)
if messages and messages[0]["from"] == "system":
system_content = messages[0]["value"]
messages = messages[1:]
for message in raw_sample.get("conversations", []):
tag = message["from"]
if tag not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
elif tag == "function_call":
messages.append(
{
"role": "assistant",
"content": [{"type": "tool_calls", "value": message["value"]}],
"loss_weight": 1.0,
}
)
else:
messages.append(
{
"role": tag_mapping[tag],
"content": [{"type": "text", "value": message["value"]}],
"loss_weight": 1.0 if tag == "gpt" else 0.0,
}
)
if system_content:
aligned_messages.append(
{"role": "system", "content": [{"type": "text", "value": system_content}], "loss_weight": 0.0}
)
if tools:
if messages and messages[0]["role"] == "system":
messages[0]["content"].append({"type": "tools", "value": tools})
else:
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
has_invalid_role = False
for message in messages:
sender = message["from"]
# validate sender is in supported tags
if sender not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag '{sender}' in message: {message}")
has_invalid_role = True
break
aligned_messages.append(
{
"role": tag_mapping[sender],
"content": [{"type": "text", "value": message["value"]}],
"loss_weight": 0.0 if sender in ("human", "observation") else 1.0,
}
)
if has_invalid_role:
logger.warning_rank0("Skipping invalid example due to unsupported role tags.")
return {"messages": []}
return {"messages": aligned_messages}
return {"messages": messages}
@DataConverterPlugin("pair").register
def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to standard DPO sample.
"""Convert Pair sample to DPO sample.
See raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
Args:
raw_sample (PairSample): pair sample with prompt, chosen, rejected fields.
see raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
raw_sample (PairSample): pair sample with chosen, rejected fields.
Returns:
DPOSample: DPO sample with chosen_messages and rejected_messages.
see the standard DPO sample at: https://huggingface.co/datasets/frozenleaves/v1-dpo-demo/raw/main/v1-dpo-demo.jsonl
"""
chosen_messages = []
assert "chosen" in raw_sample, "chosen field is required in pair sample."
assert "rejected" in raw_sample, "rejected field is required in pair sample."
assert isinstance(raw_sample["chosen"], list) and isinstance(raw_sample["rejected"], list), (
"chosen and rejected field should be a list[dict], or you may need to implement your custom converter."
)
if "chosen" in raw_sample:
value = raw_sample.get("chosen", "")
for item in value:
if item.get("role", "") == "system":
chosen_messages.append(
{
"role": "system",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "user":
chosen_messages.append(
{
"role": "user",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "assistant":
chosen_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 1.0,
}
)
def process_message(raw_messages: list[OpenaiMessage]):
messages = []
for message in raw_messages:
messages.append(
{
"role": message["role"],
"content": [{"type": "text", "value": message["content"]}],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
rejected_messages = []
if "rejected" in raw_sample:
value = raw_sample.get("rejected", "")
for item in value:
if item.get("role", "") == "system":
rejected_messages.append(
{
"role": "system",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "user":
rejected_messages.append(
{
"role": "user",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "assistant":
rejected_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 1.0,
}
)
return messages
chosen_messages = process_message(raw_sample.get("chosen", []))
rejected_messages = process_message(raw_sample.get("rejected", []))
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
CONVERTERS = {
"alpaca": alpaca_converter,
"pair": pair_converter,
"sharegpt": sharegpt_converter,
}
def get_converter(converter_name: str) -> Callable[[dict], Sample]:
if converter_name not in CONVERTERS:
raise ValueError(f"Converter {converter_name} not found.")
return CONVERTERS[converter_name]

View File

@@ -14,57 +14,59 @@
import os
from dataclasses import dataclass
import random
from typing import Any, Literal, Optional, Union
from datasets import load_dataset
from ...extras.types import DatasetInfo, HFDataset
from ...utils.plugin import BasePlugin
from ...utils.types import DatasetInfo, HFDataset
@dataclass
class DataLoaderPlugin:
class DataLoaderPlugin(BasePlugin):
"""Plugin for loading dataset."""
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
"""Get dataset builder name.
Args:
path (str): Dataset path.
Returns:
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
"""
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
dataset_dir = dataset_info.get("dataset_dir", ".")
def load(self, dataset_info: DatasetInfo) -> HFDataset:
path = dataset_info["path"]
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
if "file_name" in dataset_info:
filepath = os.path.join(dataset_dir, dataset_info["file_name"])
return self.load_data_from_file(filepath, split, streaming)
else:
raise NotImplementedError()
def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = self._get_builder_name(os.listdir(filepath)[0])
dataset = load_dataset(filetype, data_dir=filepath, split=split)
elif os.path.isfile(filepath):
filetype = self._get_builder_name(filepath)
dataset = load_dataset(filetype, data_files=filepath, split=split)
else:
raise ValueError(f"Can not load dataset from {filepath}.")
if streaming:
dataset = dataset.to_iterable_dataset()
return dataset
return super().__call__(path, split, streaming)
@dataclass
class DataIndexPlugin:
def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
"""Get dataset builder name.
Args:
path (str): Dataset path.
Returns:
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
"""
filetype = os.path.splitext(path)[-1][1:]
if filetype in ["arrow", "csv", "json", "jsonl", "parquet", "txt"]:
return filetype.replace("jsonl", "json").replace("txt", "text")
else:
raise ValueError(f"Unknown dataset filetype: {filetype}.")
@DataLoaderPlugin("local").register
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = _get_builder_name(os.listdir(filepath)[0])
dataset = load_dataset(filetype, data_dir=filepath, split=split)
elif os.path.isfile(filepath):
filetype = _get_builder_name(filepath)
dataset = load_dataset(filetype, data_files=filepath, split=split)
else:
raise ValueError(f"Can not load dataset from {filepath}.")
if streaming: # faster when data is streamed from local files
dataset = dataset.to_iterable_dataset()
return dataset
class DataIndexPlugin(BasePlugin):
"""Plugin for adjusting dataset index."""
def adjust_data_index(
@@ -81,39 +83,32 @@ class DataIndexPlugin:
list[tuple[str, int]]: Adjusted dataset index.
"""
if size is not None:
data_index = self.adjust_by_size(data_index, size)
data_index = random.choices(data_index, k=size)
if weight is not None:
data_index = self.adjust_by_weight(data_index, weight)
data_index = random.choices(data_index, k=int(len(data_index) * weight))
return data_index
def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]:
raise NotImplementedError()
def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]:
raise NotImplementedError()
@dataclass
class DataSelectorPlugin:
class DataSelectorPlugin(BasePlugin):
"""Plugin for selecting dataset samples."""
data_index: list[tuple[str, int]]
"""List of (dataset_name, sample_index)"""
def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]:
def select(
self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any]
) -> Union[tuple[str, int], list[tuple[str, int]]]:
"""Select dataset samples.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples.
Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
"""
if isinstance(index, slice):
return [self.data_index[i] for i in range(*index.indices(len(self.data_index)))]
return [data_index[i] for i in range(*index.indices(len(data_index)))]
elif isinstance(index, list):
return [self.data_index[i] for i in index]
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")

View File

@@ -21,10 +21,3 @@ class KernelType(str, Enum):
FLASH_ATTENTION = "flash_attention"
ROPE = "rope"
MOE = "moe"
class DeviceType(str, Enum):
CPU = "cpu"
CUDA = "cuda"
NPU = "npu"
XPU = "xpu"

View File

@@ -18,10 +18,10 @@ import torch
import torch.nn.functional as F
import torch_npu
from .....accelerator.helper import is_torch_npu_available
from .....extras.packages import is_transformers_version_greater_than
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.packages import is_transformers_version_greater_than
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaMoEKernel

View File

@@ -17,9 +17,9 @@ import types
import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaSwiGluKernel

View File

@@ -13,11 +13,11 @@
# limitations under the License.
from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
from ....accelerator.helper import get_current_accelerator
from ....extras.types import HFModel
from .constants import DeviceType, KernelType
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel
from .constants import KernelType
class KernelRegistry:
@@ -27,11 +27,13 @@ class KernelRegistry:
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
self._initialized = True
@@ -218,7 +220,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
return discovered_kernels
# Iterate through registry and collect all kernels for current device
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
for devices in KERNEL_REGISTRY._registry.values():
kernel_cls = devices.get(device_type)
if kernel_cls is not None:
discovered_kernels.append(kernel_cls)
@@ -226,7 +228,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
return discovered_kernels
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
"""Call the MetaKernel's `apply` to perform the replacement.
Corresponding replacement logic is maintained inside each kernel; the only
@@ -238,16 +240,18 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel)
"""
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
return kernel.apply(model, **kwargs)
if not issubclass(kernel, MetaKernel):
raise ValueError(f"{kernel} must be a MetaKernel instance.")
raise ValueError(
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
)
if kernel.device != get_current_accelerator().type:
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
return kernel.apply(model, **kwargs)
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
"""Apply all available kernels to the model."""
for kernel in discover_kernels(model):
model = apply_kernel(model, kernel, **kwargs)
return model

View File

@@ -14,9 +14,9 @@
import re
import types
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaRMSNormKernel

View File

@@ -16,9 +16,9 @@ import sys
import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaRoPEKernel

View File

@@ -0,0 +1,42 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, TypedDict
from peft import LoraConfig, PeftModel, get_peft_model
from ...utils.plugin import BasePlugin
from ...utils.types import HFModel
class LoraConfigDict(TypedDict, total=False):
name: Literal["lora"]
"""Plugin name."""
r: int
"""Lora rank."""
lora_alpha: int
"""Lora alpha."""
target_modules: list[str]
"""Target modules."""
class PeftPlugin(BasePlugin):
pass
@PeftPlugin("lora").register
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config)
return model

View File

@@ -1,13 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.