mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
[model] update kt code (#9406)
This commit is contained in:
@@ -47,6 +47,7 @@ def run_sft(
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
from ktransformers.util.globals import GLOBAL_CONFIG
|
||||
|
||||
GLOBAL_CONFIG._config["mod"] = "sft"
|
||||
|
||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||
@@ -66,12 +67,13 @@ def run_sft(
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
if training_args.predict_with_generate:
|
||||
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet. if you do need it, please open an issue.")
|
||||
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
|
||||
elif finetuning_args.compute_accuracy:
|
||||
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet. if you do need it, please open an issue.")
|
||||
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
|
||||
|
||||
# Initialize our Trainer
|
||||
from ktransformers.sft.lora import KTrainer
|
||||
|
||||
trainer = KTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
||||
@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
from ..extras.packages import is_kt_available, is_mcore_adapter_available, is_ray_available
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
@@ -86,12 +86,12 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
if model_args.use_kt:
|
||||
if not is_kt_available():
|
||||
raise ImportError("KTransformers is not installed. Please install it with `pip install ktransformers`.")
|
||||
from .ksft.workflow import run_sft as run_sft_kt
|
||||
|
||||
run_sft_kt(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
else:
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
|
||||
elif finetuning_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "ppo":
|
||||
|
||||
Reference in New Issue
Block a user