[v1] add init plugin (#9716)

This commit is contained in:
Yaowei Zheng
2026-01-04 20:51:46 +08:00
committed by GitHub
parent 81b8a50aa5
commit f60a6e3d01
14 changed files with 307 additions and 74 deletions

View File

@@ -24,7 +24,6 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
### model
model: "llamafactory/tiny-random-qwen2.5"
trust_remote_code: true
use_fast_processor: true
model_class: "llm"
kernel_config:
name: "auto"

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLaMA-Factory test configuration.
"""LlamaFactory test configuration.
Contains shared fixtures, pytest configuration, and custom markers.
"""
@@ -22,6 +22,7 @@ import sys
import pytest
import torch
import torch.distributed as dist
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
@@ -109,17 +110,24 @@ def _handle_device_visibility(items: list[Item]):
def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"):
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
for item in items:
if "tests_v1" in str(item.fspath):
item.add_marker(skip_bc)
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
for item in items:
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
item.add_marker(skip_bc)
_handle_slow_tests(items)
_handle_runs_on(items)
_handle_device_visibility(items)
@pytest.fixture(autouse=True)
def _cleanup_distributed_state():
"""Cleanup distributed state after each test."""
yield
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(autouse=True)
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
@@ -155,6 +163,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else:
monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":

View File

@@ -0,0 +1,56 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args
from llamafactory.v1.core.model_loader import ModelLoader
def test_init_on_meta():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_meta"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == "meta"
@pytest.mark.runs_on(["cuda", "npu"])
def test_init_on_rank0():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_rank0"},
)
)
model_loader = ModelLoader(model_args=model_args)
if DistributedInterface().get_rank() == 0:
assert model_loader.model.device.type == "cpu"
else:
assert model_loader.model.device.type == "meta"
def test_init_on_default():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_default"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type