[v1] add v1 launcher (#9236)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yaowei Zheng
2025-10-07 22:34:48 +08:00
committed by GitHub
parent 95b7188090
commit 10146029ba
26 changed files with 661 additions and 452 deletions

View File

@@ -12,145 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from copy import deepcopy
from functools import partial
USAGE = (
"-" * 70
+ "\n"
+ "| Usage: |\n"
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
+ "| llamafactory-cli train -h: train models |\n"
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
+ "| llamafactory-cli env: show environment info |\n"
+ "| llamafactory-cli version: show version info |\n"
+ "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n"
+ "-" * 70
)
def main():
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .extras.misc import is_env_enabled
if is_env_enabled("USE_V1"):
from .v1 import launcher
else:
from . import launcher
logger = logging.get_logger(__name__)
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
COMMAND_MAP = {
"api": launcher.run_api,
"chat": launcher.run_chat,
"env": print_env,
"eval": launcher.run_eval,
"export": launcher.export_model,
"train": launcher.run_exp,
"webchat": launcher.run_web_demo,
"webui": launcher.run_web_ui,
"version": partial(print, WELCOME),
"help": partial(print, USAGE),
}
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
# elastic launch support
max_restarts = os.getenv("MAX_RESTARTS", "0")
rdzv_id = os.getenv("RDZV_ID")
min_nnodes = os.getenv("MIN_NNODES")
max_nnodes = os.getenv("MAX_NNODES")
env = deepcopy(os.environ)
if is_env_enabled("OPTIM_TORCH", "1"):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
if rdzv_id is not None:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes = nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if min_nnodes is not None and max_nnodes is not None:
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
process = subprocess.run(
(
"torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
"--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
"--max-restarts {max_restarts} {file_name} {args}"
)
.format(
rdzv_nnodes=rdzv_nnodes,
nproc_per_node=nproc_per_node,
rdzv_id=rdzv_id,
master_addr=master_addr,
master_port=master_port,
max_restarts=max_restarts,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
)
.split(),
env=env,
check=True,
)
else:
# NOTE: DO NOT USE shell=True to avoid security risk
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.format(
nnodes=nnodes,
node_rank=node_rank,
nproc_per_node=nproc_per_node,
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
)
.split(),
env=env,
check=True,
)
sys.exit(process.returncode)
elif command in COMMAND_MAP:
COMMAND_MAP[command]()
else:
print(f"Unknown command: {command}.\n{USAGE}")
launcher.launch()
if __name__ == "__main__":

View File

@@ -16,6 +16,9 @@
# limitations under the License.
from collections import OrderedDict
VERSION = "0.9.4.dev0"
@@ -28,20 +31,20 @@ def print_env() -> None:
import peft
import torch
import transformers
import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
info = {
"`llamafactory` version": VERSION,
"Platform": platform.platform(),
"Python version": platform.python_version(),
"PyTorch version": torch.__version__,
"Transformers version": transformers.__version__,
"Datasets version": datasets.__version__,
"Accelerate version": accelerate.__version__,
"PEFT version": peft.__version__,
"TRL version": trl.__version__,
}
info = OrderedDict(
{
"`llamafactory` version": VERSION,
"Platform": platform.platform(),
"Python version": platform.python_version(),
"PyTorch version": torch.__version__,
"Transformers version": transformers.__version__,
"Datasets version": datasets.__version__,
"Accelerate version": accelerate.__version__,
"PEFT version": peft.__version__,
}
)
if is_torch_cuda_available():
info["PyTorch version"] += " (GPU)"
@@ -54,6 +57,13 @@ def print_env() -> None:
info["NPU type"] = torch.npu.get_device_name()
info["CANN version"] = torch.version.cann
try:
import trl # type: ignore
info["TRL version"] = trl.__version__
except Exception:
pass
try:
import deepspeed # type: ignore

View File

@@ -12,46 +12,169 @@
# See the License for the specific language governing permissions and
# limitations under the License.
def run_api():
from llamafactory.api.app import run_api as _run_api
_run_api()
import os
import subprocess
import sys
from copy import deepcopy
def run_chat():
from llamafactory.chat.chat_model import run_chat as _run_chat
return _run_chat()
USAGE = (
"-" * 70
+ "\n"
+ "| Usage: |\n"
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
+ "| llamafactory-cli train -h: train models |\n"
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
+ "| llamafactory-cli env: show environment info |\n"
+ "| llamafactory-cli version: show version info |\n"
+ "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n"
+ "-" * 70
)
def run_eval():
raise NotImplementedError("Evaluation will be deprecated in the future.")
def launch():
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
logger = logging.get_logger(__name__)
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
def export_model():
from llamafactory.train.tuner import export_model as _export_model
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
return _export_model()
# elastic launch support
max_restarts = os.getenv("MAX_RESTARTS", "0")
rdzv_id = os.getenv("RDZV_ID")
min_nnodes = os.getenv("MIN_NNODES")
max_nnodes = os.getenv("MAX_NNODES")
env = deepcopy(os.environ)
if is_env_enabled("OPTIM_TORCH", "1"):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
def run_exp():
from llamafactory.train.tuner import run_exp as _run_exp
if rdzv_id is not None:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes = nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if min_nnodes is not None and max_nnodes is not None:
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
return _run_exp() # use absolute import
process = subprocess.run(
(
"torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
"--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
"--max-restarts {max_restarts} {file_name} {args}"
)
.format(
rdzv_nnodes=rdzv_nnodes,
nproc_per_node=nproc_per_node,
rdzv_id=rdzv_id,
master_addr=master_addr,
master_port=master_port,
max_restarts=max_restarts,
file_name=__file__,
args=" ".join(sys.argv[1:]),
)
.split(),
env=env,
check=True,
)
else:
# NOTE: DO NOT USE shell=True to avoid security risk
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.format(
nnodes=nnodes,
node_rank=node_rank,
nproc_per_node=nproc_per_node,
master_addr=master_addr,
master_port=master_port,
file_name=__file__,
args=" ".join(sys.argv[1:]),
)
.split(),
env=env,
check=True,
)
sys.exit(process.returncode)
def run_web_demo():
from llamafactory.webui.interface import run_web_demo as _run_web_demo
elif command == "api":
from .api.app import run_api
return _run_web_demo()
run_api()
elif command == "chat":
from .chat.chat_model import run_chat
def run_web_ui():
from llamafactory.webui.interface import run_web_ui as _run_web_ui
run_chat()
return _run_web_ui()
elif command == "eval":
raise NotImplementedError("Evaluation will be deprecated in the future.")
elif command == "export":
from .train.tuner import export_model
export_model()
elif command == "train":
from .train.tuner import run_exp
run_exp()
elif command == "webchat":
from .webui.interface import run_web_demo
run_web_demo()
elif command == "webui":
from .webui.interface import run_web_ui
run_web_ui()
elif command == "env":
print_env()
elif command == "version":
print(WELCOME)
elif command == "help":
print(USAGE)
else:
print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__":
from llamafactory.train.tuner import run_exp # use absolute import
run_exp()

View File

@@ -0,0 +1,33 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class DataArguments:
dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset."},
)
dataset_dir: str = field(
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Cutoff length for the dataset."},
)

View File

@@ -0,0 +1,27 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@dataclass
class ModelArguments:
model: str = field(
metadata={"help": "Path to the model or model identifier from Hugging Face."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)

View File

@@ -0,0 +1,63 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
from pathlib import Path
from typing import Any, Optional, Union
from omegaconf import OmegaConf
from transformers import HfArgumentParser
from ...extras.misc import is_env_enabled
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
from .training_args import TrainingArguments
def get_args(
args: Optional[Union[dict[str, Any], list[str]]] = None,
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
"""Parse arguments from command line or config file."""
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
if args is None:
if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
else: # list of strings
args = sys.argv[1:]
if isinstance(args, dict):
(*parsed_args,) = parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
else:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args, return_remaining_strings=True)
if unknown_args and not allow_extra_keys:
print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return tuple(parsed_args)
if __name__ == "__main__":
print(get_args())

View File

@@ -0,0 +1,24 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@dataclass
class SampleArguments:
max_new_tokens: int = field(
default=128,
metadata={"help": "Maximum number of new tokens to generate."},
)

View File

@@ -0,0 +1,40 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@dataclass
class TrainingArguments:
output_dir: str = field(
default="",
metadata={"help": "Path to the output directory."},
)
micro_batch_size: int = field(
default=1,
metadata={"help": "Micro batch size for training."},
)
global_batch_size: int = field(
default=1,
metadata={"help": "Global batch size for training."},
)
learning_rate: float = field(
default=1e-4,
metadata={"help": "Learning rate for training."},
)
bf16: bool = field(
default=False,
metadata={"help": "Use bf16 for training."},
)

View File

@@ -0,0 +1,35 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config.training_args import TrainingArguments
from ..extras.types import DataLoader, Model, Processor
class BaseTrainer:
def __init__(
self,
args: TrainingArguments,
model: Model,
processor: Processor,
data_loader: DataLoader,
) -> None:
self.args = args
self.model = model
self.processor = processor
self.data_loader = data_loader
self.optimizer = None
self.lr_scheduler = None
def fit(self) -> None:
pass

View File

@@ -0,0 +1,20 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config.sample_args import SampleArguments
class ChatSampler:
def __init__(self, sample_args: SampleArguments) -> None:
self.args = sample_args

View File

@@ -0,0 +1,75 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from ..config.data_args import DataArguments
from ..extras.types import DataLoader, Dataset, Processor
class DataCollator:
def __init__(self, processor: Processor) -> None:
self.processor = processor
class DatasetPathMixin:
args: DataArguments
def _abspath(self, path: str) -> str:
return os.path.abspath(os.path.expanduser(os.path.join(self.args.dataset_dir, path)))
def _exists(self, path: str) -> bool:
return os.path.exists(self._abspath(path))
def _isfile(self, path: str) -> bool:
return os.path.isfile(self._abspath(path))
class DataEngine(DatasetPathMixin):
def __init__(self, data_args: DataArguments) -> None:
self.args = data_args
self.datasets: dict[str, Dataset] = {}
dataset_info = self.get_dataset_info()
self.load_dataset(dataset_info)
def get_dataset_info(self) -> dict:
"""Get dataset info from dataset path.
Returns:
dict: Dataset info.
"""
if self.args.dataset.endswith(".yaml") and self._isfile(self.args.dataset): # local file
return OmegaConf.load(self._abspath(self.args.dataset))
elif self.args.dataset.endswith(".yaml"): # hf hub uri
repo_id, filename = os.path.split(self.args.dataset)
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
return OmegaConf.load(filepath)
elif self._exists(self.args.dataset): # local file(s)
return {"default": {"file_name": self.args.dataset}}
else: # hf hub dataset
return {"default": {"hf_hub_url": self.args.dataset}}
def load_dataset(self, dataset_info: dict) -> None:
for key, value in dataset_info.items():
if "hf_hub_url" in value:
dataset_info[key] = load_dataset(value["hf_hub_url"])
elif "file_name" in value:
dataset_info[key] = load_dataset(value["file_name"])
def get_data_loader(self, processor: Processor) -> DataLoader:
pass

View File

@@ -0,0 +1,27 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config.model_args import ModelArguments
from ..extras.types import Model, Processor
class ModelEngine:
def __init__(self, model_args: ModelArguments) -> None:
self.args = model_args
def get_model(self) -> Model:
pass
def get_processor(self) -> Processor:
pass

View File

@@ -0,0 +1,32 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from datasets import Dataset as HFDataset
from datasets import IterableDataset
from torch.utils.data import DataLoader as TorchDataLoader
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
Dataset = Union[HFDataset, IterableDataset]
DataLoader = TorchDataLoader
Model = PreTrainedModel
Processor = Union[PreTrainedTokenizer, ProcessorMixin]
else:
Dataset = None
DataLoader = None
Model = None
Processor = None

View File

@@ -12,22 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
def run_train():
raise NotImplementedError("Please use `llamafactory-cli sft` or `llamafactory-cli rm`.")
from ..extras.env import VERSION, print_env
def run_chat():
from llamafactory.v1.core.chat_sampler import Sampler
Sampler().cli()
USAGE = (
"-" * 70
+ "\n"
+ "| Usage: |\n"
+ "| llamafactory-cli sft -h: train models |\n"
+ "| llamafactory-cli version: show version info |\n"
+ "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n"
+ "-" * 70
)
def run_sft():
from llamafactory.v1.train.sft import SFTTrainer
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
SFTTrainer().run()
def launch():
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "sft":
from .trainers.sft_trainer import run_sft
run_sft()
elif command == "env":
print_env()
elif command == "version":
print(WELCOME)
elif command == "help":
print(USAGE)
else:
print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__":
run_train()
pass

View File

@@ -0,0 +1,26 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class Template:
user_template: str
assistant_template: str
system_template: str
def render_message(self, message: "dict[str, str]") -> str:
return self.user_template.format(**message)

View File

@@ -0,0 +1,34 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config.parser import get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
class SFTTrainer(BaseTrainer):
pass
def run_sft():
model_args, data_args, training_args, _ = get_args()
model_engine = ModelEngine(model_args)
data_engine = DataEngine(data_args)
model = model_engine.get_model()
processor = model_engine.get_processor()
data_loader = data_engine.get_data_loader(processor)
trainer = SFTTrainer(training_args, model, processor, data_loader)
trainer.fit()

View File

@@ -36,8 +36,8 @@ from ..extras.misc import use_modelscope, use_openmind
logger = logging.get_logger(__name__)
DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_CACHE_DIR = "llamaboard_cache"
DEFAULT_CONFIG_DIR = "llamaboard_config"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"