mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-04 09:13:10 +00:00
[v1] add init plugin (#9716)
This commit is contained in:
@@ -24,7 +24,6 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
||||
### model
|
||||
model: "llamafactory/tiny-random-qwen2.5"
|
||||
trust_remote_code: true
|
||||
use_fast_processor: true
|
||||
model_class: "llm"
|
||||
kernel_config:
|
||||
name: "auto"
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""LLaMA-Factory test configuration.
|
||||
"""LlamaFactory test configuration.
|
||||
|
||||
Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
@@ -22,6 +22,7 @@ import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
||||
|
||||
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
|
||||
@@ -109,17 +110,24 @@ def _handle_device_visibility(items: list[Item]):
|
||||
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||
"""Modify test collection based on markers and environment."""
|
||||
# Handle version compatibility (from HEAD)
|
||||
if not is_transformers_version_greater_than("4.57.0"):
|
||||
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
||||
for item in items:
|
||||
if "tests_v1" in str(item.fspath):
|
||||
item.add_marker(skip_bc)
|
||||
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
||||
for item in items:
|
||||
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
|
||||
item.add_marker(skip_bc)
|
||||
|
||||
_handle_slow_tests(items)
|
||||
_handle_runs_on(items)
|
||||
_handle_device_visibility(items)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_distributed_state():
|
||||
"""Cleanup distributed state after each test."""
|
||||
yield
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
|
||||
"""Set environment variables for distributed tests if specific devices are requested."""
|
||||
@@ -155,6 +163,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
||||
else:
|
||||
monkeypatch.setenv(env_key, "0")
|
||||
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||
elif CURRENT_DEVICE == "npu":
|
||||
|
||||
56
tests_v1/plugins/model_plugins/test_init_plugin.py
Normal file
56
tests_v1/plugins/model_plugins/test_init_plugin.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||
from llamafactory.v1.config.arg_parser import get_args
|
||||
from llamafactory.v1.core.model_loader import ModelLoader
|
||||
|
||||
|
||||
def test_init_on_meta():
|
||||
_, model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen2.5",
|
||||
init_config={"name": "init_on_meta"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
assert model_loader.model.device.type == "meta"
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cuda", "npu"])
|
||||
def test_init_on_rank0():
|
||||
_, model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen2.5",
|
||||
init_config={"name": "init_on_rank0"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
assert model_loader.model.device.type == "cpu"
|
||||
else:
|
||||
assert model_loader.model.device.type == "meta"
|
||||
|
||||
|
||||
def test_init_on_default():
|
||||
_, model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen2.5",
|
||||
init_config={"name": "init_on_default"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type
|
||||
Reference in New Issue
Block a user