[model] update kt code (#9406)

This commit is contained in:
Yaowei Zheng
2025-11-05 15:27:22 +08:00
committed by GitHub
parent 56f45e826f
commit eaf963f67f
28 changed files with 108 additions and 68 deletions

View File

@@ -332,6 +332,8 @@ Choose your path:
> [!NOTE]
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
>
> If the model has both reasoning and non-reasoning versions, please use the `_nothink` suffix to distinguish between them. For example, `qwen3` and `qwen3_nothink`.
>
> Remember to use the **SAME** template in training and inference.
>
> \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check.

View File

@@ -334,6 +334,8 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
> [!NOTE]
> 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。
>
> 如果模型有推理 / 非推理两个版本,请使用 `_nothink` 后缀来区分不同的模板。例如 `qwen3` 和 `qwen3_nothink`。
>
> 请务必在训练和推理时采用**完全一致**的模板。
>
> \*:您需要从 main 分支安装 `transformers` 并使用 `DISABLE_VERSION_CHECK=1` 来跳过版本检查。

View File

@@ -7,4 +7,4 @@ trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192
chunk_size: 8192

View File

@@ -6,4 +6,4 @@ trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192
chunk_size: 8192

View File

@@ -7,4 +7,4 @@ trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192
chunk_size: 8192

View File

@@ -6,7 +6,7 @@
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -66,4 +66,4 @@
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
prefill_device: "cpu"

View File

@@ -6,7 +6,7 @@
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -65,4 +65,4 @@
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
prefill_device: "cpu"

View File

@@ -24,7 +24,7 @@
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
name: "^model\\.layers\\.(0|[1-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -35,7 +35,7 @@
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([12][0-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
name: "^model\\.layers\\.([12][0-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -44,7 +44,7 @@
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
@@ -108,7 +108,7 @@
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
transfer_map:
10: "cuda:1"
- match:

View File

@@ -6,7 +6,7 @@
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -66,4 +66,4 @@
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
prefill_device: "cpu"

View File

@@ -6,7 +6,7 @@
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -65,4 +65,4 @@
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
prefill_device: "cpu"

View File

@@ -6,7 +6,7 @@
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -65,4 +65,4 @@
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
prefill_device: "cpu"

View File

@@ -7,7 +7,7 @@
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -18,7 +18,7 @@
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -74,4 +74,4 @@
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
prefill_device: "cpu"

View File

@@ -24,7 +24,7 @@
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -35,7 +35,7 @@
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -44,7 +44,7 @@
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
@@ -125,7 +125,7 @@
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
transfer_map:
30: "cuda:1"
- match:

View File

@@ -7,7 +7,7 @@
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -18,7 +18,7 @@
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -74,4 +74,4 @@
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
prefill_device: "cpu"

View File

@@ -65,8 +65,7 @@ class KTransformersEngine(BaseEngine):
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model(
self.tokenizer, model_args, finetuning_args,
is_trainable=False, add_valuehead=(not self.can_generate)
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.generating_args = generating_args.to_dict()
@@ -143,14 +142,14 @@ class KTransformersEngine(BaseEngine):
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
if self.force_think:
think = torch.tensor(
[self.tokenizer.encode("<think>\n", add_special_tokens=False)],
dtype=torch.long, device=device
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
)
input_tensor = torch.cat([input_tensor, think], dim=1)
use_flashinfer = (
platform.system() != "Windows"
and getattr(self.model.config, "architectures", [""])[0] in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
and getattr(self.model.config, "architectures", [""])[0]
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
and flashinfer_enabled
and get_compute_capability() >= 8
and device_manager.gpu_vendor == GPUVendor.NVIDIA
@@ -159,19 +158,32 @@ class KTransformersEngine(BaseEngine):
def make_gen():
if use_flashinfer:
return prefill_and_generate_capture(
self.model, self.tokenizer, input_tensor, max_tokens, self.use_cuda_graph,
mode=self.mode, force_think=self.force_think, chunk_size=self.chunk_size,
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
use_flashinfer_mla=True,
num_heads=self.model.config.num_attention_heads,
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0) + getattr(self.model.config, "qk_nope_head_dim", 0),
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
+ getattr(self.model.config, "qk_nope_head_dim", 0),
echo_stream=False,
)
else:
return prefill_and_generate_capture(
self.model, self.tokenizer, input_tensor, max_tokens, self.use_cuda_graph,
mode=self.mode, force_think=self.force_think, chunk_size=self.chunk_size,
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
echo_stream=False,
)
@@ -182,9 +194,11 @@ class KTransformersEngine(BaseEngine):
try:
gen = make_gen()
if hasattr(gen, "__aiter__"):
async def drain_async():
async for t in gen:
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
asyncio.run(drain_async())
elif hasattr(gen, "__iter__"):
for t in gen:
@@ -252,7 +266,7 @@ class KTransformersEngine(BaseEngine):
async with self.semaphore:
produced = ""
async for t in self._generate(messages, system, tools, **input_kwargs):
delta = t[len(produced):] if t.startswith(produced) else t
delta = t[len(produced) :] if t.startswith(produced) else t
produced = t
if delta:
yield delta

View File

@@ -616,7 +616,14 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
template.default_system = data_args.default_system
template.enable_thinking = data_args.enable_thinking
if isinstance(template, ReasoningTemplate):
logger.warning_rank0(
"You are using reasoning template, "
"please add `_nothink` suffix if the model is not a reasoning model. "
"e.g., qwen3_vl_nothink"
)
template.enable_thinking = data_args.enable_thinking
template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer)
return template

View File

@@ -312,9 +312,11 @@ def use_openmind() -> bool:
def use_ray() -> bool:
return is_env_enabled("USE_RAY")
def use_kt() -> bool:
return is_env_enabled("USE_KT")
def find_available_port() -> int:
r"""Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

View File

@@ -110,6 +110,7 @@ def is_starlette_available():
def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content)
@lru_cache
def is_torch_version_greater_than(content: str):
return _get_package_version("torch") >= version.parse(content)

View File

@@ -439,6 +439,7 @@ class SwanLabArguments:
metadata={"help": "The Lark(飞书) secret for SwanLab."},
)
@dataclass
class FinetuningArguments(
SwanLabArguments,

View File

@@ -485,7 +485,9 @@ class KTransformersArguments:
)
kt_optimize_rule: Optional[str] = field(
default=None,
metadata={"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."},
metadata={
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
},
)
cpu_infer: Optional[int] = field(
default=32,
@@ -517,9 +519,16 @@ class KTransformersArguments:
metadata={"help": "Force-Think Toggle For The KT Engine."},
)
@dataclass
class ModelArguments(
SGLangArguments, VllmArguments, KTransformersArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
SGLangArguments,
VllmArguments,
KTransformersArguments,
ExportArguments,
ProcessorArguments,
QuantizationArguments,
BaseModelArguments,
):
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.

View File

@@ -90,6 +90,7 @@ class RayArguments:
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
self.ray_storage_filesystem = fs.GcsFileSystem()
@dataclass
class TrainingArguments(RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer."""

View File

@@ -57,7 +57,9 @@ def launch():
if is_env_enabled("USE_MCA"): # force use torchrun
os.environ["FORCE_TORCHRUN"] = "1"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())):
if command == "train" and (
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
):
# launch distributed training
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")

View File

@@ -167,7 +167,7 @@ def _setup_lora_tuning(
is_mergeable = False
if model_args.use_kt:
assert len(model_args.adapter_name_or_path) == 1, "Up to now, KTransformers model only accepts a single adapter, for more features, you can contact with us."
assert len(model_args.adapter_name_or_path) == 1, "KTransformers model only accepts a single adapter"
is_mergeable = False
if model_args.use_unsloth:
@@ -190,7 +190,9 @@ def _setup_lora_tuning(
if model_args.use_kt:
if model_args.infer_backend != EngineName.KT:
raise ValueError("We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers.")
raise ValueError(
"We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers."
)
for adapter in adapter_to_merge:
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
@@ -218,9 +220,9 @@ def _setup_lora_tuning(
if model_args.use_kt:
new_list = []
for m in target_modules:
if m in ('down_proj', 'up_proj', 'gate_proj'):
if m in ("down_proj", "up_proj", "gate_proj"):
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
elif m not in ('generate_linear', 'orig_module', 'prefill_linear'):
elif m not in ("generate_linear", "orig_module", "prefill_linear"):
new_list.append(m)
target_modules[:] = new_list

View File

@@ -146,6 +146,7 @@ def load_model(
lazy_load = False
if model_args.use_kt:
from ktransformers.sft.monkey_patch_torch_module import install_patch
install_patch()
model = load_kt_pretrained_model(config, model_args)
elif model_args.use_unsloth:

View File

@@ -59,6 +59,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == AttentionFunction.FA2:
from transformers import is_torch_npu_available
if not (is_flash_attn_2_available() or is_torch_npu_available()):
logger.warning_rank0("FlashAttention-2 is not installed.")
return

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import importlib.util as _u
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import torch
@@ -43,6 +43,7 @@ if KT_AVAILABLE:
logger = logging.get_logger(__name__)
def _get_kt_kwargs(
config: "PretrainedConfig",
model_name_or_path: str,
@@ -64,9 +65,7 @@ def _get_kt_kwargs(
}
def load_kt_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]:
def load_kt_pretrained_model(config: "PretrainedConfig", model_args: "ModelArguments") -> "PreTrainedModel":
r"""Optionally load pretrained model with KTransformers. Used in training."""
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
@@ -79,7 +78,7 @@ def load_kt_pretrained_model(
Config().chunk_size = model_args.chunk_size
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
if model_args.mode == 'long_context':
if model_args.mode == "long_context":
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else:
@@ -88,9 +87,7 @@ def load_kt_pretrained_model(
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if (
"Qwen2Moe" in config.architectures[0]
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
@@ -115,21 +112,17 @@ def load_kt_pretrained_model(
return model
def get_kt_peft_model(
model: "PreTrainedModel", peft_kwargs: dict[str, Any]
) -> "PreTrainedModel":
def get_kt_peft_model(model: "PreTrainedModel", peft_kwargs: dict[str, Any]) -> "PreTrainedModel":
r"""Get the peft model for the pretrained model with KTransformers. Used in training."""
from ktransformers.sft.peft_utils.mapping import get_peft_model
return get_peft_model(model, peft_kwargs)
def load_kt_peft_model(
model_args: "ModelArguments", model: "PreTrainedModel",
) -> "PreTrainedModel":
def load_kt_peft_model(model_args: "ModelArguments", model: "PreTrainedModel") -> "PreTrainedModel":
r"""Load peft model with KTransformers. Used in both training and inference."""
load_adapter_name_or_path = model_args.adapter_name_or_path[0]
if load_adapter_name_or_path.endswith('.gguf'):
if load_adapter_name_or_path.endswith(".gguf"):
inject_lora_layer(model, load_adapter_name_or_path)
adapter_gguf_loader = GGUFLoader(load_adapter_name_or_path)
load_weights(model, adapter_gguf_loader, adapter_gguf=True)

View File

@@ -47,6 +47,7 @@ def run_sft(
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
from ktransformers.util.globals import GLOBAL_CONFIG
GLOBAL_CONFIG._config["mod"] = "sft"
if getattr(model, "is_quantized", False) and not training_args.do_train:
@@ -66,12 +67,13 @@ def run_sft(
# Metric utils
metric_module = {}
if training_args.predict_with_generate:
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet. if you do need it, please open an issue.")
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
elif finetuning_args.compute_accuracy:
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet. if you do need it, please open an issue.")
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
# Initialize our Trainer
from ktransformers.sft.lora import KTrainer
trainer = KTrainer(
model=model,
args=training_args,

View File

@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_kt_available, is_mcore_adapter_available, is_ray_available
from ..extras.packages import is_mcore_adapter_available, is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
@@ -86,12 +86,12 @@ def _training_function(config: dict[str, Any]) -> None:
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":
if model_args.use_kt:
if not is_kt_available():
raise ImportError("KTransformers is not installed. Please install it with `pip install ktransformers`.")
from .ksft.workflow import run_sft as run_sft_kt
run_sft_kt(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
else:
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "ppo":