[model] update kt code (#9406)
This commit is contained in:
@@ -332,6 +332,8 @@ Choose your path:
|
|||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||||
>
|
>
|
||||||
|
> If the model has both reasoning and non-reasoning versions, please use the `_nothink` suffix to distinguish between them. For example, `qwen3` and `qwen3_nothink`.
|
||||||
|
>
|
||||||
> Remember to use the **SAME** template in training and inference.
|
> Remember to use the **SAME** template in training and inference.
|
||||||
>
|
>
|
||||||
> \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check.
|
> \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check.
|
||||||
|
|||||||
@@ -334,6 +334,8 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||||
>
|
>
|
||||||
|
> 如果模型有推理 / 非推理两个版本,请使用 `_nothink` 后缀来区分不同的模板。例如 `qwen3` 和 `qwen3_nothink`。
|
||||||
|
>
|
||||||
> 请务必在训练和推理时采用**完全一致**的模板。
|
> 请务必在训练和推理时采用**完全一致**的模板。
|
||||||
>
|
>
|
||||||
> \*:您需要从 main 分支安装 `transformers` 并使用 `DISABLE_VERSION_CHECK=1` 来跳过版本检查。
|
> \*:您需要从 main 分支安装 `transformers` 并使用 `DISABLE_VERSION_CHECK=1` 来跳过版本检查。
|
||||||
|
|||||||
@@ -7,4 +7,4 @@ trust_remote_code: true
|
|||||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||||
cpu_infer: 32
|
cpu_infer: 32
|
||||||
chunk_size: 8192
|
chunk_size: 8192
|
||||||
|
|||||||
@@ -6,4 +6,4 @@ trust_remote_code: true
|
|||||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||||
cpu_infer: 32
|
cpu_infer: 32
|
||||||
chunk_size: 8192
|
chunk_size: 8192
|
||||||
|
|||||||
@@ -7,4 +7,4 @@ trust_remote_code: true
|
|||||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||||
cpu_infer: 32
|
cpu_infer: 32
|
||||||
chunk_size: 8192
|
chunk_size: 8192
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
generate_device: "cuda"
|
generate_device: "cuda"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -66,4 +66,4 @@
|
|||||||
class: "default"
|
class: "default"
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cpu"
|
||||||
prefill_device: "cpu"
|
prefill_device: "cpu"
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
generate_device: "cuda"
|
generate_device: "cuda"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -65,4 +65,4 @@
|
|||||||
class: "default"
|
class: "default"
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cpu"
|
||||||
prefill_device: "cpu"
|
prefill_device: "cpu"
|
||||||
|
|||||||
@@ -24,7 +24,7 @@
|
|||||||
prefill_device: "cuda:1"
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(0|[1-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
name: "^model\\.layers\\.(0|[1-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -35,7 +35,7 @@
|
|||||||
prefill_op: "KLinearTorch"
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.([12][0-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
name: "^model\\.layers\\.([12][0-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -44,7 +44,7 @@
|
|||||||
prefill_device: "cuda:1"
|
prefill_device: "cuda:1"
|
||||||
generate_op: "KLinearTorch"
|
generate_op: "KLinearTorch"
|
||||||
prefill_op: "KLinearTorch"
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
|
name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
|
||||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||||
@@ -108,7 +108,7 @@
|
|||||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
kwargs:
|
kwargs:
|
||||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
transfer_map:
|
transfer_map:
|
||||||
10: "cuda:1"
|
10: "cuda:1"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
generate_device: "cuda"
|
generate_device: "cuda"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -66,4 +66,4 @@
|
|||||||
class: "default"
|
class: "default"
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cpu"
|
||||||
prefill_device: "cpu"
|
prefill_device: "cpu"
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
generate_device: "cuda"
|
generate_device: "cuda"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -65,4 +65,4 @@
|
|||||||
class: "default"
|
class: "default"
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cpu"
|
||||||
prefill_device: "cpu"
|
prefill_device: "cpu"
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
generate_device: "cuda"
|
generate_device: "cuda"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -65,4 +65,4 @@
|
|||||||
class: "default"
|
class: "default"
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cpu"
|
||||||
prefill_device: "cpu"
|
prefill_device: "cpu"
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^lm_head$" # regular expression
|
name: "^lm_head$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -18,7 +18,7 @@
|
|||||||
prefill_op: "KLinearTorch"
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -74,4 +74,4 @@
|
|||||||
class: "default"
|
class: "default"
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cpu"
|
||||||
prefill_device: "cpu"
|
prefill_device: "cpu"
|
||||||
|
|||||||
@@ -24,7 +24,7 @@
|
|||||||
prefill_device: "cuda:1"
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -35,7 +35,7 @@
|
|||||||
prefill_op: "KLinearTorch"
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -44,7 +44,7 @@
|
|||||||
prefill_device: "cuda:1"
|
prefill_device: "cuda:1"
|
||||||
generate_op: "KLinearTorch"
|
generate_op: "KLinearTorch"
|
||||||
prefill_op: "KLinearTorch"
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
|
||||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
@@ -125,7 +125,7 @@
|
|||||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
kwargs:
|
kwargs:
|
||||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
transfer_map:
|
transfer_map:
|
||||||
30: "cuda:1"
|
30: "cuda:1"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^lm_head$" # regular expression
|
name: "^lm_head$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -18,7 +18,7 @@
|
|||||||
prefill_op: "KLinearTorch"
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
@@ -74,4 +74,4 @@
|
|||||||
class: "default"
|
class: "default"
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cpu"
|
||||||
prefill_device: "cpu"
|
prefill_device: "cpu"
|
||||||
|
|||||||
@@ -65,8 +65,7 @@ class KTransformersEngine(BaseEngine):
|
|||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
||||||
|
|
||||||
self.model = load_model(
|
self.model = load_model(
|
||||||
self.tokenizer, model_args, finetuning_args,
|
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||||
is_trainable=False, add_valuehead=(not self.can_generate)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.generating_args = generating_args.to_dict()
|
self.generating_args = generating_args.to_dict()
|
||||||
@@ -143,14 +142,14 @@ class KTransformersEngine(BaseEngine):
|
|||||||
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
||||||
if self.force_think:
|
if self.force_think:
|
||||||
think = torch.tensor(
|
think = torch.tensor(
|
||||||
[self.tokenizer.encode("<think>\n", add_special_tokens=False)],
|
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
|
||||||
dtype=torch.long, device=device
|
|
||||||
)
|
)
|
||||||
input_tensor = torch.cat([input_tensor, think], dim=1)
|
input_tensor = torch.cat([input_tensor, think], dim=1)
|
||||||
|
|
||||||
use_flashinfer = (
|
use_flashinfer = (
|
||||||
platform.system() != "Windows"
|
platform.system() != "Windows"
|
||||||
and getattr(self.model.config, "architectures", [""])[0] in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
|
and getattr(self.model.config, "architectures", [""])[0]
|
||||||
|
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
|
||||||
and flashinfer_enabled
|
and flashinfer_enabled
|
||||||
and get_compute_capability() >= 8
|
and get_compute_capability() >= 8
|
||||||
and device_manager.gpu_vendor == GPUVendor.NVIDIA
|
and device_manager.gpu_vendor == GPUVendor.NVIDIA
|
||||||
@@ -159,19 +158,32 @@ class KTransformersEngine(BaseEngine):
|
|||||||
def make_gen():
|
def make_gen():
|
||||||
if use_flashinfer:
|
if use_flashinfer:
|
||||||
return prefill_and_generate_capture(
|
return prefill_and_generate_capture(
|
||||||
self.model, self.tokenizer, input_tensor, max_tokens, self.use_cuda_graph,
|
self.model,
|
||||||
mode=self.mode, force_think=self.force_think, chunk_size=self.chunk_size,
|
self.tokenizer,
|
||||||
|
input_tensor,
|
||||||
|
max_tokens,
|
||||||
|
self.use_cuda_graph,
|
||||||
|
mode=self.mode,
|
||||||
|
force_think=self.force_think,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
use_flashinfer_mla=True,
|
use_flashinfer_mla=True,
|
||||||
num_heads=self.model.config.num_attention_heads,
|
num_heads=self.model.config.num_attention_heads,
|
||||||
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
|
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
|
||||||
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
|
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
|
||||||
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0) + getattr(self.model.config, "qk_nope_head_dim", 0),
|
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
|
||||||
|
+ getattr(self.model.config, "qk_nope_head_dim", 0),
|
||||||
echo_stream=False,
|
echo_stream=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return prefill_and_generate_capture(
|
return prefill_and_generate_capture(
|
||||||
self.model, self.tokenizer, input_tensor, max_tokens, self.use_cuda_graph,
|
self.model,
|
||||||
mode=self.mode, force_think=self.force_think, chunk_size=self.chunk_size,
|
self.tokenizer,
|
||||||
|
input_tensor,
|
||||||
|
max_tokens,
|
||||||
|
self.use_cuda_graph,
|
||||||
|
mode=self.mode,
|
||||||
|
force_think=self.force_think,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
echo_stream=False,
|
echo_stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -182,9 +194,11 @@ class KTransformersEngine(BaseEngine):
|
|||||||
try:
|
try:
|
||||||
gen = make_gen()
|
gen = make_gen()
|
||||||
if hasattr(gen, "__aiter__"):
|
if hasattr(gen, "__aiter__"):
|
||||||
|
|
||||||
async def drain_async():
|
async def drain_async():
|
||||||
async for t in gen:
|
async for t in gen:
|
||||||
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
|
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
|
||||||
|
|
||||||
asyncio.run(drain_async())
|
asyncio.run(drain_async())
|
||||||
elif hasattr(gen, "__iter__"):
|
elif hasattr(gen, "__iter__"):
|
||||||
for t in gen:
|
for t in gen:
|
||||||
@@ -252,7 +266,7 @@ class KTransformersEngine(BaseEngine):
|
|||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
produced = ""
|
produced = ""
|
||||||
async for t in self._generate(messages, system, tools, **input_kwargs):
|
async for t in self._generate(messages, system, tools, **input_kwargs):
|
||||||
delta = t[len(produced):] if t.startswith(produced) else t
|
delta = t[len(produced) :] if t.startswith(produced) else t
|
||||||
produced = t
|
produced = t
|
||||||
if delta:
|
if delta:
|
||||||
yield delta
|
yield delta
|
||||||
|
|||||||
@@ -616,7 +616,14 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
|
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
|
||||||
template.default_system = data_args.default_system
|
template.default_system = data_args.default_system
|
||||||
|
|
||||||
template.enable_thinking = data_args.enable_thinking
|
if isinstance(template, ReasoningTemplate):
|
||||||
|
logger.warning_rank0(
|
||||||
|
"You are using reasoning template, "
|
||||||
|
"please add `_nothink` suffix if the model is not a reasoning model. "
|
||||||
|
"e.g., qwen3_vl_nothink"
|
||||||
|
)
|
||||||
|
template.enable_thinking = data_args.enable_thinking
|
||||||
|
|
||||||
template.fix_special_tokens(tokenizer)
|
template.fix_special_tokens(tokenizer)
|
||||||
template.fix_jinja_template(tokenizer)
|
template.fix_jinja_template(tokenizer)
|
||||||
return template
|
return template
|
||||||
|
|||||||
@@ -312,9 +312,11 @@ def use_openmind() -> bool:
|
|||||||
def use_ray() -> bool:
|
def use_ray() -> bool:
|
||||||
return is_env_enabled("USE_RAY")
|
return is_env_enabled("USE_RAY")
|
||||||
|
|
||||||
|
|
||||||
def use_kt() -> bool:
|
def use_kt() -> bool:
|
||||||
return is_env_enabled("USE_KT")
|
return is_env_enabled("USE_KT")
|
||||||
|
|
||||||
|
|
||||||
def find_available_port() -> int:
|
def find_available_port() -> int:
|
||||||
r"""Find an available port on the local machine."""
|
r"""Find an available port on the local machine."""
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ def is_starlette_available():
|
|||||||
def is_transformers_version_greater_than(content: str):
|
def is_transformers_version_greater_than(content: str):
|
||||||
return _get_package_version("transformers") >= version.parse(content)
|
return _get_package_version("transformers") >= version.parse(content)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def is_torch_version_greater_than(content: str):
|
def is_torch_version_greater_than(content: str):
|
||||||
return _get_package_version("torch") >= version.parse(content)
|
return _get_package_version("torch") >= version.parse(content)
|
||||||
|
|||||||
@@ -439,6 +439,7 @@ class SwanLabArguments:
|
|||||||
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FinetuningArguments(
|
class FinetuningArguments(
|
||||||
SwanLabArguments,
|
SwanLabArguments,
|
||||||
|
|||||||
@@ -485,7 +485,9 @@ class KTransformersArguments:
|
|||||||
)
|
)
|
||||||
kt_optimize_rule: Optional[str] = field(
|
kt_optimize_rule: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."},
|
metadata={
|
||||||
|
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
cpu_infer: Optional[int] = field(
|
cpu_infer: Optional[int] = field(
|
||||||
default=32,
|
default=32,
|
||||||
@@ -517,9 +519,16 @@ class KTransformersArguments:
|
|||||||
metadata={"help": "Force-Think Toggle For The KT Engine."},
|
metadata={"help": "Force-Think Toggle For The KT Engine."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments(
|
class ModelArguments(
|
||||||
SGLangArguments, VllmArguments, KTransformersArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
|
SGLangArguments,
|
||||||
|
VllmArguments,
|
||||||
|
KTransformersArguments,
|
||||||
|
ExportArguments,
|
||||||
|
ProcessorArguments,
|
||||||
|
QuantizationArguments,
|
||||||
|
BaseModelArguments,
|
||||||
):
|
):
|
||||||
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||||
|
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class RayArguments:
|
|||||||
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
|
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
|
||||||
self.ray_storage_filesystem = fs.GcsFileSystem()
|
self.ray_storage_filesystem = fs.GcsFileSystem()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
||||||
r"""Arguments pertaining to the trainer."""
|
r"""Arguments pertaining to the trainer."""
|
||||||
|
|||||||
@@ -57,7 +57,9 @@ def launch():
|
|||||||
if is_env_enabled("USE_MCA"): # force use torchrun
|
if is_env_enabled("USE_MCA"): # force use torchrun
|
||||||
os.environ["FORCE_TORCHRUN"] = "1"
|
os.environ["FORCE_TORCHRUN"] = "1"
|
||||||
|
|
||||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())):
|
if command == "train" and (
|
||||||
|
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
|
||||||
|
):
|
||||||
# launch distributed training
|
# launch distributed training
|
||||||
nnodes = os.getenv("NNODES", "1")
|
nnodes = os.getenv("NNODES", "1")
|
||||||
node_rank = os.getenv("NODE_RANK", "0")
|
node_rank = os.getenv("NODE_RANK", "0")
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ def _setup_lora_tuning(
|
|||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
|
|
||||||
if model_args.use_kt:
|
if model_args.use_kt:
|
||||||
assert len(model_args.adapter_name_or_path) == 1, "Up to now, KTransformers model only accepts a single adapter, for more features, you can contact with us."
|
assert len(model_args.adapter_name_or_path) == 1, "KTransformers model only accepts a single adapter"
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
|
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
@@ -190,7 +190,9 @@ def _setup_lora_tuning(
|
|||||||
|
|
||||||
if model_args.use_kt:
|
if model_args.use_kt:
|
||||||
if model_args.infer_backend != EngineName.KT:
|
if model_args.infer_backend != EngineName.KT:
|
||||||
raise ValueError("We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers.")
|
raise ValueError(
|
||||||
|
"We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers."
|
||||||
|
)
|
||||||
|
|
||||||
for adapter in adapter_to_merge:
|
for adapter in adapter_to_merge:
|
||||||
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
||||||
@@ -218,9 +220,9 @@ def _setup_lora_tuning(
|
|||||||
if model_args.use_kt:
|
if model_args.use_kt:
|
||||||
new_list = []
|
new_list = []
|
||||||
for m in target_modules:
|
for m in target_modules:
|
||||||
if m in ('down_proj', 'up_proj', 'gate_proj'):
|
if m in ("down_proj", "up_proj", "gate_proj"):
|
||||||
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
|
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
|
||||||
elif m not in ('generate_linear', 'orig_module', 'prefill_linear'):
|
elif m not in ("generate_linear", "orig_module", "prefill_linear"):
|
||||||
new_list.append(m)
|
new_list.append(m)
|
||||||
|
|
||||||
target_modules[:] = new_list
|
target_modules[:] = new_list
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ def load_model(
|
|||||||
lazy_load = False
|
lazy_load = False
|
||||||
if model_args.use_kt:
|
if model_args.use_kt:
|
||||||
from ktransformers.sft.monkey_patch_torch_module import install_patch
|
from ktransformers.sft.monkey_patch_torch_module import install_patch
|
||||||
|
|
||||||
install_patch()
|
install_patch()
|
||||||
model = load_kt_pretrained_model(config, model_args)
|
model = load_kt_pretrained_model(config, model_args)
|
||||||
elif model_args.use_unsloth:
|
elif model_args.use_unsloth:
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
|||||||
requested_attn_implementation = "sdpa"
|
requested_attn_implementation = "sdpa"
|
||||||
elif model_args.flash_attn == AttentionFunction.FA2:
|
elif model_args.flash_attn == AttentionFunction.FA2:
|
||||||
from transformers import is_torch_npu_available
|
from transformers import is_torch_npu_available
|
||||||
|
|
||||||
if not (is_flash_attn_2_available() or is_torch_npu_available()):
|
if not (is_flash_attn_2_available() or is_torch_npu_available()):
|
||||||
logger.warning_rank0("FlashAttention-2 is not installed.")
|
logger.warning_rank0("FlashAttention-2 is not installed.")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import importlib.util as _u
|
import importlib.util as _u
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -43,6 +43,7 @@ if KT_AVAILABLE:
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_kt_kwargs(
|
def _get_kt_kwargs(
|
||||||
config: "PretrainedConfig",
|
config: "PretrainedConfig",
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
@@ -64,9 +65,7 @@ def _get_kt_kwargs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_kt_pretrained_model(
|
def load_kt_pretrained_model(config: "PretrainedConfig", model_args: "ModelArguments") -> "PreTrainedModel":
|
||||||
config: "PretrainedConfig", model_args: "ModelArguments"
|
|
||||||
) -> Optional["PreTrainedModel"]:
|
|
||||||
r"""Optionally load pretrained model with KTransformers. Used in training."""
|
r"""Optionally load pretrained model with KTransformers. Used in training."""
|
||||||
custom_models = {
|
custom_models = {
|
||||||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||||
@@ -79,7 +78,7 @@ def load_kt_pretrained_model(
|
|||||||
Config().chunk_size = model_args.chunk_size
|
Config().chunk_size = model_args.chunk_size
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
|
||||||
|
|
||||||
if model_args.mode == 'long_context':
|
if model_args.mode == "long_context":
|
||||||
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
else:
|
else:
|
||||||
@@ -88,9 +87,7 @@ def load_kt_pretrained_model(
|
|||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
if config.architectures[0] in custom_models:
|
if config.architectures[0] in custom_models:
|
||||||
print("using custom modeling_xxx.py.")
|
print("using custom modeling_xxx.py.")
|
||||||
if (
|
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
||||||
"Qwen2Moe" in config.architectures[0]
|
|
||||||
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
|
||||||
config._attn_implementation = "flash_attention_2"
|
config._attn_implementation = "flash_attention_2"
|
||||||
if "Llama" in config.architectures[0]:
|
if "Llama" in config.architectures[0]:
|
||||||
config._attn_implementation = "eager"
|
config._attn_implementation = "eager"
|
||||||
@@ -115,21 +112,17 @@ def load_kt_pretrained_model(
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_kt_peft_model(
|
def get_kt_peft_model(model: "PreTrainedModel", peft_kwargs: dict[str, Any]) -> "PreTrainedModel":
|
||||||
model: "PreTrainedModel", peft_kwargs: dict[str, Any]
|
|
||||||
) -> "PreTrainedModel":
|
|
||||||
r"""Get the peft model for the pretrained model with KTransformers. Used in training."""
|
r"""Get the peft model for the pretrained model with KTransformers. Used in training."""
|
||||||
from ktransformers.sft.peft_utils.mapping import get_peft_model
|
from ktransformers.sft.peft_utils.mapping import get_peft_model
|
||||||
|
|
||||||
return get_peft_model(model, peft_kwargs)
|
return get_peft_model(model, peft_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_kt_peft_model(
|
def load_kt_peft_model(model_args: "ModelArguments", model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
model_args: "ModelArguments", model: "PreTrainedModel",
|
|
||||||
) -> "PreTrainedModel":
|
|
||||||
r"""Load peft model with KTransformers. Used in both training and inference."""
|
r"""Load peft model with KTransformers. Used in both training and inference."""
|
||||||
load_adapter_name_or_path = model_args.adapter_name_or_path[0]
|
load_adapter_name_or_path = model_args.adapter_name_or_path[0]
|
||||||
if load_adapter_name_or_path.endswith('.gguf'):
|
if load_adapter_name_or_path.endswith(".gguf"):
|
||||||
inject_lora_layer(model, load_adapter_name_or_path)
|
inject_lora_layer(model, load_adapter_name_or_path)
|
||||||
adapter_gguf_loader = GGUFLoader(load_adapter_name_or_path)
|
adapter_gguf_loader = GGUFLoader(load_adapter_name_or_path)
|
||||||
load_weights(model, adapter_gguf_loader, adapter_gguf=True)
|
load_weights(model, adapter_gguf_loader, adapter_gguf=True)
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ def run_sft(
|
|||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
|
||||||
from ktransformers.util.globals import GLOBAL_CONFIG
|
from ktransformers.util.globals import GLOBAL_CONFIG
|
||||||
|
|
||||||
GLOBAL_CONFIG._config["mod"] = "sft"
|
GLOBAL_CONFIG._config["mod"] = "sft"
|
||||||
|
|
||||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||||
@@ -66,12 +67,13 @@ def run_sft(
|
|||||||
# Metric utils
|
# Metric utils
|
||||||
metric_module = {}
|
metric_module = {}
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet. if you do need it, please open an issue.")
|
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
|
||||||
elif finetuning_args.compute_accuracy:
|
elif finetuning_args.compute_accuracy:
|
||||||
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet. if you do need it, please open an issue.")
|
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
from ktransformers.sft.lora import KTrainer
|
from ktransformers.sft.lora import KTrainer
|
||||||
|
|
||||||
trainer = KTrainer(
|
trainer = KTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
|
|||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.misc import infer_optim_dtype
|
from ..extras.misc import infer_optim_dtype
|
||||||
from ..extras.packages import is_kt_available, is_mcore_adapter_available, is_ray_available
|
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||||
@@ -86,12 +86,12 @@ def _training_function(config: dict[str, Any]) -> None:
|
|||||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
elif finetuning_args.stage == "sft":
|
elif finetuning_args.stage == "sft":
|
||||||
if model_args.use_kt:
|
if model_args.use_kt:
|
||||||
if not is_kt_available():
|
|
||||||
raise ImportError("KTransformers is not installed. Please install it with `pip install ktransformers`.")
|
|
||||||
from .ksft.workflow import run_sft as run_sft_kt
|
from .ksft.workflow import run_sft as run_sft_kt
|
||||||
|
|
||||||
run_sft_kt(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
run_sft_kt(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
|
else:
|
||||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
|
|
||||||
elif finetuning_args.stage == "rm":
|
elif finetuning_args.stage == "rm":
|
||||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
elif finetuning_args.stage == "ppo":
|
elif finetuning_args.stage == "ppo":
|
||||||
|
|||||||
Reference in New Issue
Block a user