mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-29 22:02:03 +00:00
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
# Copyright 2025 the LlamaFactory team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
|
|
@pytest.mark.xfail(reason="CI machines may OOM when heavily loaded.")
|
|
@pytest.mark.runs_on(["cuda", "npu"])
|
|
def test_fsdp2_sft_trainer(tmp_path: Path):
|
|
"""Test FSDP2 SFT trainer by simulating `llamafactory-cli sft config.yaml` behavior."""
|
|
config_yaml = """\
|
|
model: Qwen/Qwen3-0.6B
|
|
trust_remote_code: true
|
|
model_class: llm
|
|
|
|
template: qwen3_nothink
|
|
|
|
kernel_config:
|
|
name: auto
|
|
include_kernels: auto
|
|
|
|
quant_config: null
|
|
|
|
dist_config:
|
|
name: fsdp2
|
|
dcp_path: null
|
|
|
|
init_config:
|
|
name: init_on_meta
|
|
|
|
### data
|
|
train_dataset: data/v1_sft_demo.yaml
|
|
|
|
### training
|
|
output_dir: {output_dir}
|
|
micro_batch_size: 1
|
|
global_batch_size: 1
|
|
cutoff_len: 2048
|
|
learning_rate: 1.0e-4
|
|
bf16: false
|
|
max_steps: 1
|
|
|
|
### sample
|
|
sample_backend: hf
|
|
max_new_tokens: 128
|
|
"""
|
|
# Create output directory
|
|
output_dir = tmp_path / "outputs"
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
config_file = tmp_path / "config.yaml"
|
|
config_file.write_text(config_yaml.format(output_dir=str(output_dir)))
|
|
|
|
# Set up environment variables
|
|
env = os.environ.copy()
|
|
env["USE_V1"] = "1" # Use v1 launcher
|
|
env["FORCE_TORCHRUN"] = "1" # Force distributed training via torchrun
|
|
|
|
# Run the CLI command via subprocess
|
|
# This simulates: llamafactory-cli sft config.yaml
|
|
result = subprocess.run(
|
|
[sys.executable, "-m", "llamafactory.cli", "sft", str(config_file)],
|
|
env=env,
|
|
capture_output=True,
|
|
cwd=str(Path(__file__).parent.parent.parent), # LLaMA-Factory root
|
|
)
|
|
|
|
# Decode output with error handling (progress bars may contain non-UTF-8 bytes)
|
|
stderr = result.stderr.decode("utf-8", errors="replace")
|
|
|
|
# Check the result
|
|
assert result.returncode == 0, f"Training failed with return code {result.returncode}\nSTDERR: {stderr}"
|
|
|
|
# Verify output files exist (optional - adjust based on what run_sft produces)
|
|
# assert (output_dir / "some_expected_file").exists()
|