mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[v1] add models & accelerator (#9579)
This commit is contained in:
@@ -19,7 +19,6 @@ from typing import Any, Literal, Optional, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from ...config.data_args import DataArguments
|
||||
from ...extras.types import DatasetInfo, HFDataset
|
||||
|
||||
|
||||
@@ -27,9 +26,6 @@ from ...extras.types import DatasetInfo, HFDataset
|
||||
class DataLoaderPlugin:
|
||||
"""Plugin for loading dataset."""
|
||||
|
||||
args: DataArguments
|
||||
"""Data arguments."""
|
||||
|
||||
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
|
||||
"""Get dataset builder name.
|
||||
|
||||
@@ -42,7 +38,7 @@ class DataLoaderPlugin:
|
||||
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
|
||||
|
||||
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
|
||||
dataset_dir = dataset_info.get("dataset_dir", self.args.dataset_dir)
|
||||
dataset_dir = dataset_info.get("dataset_dir", ".")
|
||||
split = dataset_info.get("split", "train")
|
||||
streaming = dataset_info.get("streaming", False)
|
||||
if "file_name" in dataset_info:
|
||||
|
||||
@@ -18,9 +18,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.packages import is_transformers_version_greater_than
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaMoEKernel
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ import types
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaSwiGluKernel
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from ....extras.types import HFModel
|
||||
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
|
||||
from .constants import DeviceType, KernelType
|
||||
|
||||
|
||||
@@ -206,7 +206,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
discovered_kernels: list[type[MetaKernel]] = []
|
||||
|
||||
# Detect current device type
|
||||
accelerator = get_available_accelerator()
|
||||
accelerator = get_current_accelerator()
|
||||
try:
|
||||
device_type = DeviceType(accelerator.type)
|
||||
except ValueError:
|
||||
@@ -238,11 +238,11 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
|
||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||
model = apply_kernel(model, NpuRMSNormKernel)
|
||||
"""
|
||||
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
|
||||
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
|
||||
return kernel.apply(model, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
|
||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
import re
|
||||
import types
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaRMSNormKernel
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ import sys
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaRoPEKernel
|
||||
|
||||
|
||||
@@ -11,37 +11,3 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_available_accelerator():
|
||||
"""Get available accelerator in current environment.
|
||||
|
||||
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
|
||||
"""
|
||||
accelerator = torch.accelerator.current_accelerator()
|
||||
if accelerator is None:
|
||||
return torch.device("cpu")
|
||||
return accelerator
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_npu_available():
|
||||
return get_available_accelerator().type == "npu"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_cuda_available():
|
||||
return get_available_accelerator().type == "cuda"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_xpu_available():
|
||||
return get_available_accelerator().type == "xpu"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_mps_available():
|
||||
return get_available_accelerator().type == "mps"
|
||||
|
||||
Reference in New Issue
Block a user