diff --git a/examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml b/examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml index b8eceb27d..3fd78d13f 100644 --- a/examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml +++ b/examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml @@ -7,7 +7,7 @@ prefill_device: "cuda" - match: - name: "^lm_head$" # regular expression + name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types @@ -18,7 +18,7 @@ prefill_op: "KLinearTorch" # - match: -# name: "^model\\.layers\\..*$" # regular expression +# name: "^model\\.layers\\..*$" # regular expression # class: torch.nn.Linear # only match modules matching name and class simultaneously # replace: # class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types @@ -28,7 +28,7 @@ # generate_op: "KLinearTorch" # prefill_op: "KLinearTorch" - match: - name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression + name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types @@ -77,4 +77,4 @@ replace: class: "ktransformers.operators.models.KQwen3MoeModel" kwargs: - per_layer_prefill_intput_threshold: 0 \ No newline at end of file + per_layer_prefill_intput_threshold: 0 diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 9ee13a002..02c100ec8 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -174,6 +174,10 @@ class BaseModelArguments: default=True, metadata={"help": "Whether or not to use KV cache in generation."}, ) + use_v1_kernels: bool = field( + default=False, + metadata={"help": "Whether or not to use high-performance kernels in training."}, + ) infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( default="auto", metadata={"help": "Data type for model weights and activations at inference."}, diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 10c257a5a..25710c31d 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -213,6 +213,17 @@ def load_model( else: model.train() + # Borrowing the kernel plugins ability of v1 to temporarily apply the NPU fusion operator to v0, + # it is turned off by default, and can be discarded after the transition period ends. + if model_args.use_v1_kernels and is_trainable: + logger.warning_rank0( + "You are try to using future feature about kernels, please note that this feature " + "is not supported for all models. If get any error, please disable this feature, or report the issue." + ) + from ..v1.plugins.model_plugins.kernels.registry import apply_available_kernels + + model = apply_available_kernels(model) + trainable_params, all_param = count_parameters(model) if is_trainable: param_stats = (