[model] update kt code (#9406)
This commit is contained in:
@@ -332,6 +332,8 @@ Choose your path:
|
||||
> [!NOTE]
|
||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||
>
|
||||
> If the model has both reasoning and non-reasoning versions, please use the `_nothink` suffix to distinguish between them. For example, `qwen3` and `qwen3_nothink`.
|
||||
>
|
||||
> Remember to use the **SAME** template in training and inference.
|
||||
>
|
||||
> \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check.
|
||||
|
||||
@@ -334,6 +334,8 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
> [!NOTE]
|
||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||
>
|
||||
> 如果模型有推理 / 非推理两个版本,请使用 `_nothink` 后缀来区分不同的模板。例如 `qwen3` 和 `qwen3_nothink`。
|
||||
>
|
||||
> 请务必在训练和推理时采用**完全一致**的模板。
|
||||
>
|
||||
> \*:您需要从 main 分支安装 `transformers` 并使用 `DISABLE_VERSION_CHECK=1` 来跳过版本检查。
|
||||
|
||||
@@ -7,4 +7,4 @@ trust_remote_code: true
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
chunk_size: 8192
|
||||
|
||||
@@ -6,4 +6,4 @@ trust_remote_code: true
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
chunk_size: 8192
|
||||
|
||||
@@ -7,4 +7,4 @@ trust_remote_code: true
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
chunk_size: 8192
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -66,4 +66,4 @@
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -65,4 +65,4 @@
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -35,7 +35,7 @@
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
name: "^model\\.layers\\.([12][0-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -44,7 +44,7 @@
|
||||
prefill_device: "cuda:1"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||
@@ -108,7 +108,7 @@
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
transfer_map:
|
||||
transfer_map:
|
||||
10: "cuda:1"
|
||||
|
||||
- match:
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -66,4 +66,4 @@
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -65,4 +65,4 @@
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -65,4 +65,4 @@
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -18,7 +18,7 @@
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -74,4 +74,4 @@
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -35,7 +35,7 @@
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -44,7 +44,7 @@
|
||||
prefill_device: "cuda:1"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
@@ -125,7 +125,7 @@
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
transfer_map:
|
||||
transfer_map:
|
||||
30: "cuda:1"
|
||||
|
||||
- match:
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -18,7 +18,7 @@
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
@@ -74,4 +74,4 @@
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
@@ -65,8 +65,7 @@ class KTransformersEngine(BaseEngine):
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
||||
|
||||
self.model = load_model(
|
||||
self.tokenizer, model_args, finetuning_args,
|
||||
is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
)
|
||||
|
||||
self.generating_args = generating_args.to_dict()
|
||||
@@ -143,14 +142,14 @@ class KTransformersEngine(BaseEngine):
|
||||
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
||||
if self.force_think:
|
||||
think = torch.tensor(
|
||||
[self.tokenizer.encode("<think>\n", add_special_tokens=False)],
|
||||
dtype=torch.long, device=device
|
||||
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
|
||||
)
|
||||
input_tensor = torch.cat([input_tensor, think], dim=1)
|
||||
|
||||
use_flashinfer = (
|
||||
platform.system() != "Windows"
|
||||
and getattr(self.model.config, "architectures", [""])[0] in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
|
||||
and getattr(self.model.config, "architectures", [""])[0]
|
||||
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
|
||||
and flashinfer_enabled
|
||||
and get_compute_capability() >= 8
|
||||
and device_manager.gpu_vendor == GPUVendor.NVIDIA
|
||||
@@ -159,19 +158,32 @@ class KTransformersEngine(BaseEngine):
|
||||
def make_gen():
|
||||
if use_flashinfer:
|
||||
return prefill_and_generate_capture(
|
||||
self.model, self.tokenizer, input_tensor, max_tokens, self.use_cuda_graph,
|
||||
mode=self.mode, force_think=self.force_think, chunk_size=self.chunk_size,
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
input_tensor,
|
||||
max_tokens,
|
||||
self.use_cuda_graph,
|
||||
mode=self.mode,
|
||||
force_think=self.force_think,
|
||||
chunk_size=self.chunk_size,
|
||||
use_flashinfer_mla=True,
|
||||
num_heads=self.model.config.num_attention_heads,
|
||||
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
|
||||
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
|
||||
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0) + getattr(self.model.config, "qk_nope_head_dim", 0),
|
||||
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
|
||||
+ getattr(self.model.config, "qk_nope_head_dim", 0),
|
||||
echo_stream=False,
|
||||
)
|
||||
else:
|
||||
return prefill_and_generate_capture(
|
||||
self.model, self.tokenizer, input_tensor, max_tokens, self.use_cuda_graph,
|
||||
mode=self.mode, force_think=self.force_think, chunk_size=self.chunk_size,
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
input_tensor,
|
||||
max_tokens,
|
||||
self.use_cuda_graph,
|
||||
mode=self.mode,
|
||||
force_think=self.force_think,
|
||||
chunk_size=self.chunk_size,
|
||||
echo_stream=False,
|
||||
)
|
||||
|
||||
@@ -182,9 +194,11 @@ class KTransformersEngine(BaseEngine):
|
||||
try:
|
||||
gen = make_gen()
|
||||
if hasattr(gen, "__aiter__"):
|
||||
|
||||
async def drain_async():
|
||||
async for t in gen:
|
||||
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
|
||||
|
||||
asyncio.run(drain_async())
|
||||
elif hasattr(gen, "__iter__"):
|
||||
for t in gen:
|
||||
@@ -252,7 +266,7 @@ class KTransformersEngine(BaseEngine):
|
||||
async with self.semaphore:
|
||||
produced = ""
|
||||
async for t in self._generate(messages, system, tools, **input_kwargs):
|
||||
delta = t[len(produced):] if t.startswith(produced) else t
|
||||
delta = t[len(produced) :] if t.startswith(produced) else t
|
||||
produced = t
|
||||
if delta:
|
||||
yield delta
|
||||
|
||||
@@ -616,7 +616,14 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
|
||||
template.default_system = data_args.default_system
|
||||
|
||||
template.enable_thinking = data_args.enable_thinking
|
||||
if isinstance(template, ReasoningTemplate):
|
||||
logger.warning_rank0(
|
||||
"You are using reasoning template, "
|
||||
"please add `_nothink` suffix if the model is not a reasoning model. "
|
||||
"e.g., qwen3_vl_nothink"
|
||||
)
|
||||
template.enable_thinking = data_args.enable_thinking
|
||||
|
||||
template.fix_special_tokens(tokenizer)
|
||||
template.fix_jinja_template(tokenizer)
|
||||
return template
|
||||
|
||||
@@ -312,9 +312,11 @@ def use_openmind() -> bool:
|
||||
def use_ray() -> bool:
|
||||
return is_env_enabled("USE_RAY")
|
||||
|
||||
|
||||
def use_kt() -> bool:
|
||||
return is_env_enabled("USE_KT")
|
||||
|
||||
|
||||
def find_available_port() -> int:
|
||||
r"""Find an available port on the local machine."""
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
||||
@@ -110,6 +110,7 @@ def is_starlette_available():
|
||||
def is_transformers_version_greater_than(content: str):
|
||||
return _get_package_version("transformers") >= version.parse(content)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_version_greater_than(content: str):
|
||||
return _get_package_version("torch") >= version.parse(content)
|
||||
|
||||
@@ -439,6 +439,7 @@ class SwanLabArguments:
|
||||
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(
|
||||
SwanLabArguments,
|
||||
|
||||
@@ -485,7 +485,9 @@ class KTransformersArguments:
|
||||
)
|
||||
kt_optimize_rule: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."},
|
||||
metadata={
|
||||
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
|
||||
},
|
||||
)
|
||||
cpu_infer: Optional[int] = field(
|
||||
default=32,
|
||||
@@ -517,9 +519,16 @@ class KTransformersArguments:
|
||||
metadata={"help": "Force-Think Toggle For The KT Engine."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(
|
||||
SGLangArguments, VllmArguments, KTransformersArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
|
||||
SGLangArguments,
|
||||
VllmArguments,
|
||||
KTransformersArguments,
|
||||
ExportArguments,
|
||||
ProcessorArguments,
|
||||
QuantizationArguments,
|
||||
BaseModelArguments,
|
||||
):
|
||||
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
|
||||
|
||||
@@ -90,6 +90,7 @@ class RayArguments:
|
||||
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
|
||||
self.ray_storage_filesystem = fs.GcsFileSystem()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
@@ -57,7 +57,9 @@ def launch():
|
||||
if is_env_enabled("USE_MCA"): # force use torchrun
|
||||
os.environ["FORCE_TORCHRUN"] = "1"
|
||||
|
||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())):
|
||||
if command == "train" and (
|
||||
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
|
||||
):
|
||||
# launch distributed training
|
||||
nnodes = os.getenv("NNODES", "1")
|
||||
node_rank = os.getenv("NODE_RANK", "0")
|
||||
|
||||
@@ -167,7 +167,7 @@ def _setup_lora_tuning(
|
||||
is_mergeable = False
|
||||
|
||||
if model_args.use_kt:
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Up to now, KTransformers model only accepts a single adapter, for more features, you can contact with us."
|
||||
assert len(model_args.adapter_name_or_path) == 1, "KTransformers model only accepts a single adapter"
|
||||
is_mergeable = False
|
||||
|
||||
if model_args.use_unsloth:
|
||||
@@ -190,7 +190,9 @@ def _setup_lora_tuning(
|
||||
|
||||
if model_args.use_kt:
|
||||
if model_args.infer_backend != EngineName.KT:
|
||||
raise ValueError("We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers.")
|
||||
raise ValueError(
|
||||
"We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers."
|
||||
)
|
||||
|
||||
for adapter in adapter_to_merge:
|
||||
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
||||
@@ -218,9 +220,9 @@ def _setup_lora_tuning(
|
||||
if model_args.use_kt:
|
||||
new_list = []
|
||||
for m in target_modules:
|
||||
if m in ('down_proj', 'up_proj', 'gate_proj'):
|
||||
if m in ("down_proj", "up_proj", "gate_proj"):
|
||||
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
|
||||
elif m not in ('generate_linear', 'orig_module', 'prefill_linear'):
|
||||
elif m not in ("generate_linear", "orig_module", "prefill_linear"):
|
||||
new_list.append(m)
|
||||
|
||||
target_modules[:] = new_list
|
||||
|
||||
@@ -146,6 +146,7 @@ def load_model(
|
||||
lazy_load = False
|
||||
if model_args.use_kt:
|
||||
from ktransformers.sft.monkey_patch_torch_module import install_patch
|
||||
|
||||
install_patch()
|
||||
model = load_kt_pretrained_model(config, model_args)
|
||||
elif model_args.use_unsloth:
|
||||
|
||||
@@ -59,6 +59,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
||||
requested_attn_implementation = "sdpa"
|
||||
elif model_args.flash_attn == AttentionFunction.FA2:
|
||||
from transformers import is_torch_npu_available
|
||||
|
||||
if not (is_flash_attn_2_available() or is_torch_npu_available()):
|
||||
logger.warning_rank0("FlashAttention-2 is not installed.")
|
||||
return
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.util as _u
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -43,6 +43,7 @@ if KT_AVAILABLE:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _get_kt_kwargs(
|
||||
config: "PretrainedConfig",
|
||||
model_name_or_path: str,
|
||||
@@ -64,9 +65,7 @@ def _get_kt_kwargs(
|
||||
}
|
||||
|
||||
|
||||
def load_kt_pretrained_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments"
|
||||
) -> Optional["PreTrainedModel"]:
|
||||
def load_kt_pretrained_model(config: "PretrainedConfig", model_args: "ModelArguments") -> "PreTrainedModel":
|
||||
r"""Optionally load pretrained model with KTransformers. Used in training."""
|
||||
custom_models = {
|
||||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||
@@ -79,7 +78,7 @@ def load_kt_pretrained_model(
|
||||
Config().chunk_size = model_args.chunk_size
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
|
||||
|
||||
if model_args.mode == 'long_context':
|
||||
if model_args.mode == "long_context":
|
||||
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
||||
torch.set_default_dtype(torch.float16)
|
||||
else:
|
||||
@@ -88,9 +87,7 @@ def load_kt_pretrained_model(
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] in custom_models:
|
||||
print("using custom modeling_xxx.py.")
|
||||
if (
|
||||
"Qwen2Moe" in config.architectures[0]
|
||||
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
||||
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
if "Llama" in config.architectures[0]:
|
||||
config._attn_implementation = "eager"
|
||||
@@ -115,21 +112,17 @@ def load_kt_pretrained_model(
|
||||
return model
|
||||
|
||||
|
||||
def get_kt_peft_model(
|
||||
model: "PreTrainedModel", peft_kwargs: dict[str, Any]
|
||||
) -> "PreTrainedModel":
|
||||
def get_kt_peft_model(model: "PreTrainedModel", peft_kwargs: dict[str, Any]) -> "PreTrainedModel":
|
||||
r"""Get the peft model for the pretrained model with KTransformers. Used in training."""
|
||||
from ktransformers.sft.peft_utils.mapping import get_peft_model
|
||||
|
||||
return get_peft_model(model, peft_kwargs)
|
||||
|
||||
|
||||
def load_kt_peft_model(
|
||||
model_args: "ModelArguments", model: "PreTrainedModel",
|
||||
) -> "PreTrainedModel":
|
||||
def load_kt_peft_model(model_args: "ModelArguments", model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""Load peft model with KTransformers. Used in both training and inference."""
|
||||
load_adapter_name_or_path = model_args.adapter_name_or_path[0]
|
||||
if load_adapter_name_or_path.endswith('.gguf'):
|
||||
if load_adapter_name_or_path.endswith(".gguf"):
|
||||
inject_lora_layer(model, load_adapter_name_or_path)
|
||||
adapter_gguf_loader = GGUFLoader(load_adapter_name_or_path)
|
||||
load_weights(model, adapter_gguf_loader, adapter_gguf=True)
|
||||
|
||||
@@ -47,6 +47,7 @@ def run_sft(
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
from ktransformers.util.globals import GLOBAL_CONFIG
|
||||
|
||||
GLOBAL_CONFIG._config["mod"] = "sft"
|
||||
|
||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||
@@ -66,12 +67,13 @@ def run_sft(
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
if training_args.predict_with_generate:
|
||||
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet. if you do need it, please open an issue.")
|
||||
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
|
||||
elif finetuning_args.compute_accuracy:
|
||||
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet. if you do need it, please open an issue.")
|
||||
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
|
||||
|
||||
# Initialize our Trainer
|
||||
from ktransformers.sft.lora import KTrainer
|
||||
|
||||
trainer = KTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
||||
@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
from ..extras.packages import is_kt_available, is_mcore_adapter_available, is_ray_available
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
@@ -86,12 +86,12 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
if model_args.use_kt:
|
||||
if not is_kt_available():
|
||||
raise ImportError("KTransformers is not installed. Please install it with `pip install ktransformers`.")
|
||||
from .ksft.workflow import run_sft as run_sft_kt
|
||||
|
||||
run_sft_kt(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
else:
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
|
||||
elif finetuning_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "ppo":
|
||||
|
||||
Reference in New Issue
Block a user