mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 08:33:38 +00:00
[train] KTransformers SFT as backend engine for LLaMA-Factory (#9400)
Co-authored-by: jimmy128 <jimmy128@noreply.gitcode.com> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -20,6 +20,8 @@ from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import EngineName
|
||||
from .model_utils.ktransformers import get_kt_peft_model, load_kt_peft_model
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
@@ -164,6 +166,10 @@ def _setup_lora_tuning(
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
||||
is_mergeable = False
|
||||
|
||||
if model_args.use_kt:
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Up to now, KTransformers model only accepts a single adapter, for more features, you can contact with us."
|
||||
is_mergeable = False
|
||||
|
||||
if model_args.use_unsloth:
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
|
||||
is_mergeable = False
|
||||
@@ -182,6 +188,10 @@ def _setup_lora_tuning(
|
||||
"token": model_args.hf_hub_token,
|
||||
}
|
||||
|
||||
if model_args.use_kt:
|
||||
if model_args.infer_backend != EngineName.KT:
|
||||
raise ValueError("We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers.")
|
||||
|
||||
for adapter in adapter_to_merge:
|
||||
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
||||
model = model.merge_and_unload()
|
||||
@@ -190,7 +200,9 @@ def _setup_lora_tuning(
|
||||
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
if model_args.use_unsloth:
|
||||
if model_args.use_kt:
|
||||
model = load_kt_peft_model(model_args, model)
|
||||
elif model_args.use_unsloth:
|
||||
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
||||
else:
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
@@ -203,6 +215,16 @@ def _setup_lora_tuning(
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
if model_args.use_kt:
|
||||
new_list = []
|
||||
for m in target_modules:
|
||||
if m in ('down_proj', 'up_proj', 'gate_proj'):
|
||||
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
|
||||
elif m not in ('generate_linear', 'orig_module', 'prefill_linear'):
|
||||
new_list.append(m)
|
||||
|
||||
target_modules[:] = new_list
|
||||
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
||||
|
||||
@@ -245,7 +267,21 @@ def _setup_lora_tuning(
|
||||
"modules_to_save": finetuning_args.additional_target,
|
||||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
if model_args.use_kt:
|
||||
if finetuning_args.finetuning_type == "oft":
|
||||
raise ValueError("KTransformers is currently not supported for OFT.")
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
**peft_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError("KTransformers is currently only supported for LoRA.")
|
||||
|
||||
model = get_kt_peft_model(model, peft_config)
|
||||
print(f"KT_model:{model}")
|
||||
elif model_args.use_unsloth:
|
||||
if finetuning_args.finetuning_type == "oft":
|
||||
raise ValueError("Unsloth is currently not supported for OFT.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user