mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[v1] add accelerator (#9607)
This commit is contained in:
@@ -13,12 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Callable, TypedDict
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired, Required
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from ....extras import logging
|
||||
from ...extras.types import DPOSample, Sample, SFTSample
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import DPOSample, Sample, SFTSample
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -26,35 +27,48 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
class AlpacaSample(TypedDict, total=False):
|
||||
system: NotRequired[str]
|
||||
instruction: NotRequired[str]
|
||||
instruction: str
|
||||
input: NotRequired[str]
|
||||
output: NotRequired[str]
|
||||
output: str
|
||||
|
||||
|
||||
ShareGPTMessage = TypedDict(
|
||||
"ShareGPTMessage",
|
||||
{
|
||||
"from": Required[str], # Role of the message sender (e.g., "human", "gpt", "system")
|
||||
"value": Required[str], # Content of the message
|
||||
},
|
||||
SharegptMessage = TypedDict(
|
||||
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTSample(TypedDict, total=False):
|
||||
"""Type definition for raw ShareGPT sample."""
|
||||
class SharegptSample(TypedDict, total=False):
|
||||
conversations: list[SharegptMessage]
|
||||
tools: NotRequired[str]
|
||||
|
||||
conversations: Required[list[ShareGPTMessage]]
|
||||
|
||||
class OpenaiMessage(TypedDict, total=False):
|
||||
role: Literal["user", "assistant", "tool"]
|
||||
content: str
|
||||
|
||||
|
||||
class OpenaiSample(TypedDict, total=False):
|
||||
messages: list[OpenaiMessage]
|
||||
|
||||
|
||||
class PairSample(TypedDict, total=False):
|
||||
prompt: NotRequired[str]
|
||||
chosen: NotRequired[list[dict]]
|
||||
rejected: NotRequired[list[dict]]
|
||||
chosen: list[OpenaiMessage]
|
||||
rejected: list[OpenaiMessage]
|
||||
|
||||
|
||||
class DataConverterPlugin(BasePlugin):
|
||||
"""Plugin for data converters."""
|
||||
|
||||
def __call__(self, raw_sample: dict[str, Any]) -> Sample:
|
||||
return super().__call__(raw_sample)
|
||||
|
||||
|
||||
@DataConverterPlugin("alpaca").register
|
||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
"""Convert Alpaca sample to SFT sample.
|
||||
|
||||
See raw example at: https://huggingface.co/datasets/llamafactory/alpaca_gpt4_en
|
||||
|
||||
Args:
|
||||
raw_sample (AlpacaSample): Alpaca sample.
|
||||
|
||||
@@ -67,20 +81,6 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
{"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
|
||||
)
|
||||
|
||||
if "history" in raw_sample:
|
||||
for idx, item in enumerate(raw_sample["history"]):
|
||||
if len(item) != 2:
|
||||
logger.warning_rank0(
|
||||
f"Warning: History item at index {idx} has invalid length (expected 2, got {len(item)}). Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
old_prompt, old_response = item
|
||||
messages.append({"role": "user", "content": [{"type": "text", "value": old_prompt}], "loss_weight": 0.0})
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [{"type": "text", "value": old_response}], "loss_weight": 1.0}
|
||||
)
|
||||
|
||||
if "instruction" in raw_sample or "input" in raw_sample:
|
||||
messages.append(
|
||||
{
|
||||
@@ -100,149 +100,85 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
def sharegpt_converter(raw_sample: ShareGPTSample) -> SFTSample:
|
||||
"""Converts a raw ShareGPT sample into a formatted SFT (Supervised Fine-Tuning) sample.
|
||||
@DataConverterPlugin("sharegpt").register
|
||||
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"""Convert ShareGPT sample to SFT sample.
|
||||
|
||||
Retains only SFT-relevant scenarios and removes parity checks.
|
||||
See raw example at: https://huggingface.co/datasets/llamafactory/glaive_toolcall_en
|
||||
|
||||
Args:
|
||||
raw_sample (ShareGPTSample): A raw sample in ShareGPT format.
|
||||
raw_sample (SharegptSample): ShareGPT sample.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the formatted 'messages' list for SFT training.
|
||||
Returns an empty list if the input data is invalid.
|
||||
SFTSample: SFT sample.
|
||||
"""
|
||||
tag_mapping = {
|
||||
"system": "system",
|
||||
"human": "user",
|
||||
"gpt": "assistant",
|
||||
"observation": "observation",
|
||||
"function_call": "function",
|
||||
"observation": "tool",
|
||||
"function_call": "assistant",
|
||||
}
|
||||
messages = raw_sample.get("conversations", [])
|
||||
aligned_messages = []
|
||||
system_content = ""
|
||||
messages = []
|
||||
tools = raw_sample.get("tools", "")
|
||||
|
||||
# Extract system message if present (typically the first message)
|
||||
if messages and messages[0]["from"] == "system":
|
||||
system_content = messages[0]["value"]
|
||||
messages = messages[1:]
|
||||
for message in raw_sample.get("conversations", []):
|
||||
tag = message["from"]
|
||||
if tag not in tag_mapping:
|
||||
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
|
||||
elif tag == "function_call":
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_calls", "value": message["value"]}],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": tag_mapping[tag],
|
||||
"content": [{"type": "text", "value": message["value"]}],
|
||||
"loss_weight": 1.0 if tag == "gpt" else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
if system_content:
|
||||
aligned_messages.append(
|
||||
{"role": "system", "content": [{"type": "text", "value": system_content}], "loss_weight": 0.0}
|
||||
)
|
||||
if tools:
|
||||
if messages and messages[0]["role"] == "system":
|
||||
messages[0]["content"].append({"type": "tools", "value": tools})
|
||||
else:
|
||||
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
|
||||
|
||||
has_invalid_role = False
|
||||
for message in messages:
|
||||
sender = message["from"]
|
||||
# validate sender is in supported tags
|
||||
if sender not in tag_mapping:
|
||||
logger.warning_rank0(f"Unsupported role tag '{sender}' in message: {message}")
|
||||
has_invalid_role = True
|
||||
break
|
||||
|
||||
aligned_messages.append(
|
||||
{
|
||||
"role": tag_mapping[sender],
|
||||
"content": [{"type": "text", "value": message["value"]}],
|
||||
"loss_weight": 0.0 if sender in ("human", "observation") else 1.0,
|
||||
}
|
||||
)
|
||||
|
||||
if has_invalid_role:
|
||||
logger.warning_rank0("Skipping invalid example due to unsupported role tags.")
|
||||
return {"messages": []}
|
||||
|
||||
return {"messages": aligned_messages}
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@DataConverterPlugin("pair").register
|
||||
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
"""Convert Pair sample to standard DPO sample.
|
||||
"""Convert Pair sample to DPO sample.
|
||||
|
||||
See raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
|
||||
|
||||
Args:
|
||||
raw_sample (PairSample): pair sample with prompt, chosen, rejected fields.
|
||||
see raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
|
||||
raw_sample (PairSample): pair sample with chosen, rejected fields.
|
||||
|
||||
Returns:
|
||||
DPOSample: DPO sample with chosen_messages and rejected_messages.
|
||||
see the standard DPO sample at: https://huggingface.co/datasets/frozenleaves/v1-dpo-demo/raw/main/v1-dpo-demo.jsonl
|
||||
"""
|
||||
chosen_messages = []
|
||||
assert "chosen" in raw_sample, "chosen field is required in pair sample."
|
||||
assert "rejected" in raw_sample, "rejected field is required in pair sample."
|
||||
assert isinstance(raw_sample["chosen"], list) and isinstance(raw_sample["rejected"], list), (
|
||||
"chosen and rejected field should be a list[dict], or you may need to implement your custom converter."
|
||||
)
|
||||
|
||||
if "chosen" in raw_sample:
|
||||
value = raw_sample.get("chosen", "")
|
||||
for item in value:
|
||||
if item.get("role", "") == "system":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "user":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "assistant":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
def process_message(raw_messages: list[OpenaiMessage]):
|
||||
messages = []
|
||||
for message in raw_messages:
|
||||
messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
"content": [{"type": "text", "value": message["content"]}],
|
||||
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
rejected_messages = []
|
||||
if "rejected" in raw_sample:
|
||||
value = raw_sample.get("rejected", "")
|
||||
for item in value:
|
||||
if item.get("role", "") == "system":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "user":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "assistant":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
return messages
|
||||
|
||||
chosen_messages = process_message(raw_sample.get("chosen", []))
|
||||
rejected_messages = process_message(raw_sample.get("rejected", []))
|
||||
|
||||
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
|
||||
|
||||
|
||||
CONVERTERS = {
|
||||
"alpaca": alpaca_converter,
|
||||
"pair": pair_converter,
|
||||
"sharegpt": sharegpt_converter,
|
||||
}
|
||||
|
||||
|
||||
def get_converter(converter_name: str) -> Callable[[dict], Sample]:
|
||||
if converter_name not in CONVERTERS:
|
||||
raise ValueError(f"Converter {converter_name} not found.")
|
||||
|
||||
return CONVERTERS[converter_name]
|
||||
|
||||
@@ -14,57 +14,59 @@
|
||||
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
import random
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from ...extras.types import DatasetInfo, HFDataset
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import DatasetInfo, HFDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataLoaderPlugin:
|
||||
class DataLoaderPlugin(BasePlugin):
|
||||
"""Plugin for loading dataset."""
|
||||
|
||||
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
|
||||
"""Get dataset builder name.
|
||||
|
||||
Args:
|
||||
path (str): Dataset path.
|
||||
|
||||
Returns:
|
||||
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
|
||||
"""
|
||||
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
|
||||
|
||||
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
|
||||
dataset_dir = dataset_info.get("dataset_dir", ".")
|
||||
def load(self, dataset_info: DatasetInfo) -> HFDataset:
|
||||
path = dataset_info["path"]
|
||||
split = dataset_info.get("split", "train")
|
||||
streaming = dataset_info.get("streaming", False)
|
||||
if "file_name" in dataset_info:
|
||||
filepath = os.path.join(dataset_dir, dataset_info["file_name"])
|
||||
return self.load_data_from_file(filepath, split, streaming)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset:
|
||||
if os.path.isdir(filepath):
|
||||
filetype = self._get_builder_name(os.listdir(filepath)[0])
|
||||
dataset = load_dataset(filetype, data_dir=filepath, split=split)
|
||||
elif os.path.isfile(filepath):
|
||||
filetype = self._get_builder_name(filepath)
|
||||
dataset = load_dataset(filetype, data_files=filepath, split=split)
|
||||
else:
|
||||
raise ValueError(f"Can not load dataset from {filepath}.")
|
||||
|
||||
if streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
|
||||
return dataset
|
||||
return super().__call__(path, split, streaming)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataIndexPlugin:
|
||||
def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
|
||||
"""Get dataset builder name.
|
||||
|
||||
Args:
|
||||
path (str): Dataset path.
|
||||
|
||||
Returns:
|
||||
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
|
||||
"""
|
||||
filetype = os.path.splitext(path)[-1][1:]
|
||||
if filetype in ["arrow", "csv", "json", "jsonl", "parquet", "txt"]:
|
||||
return filetype.replace("jsonl", "json").replace("txt", "text")
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset filetype: {filetype}.")
|
||||
|
||||
|
||||
@DataLoaderPlugin("local").register
|
||||
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
|
||||
if os.path.isdir(filepath):
|
||||
filetype = _get_builder_name(os.listdir(filepath)[0])
|
||||
dataset = load_dataset(filetype, data_dir=filepath, split=split)
|
||||
elif os.path.isfile(filepath):
|
||||
filetype = _get_builder_name(filepath)
|
||||
dataset = load_dataset(filetype, data_files=filepath, split=split)
|
||||
else:
|
||||
raise ValueError(f"Can not load dataset from {filepath}.")
|
||||
|
||||
if streaming: # faster when data is streamed from local files
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class DataIndexPlugin(BasePlugin):
|
||||
"""Plugin for adjusting dataset index."""
|
||||
|
||||
def adjust_data_index(
|
||||
@@ -81,39 +83,32 @@ class DataIndexPlugin:
|
||||
list[tuple[str, int]]: Adjusted dataset index.
|
||||
"""
|
||||
if size is not None:
|
||||
data_index = self.adjust_by_size(data_index, size)
|
||||
data_index = random.choices(data_index, k=size)
|
||||
|
||||
if weight is not None:
|
||||
data_index = self.adjust_by_weight(data_index, weight)
|
||||
data_index = random.choices(data_index, k=int(len(data_index) * weight))
|
||||
|
||||
return data_index
|
||||
|
||||
def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataSelectorPlugin:
|
||||
class DataSelectorPlugin(BasePlugin):
|
||||
"""Plugin for selecting dataset samples."""
|
||||
|
||||
data_index: list[tuple[str, int]]
|
||||
"""List of (dataset_name, sample_index)"""
|
||||
|
||||
def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]:
|
||||
def select(
|
||||
self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any]
|
||||
) -> Union[tuple[str, int], list[tuple[str, int]]]:
|
||||
"""Select dataset samples.
|
||||
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
index (Union[slice, list[int], Any]): Index of dataset samples.
|
||||
|
||||
Returns:
|
||||
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
||||
"""
|
||||
if isinstance(index, slice):
|
||||
return [self.data_index[i] for i in range(*index.indices(len(self.data_index)))]
|
||||
return [data_index[i] for i in range(*index.indices(len(data_index)))]
|
||||
elif isinstance(index, list):
|
||||
return [self.data_index[i] for i in index]
|
||||
return [data_index[i] for i in index]
|
||||
else:
|
||||
raise ValueError(f"Invalid index type {type(index)}.")
|
||||
|
||||
@@ -21,10 +21,3 @@ class KernelType(str, Enum):
|
||||
FLASH_ATTENTION = "flash_attention"
|
||||
ROPE = "rope"
|
||||
MOE = "moe"
|
||||
|
||||
|
||||
class DeviceType(str, Enum):
|
||||
CPU = "cpu"
|
||||
CUDA = "cuda"
|
||||
NPU = "npu"
|
||||
XPU = "xpu"
|
||||
|
||||
@@ -18,10 +18,10 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.packages import is_transformers_version_greater_than
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.packages import is_transformers_version_greater_than
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaMoEKernel
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ import types
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaSwiGluKernel
|
||||
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from ....extras.types import HFModel
|
||||
from .constants import DeviceType, KernelType
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.types import HFModel
|
||||
from .constants import KernelType
|
||||
|
||||
|
||||
class KernelRegistry:
|
||||
@@ -27,11 +27,13 @@ class KernelRegistry:
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
|
||||
self._initialized = True
|
||||
|
||||
@@ -218,7 +220,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
return discovered_kernels
|
||||
|
||||
# Iterate through registry and collect all kernels for current device
|
||||
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
|
||||
for devices in KERNEL_REGISTRY._registry.values():
|
||||
kernel_cls = devices.get(device_type)
|
||||
if kernel_cls is not None:
|
||||
discovered_kernels.append(kernel_cls)
|
||||
@@ -226,7 +228,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
return discovered_kernels
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
|
||||
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
|
||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||
|
||||
Corresponding replacement logic is maintained inside each kernel; the only
|
||||
@@ -238,16 +240,18 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
|
||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||
model = apply_kernel(model, NpuRMSNormKernel)
|
||||
"""
|
||||
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
|
||||
return kernel.apply(model, **kwargs)
|
||||
if not issubclass(kernel, MetaKernel):
|
||||
raise ValueError(f"{kernel} must be a MetaKernel instance.")
|
||||
|
||||
raise ValueError(
|
||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
|
||||
)
|
||||
if kernel.device != get_current_accelerator().type:
|
||||
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
|
||||
|
||||
return kernel.apply(model, **kwargs)
|
||||
|
||||
|
||||
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
|
||||
"""Apply all available kernels to the model."""
|
||||
for kernel in discover_kernels(model):
|
||||
model = apply_kernel(model, kernel, **kwargs)
|
||||
|
||||
return model
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
import re
|
||||
import types
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaRMSNormKernel
|
||||
|
||||
|
||||
|
||||
@@ -16,9 +16,9 @@ import sys
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaRoPEKernel
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import HFModel
|
||||
|
||||
|
||||
class LoraConfigDict(TypedDict, total=False):
|
||||
name: Literal["lora"]
|
||||
"""Plugin name."""
|
||||
r: int
|
||||
"""Lora rank."""
|
||||
lora_alpha: int
|
||||
"""Lora alpha."""
|
||||
target_modules: list[str]
|
||||
"""Target modules."""
|
||||
|
||||
|
||||
class PeftPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
|
||||
@PeftPlugin("lora").register
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
|
||||
peft_config = LoraConfig(**config)
|
||||
model = get_peft_model(model, peft_config)
|
||||
return model
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
Reference in New Issue
Block a user