[train] KTransformers SFT as backend engine for LLaMA-Factory (#9400)

Co-authored-by: jimmy128 <jimmy128@noreply.gitcode.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Peilin Li
2025-11-04 15:54:12 +08:00
committed by GitHub
parent 3ae15da9c0
commit 934b3084ee
37 changed files with 2006 additions and 16 deletions

View File

@@ -20,6 +20,8 @@ from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging
from ..extras.constants import EngineName
from .model_utils.ktransformers import get_kt_peft_model, load_kt_peft_model
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
@@ -164,6 +166,10 @@ def _setup_lora_tuning(
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if model_args.use_kt:
assert len(model_args.adapter_name_or_path) == 1, "Up to now, KTransformers model only accepts a single adapter, for more features, you can contact with us."
is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
@@ -182,6 +188,10 @@ def _setup_lora_tuning(
"token": model_args.hf_hub_token,
}
if model_args.use_kt:
if model_args.infer_backend != EngineName.KT:
raise ValueError("We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers.")
for adapter in adapter_to_merge:
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
@@ -190,7 +200,9 @@ def _setup_lora_tuning(
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
if model_args.use_kt:
model = load_kt_peft_model(model_args, model)
elif model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
@@ -203,6 +215,16 @@ def _setup_lora_tuning(
else:
target_modules = finetuning_args.lora_target
if model_args.use_kt:
new_list = []
for m in target_modules:
if m in ('down_proj', 'up_proj', 'gate_proj'):
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
elif m not in ('generate_linear', 'orig_module', 'prefill_linear'):
new_list.append(m)
target_modules[:] = new_list
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
@@ -245,7 +267,21 @@ def _setup_lora_tuning(
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
if model_args.use_kt:
if finetuning_args.finetuning_type == "oft":
raise ValueError("KTransformers is currently not supported for OFT.")
if finetuning_args.finetuning_type == "lora":
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
else:
raise ValueError("KTransformers is currently only supported for LoRA.")
model = get_kt_peft_model(model, peft_config)
print(f"KT_model:{model}")
elif model_args.use_unsloth:
if finetuning_args.finetuning_type == "oft":
raise ValueError("Unsloth is currently not supported for OFT.")