mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[train] KTransformers SFT as backend engine for LLaMA-Factory (#9400)
Co-authored-by: jimmy128 <jimmy128@noreply.gitcode.com> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -439,7 +439,6 @@ class SwanLabArguments:
|
||||
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(
|
||||
SwanLabArguments,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
||||
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
@@ -475,9 +475,51 @@ class SGLangArguments:
|
||||
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTransformersArguments:
|
||||
r"""Arguments pertaining to the KT training."""
|
||||
|
||||
use_kt: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
|
||||
)
|
||||
kt_optimize_rule: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."},
|
||||
)
|
||||
cpu_infer: Optional[int] = field(
|
||||
default=32,
|
||||
metadata={"help": "Number Of CPU Cores Used For Computation."},
|
||||
)
|
||||
chunk_size: Optional[int] = field(
|
||||
default=8192,
|
||||
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
|
||||
)
|
||||
mode: Optional[str] = field(
|
||||
default="normal",
|
||||
metadata={"help": "Normal Or Long_Context For Llama Models."},
|
||||
)
|
||||
|
||||
kt_maxlen: int = field(
|
||||
default=4096,
|
||||
metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
|
||||
)
|
||||
kt_use_cuda_graph: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
|
||||
)
|
||||
kt_mode: str = field(
|
||||
default="normal",
|
||||
metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
|
||||
)
|
||||
kt_force_think: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Force-Think Toggle For The KT Engine."},
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(
|
||||
SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
|
||||
SGLangArguments, VllmArguments, KTransformersArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
|
||||
):
|
||||
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
|
||||
|
||||
@@ -156,6 +156,9 @@ def _check_extra_dependencies(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
training_args: Optional["TrainingArguments"] = None,
|
||||
) -> None:
|
||||
if model_args.use_kt:
|
||||
check_version("ktransformers", mandatory=True)
|
||||
|
||||
if model_args.use_unsloth:
|
||||
check_version("unsloth", mandatory=True)
|
||||
|
||||
@@ -282,13 +285,16 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
if model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
||||
if finetuning_args.reward_model_type == "lora" and model_args.use_kt:
|
||||
raise ValueError("KTransformers does not support lora reward model.")
|
||||
|
||||
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
|
||||
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
|
||||
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||
|
||||
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||
@@ -350,6 +356,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if model_args.use_kt and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
|
||||
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
|
||||
|
||||
|
||||
@@ -90,7 +90,6 @@ class RayArguments:
|
||||
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
|
||||
self.ray_storage_filesystem = fs.GcsFileSystem()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
Reference in New Issue
Block a user