mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-30 02:53:09 +00:00
[feat] support LlamaFactory SFT training by HyperParallel FSDP2 backend (#10289)
This commit is contained in:
@@ -70,6 +70,10 @@ def is_matplotlib_available():
|
||||
return _is_package_available("matplotlib")
|
||||
|
||||
|
||||
def is_hyper_parallel_available():
|
||||
return _is_package_available("hyper_parallel")
|
||||
|
||||
|
||||
def is_mcore_adapter_available():
|
||||
return _is_package_available("mcore_adapter")
|
||||
|
||||
|
||||
@@ -482,6 +482,24 @@ class FinetuningArguments(
|
||||
)
|
||||
},
|
||||
)
|
||||
use_hyper_parallel: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
|
||||
"Only supported for the 'sft' stage with full fine-tuning."
|
||||
)
|
||||
},
|
||||
)
|
||||
hyper_parallel_args: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Path to a JSON file containing HyperParallel strategy arguments "
|
||||
"(e.g., tp_size, param_dtype). Used when use_hyper_parallel=True."
|
||||
)
|
||||
},
|
||||
)
|
||||
use_muon: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the Muon optimizer."},
|
||||
|
||||
18
src/llamafactory/train/hyper_parallel/__init__.py
Normal file
18
src/llamafactory/train/hyper_parallel/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_sft
|
||||
|
||||
|
||||
__all__ = ["run_sft"]
|
||||
179
src/llamafactory/train/hyper_parallel/workflow.py
Normal file
179
src/llamafactory/train/hyper_parallel/workflow.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import calculate_tps
|
||||
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
if not is_hyper_parallel_available():
|
||||
raise ImportError(
|
||||
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
|
||||
)
|
||||
|
||||
from hyper_parallel.integration.llamafactory import HyperParallelArguments, HyperParallelTrainer # pylint: disable=C0415
|
||||
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
ref_model = None
|
||||
if finetuning_args.use_asft_loss:
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
template=template,
|
||||
model=model if not training_args.predict_with_generate else None,
|
||||
pad_to_multiple_of=8 if training_args.do_train else None,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
block_diag_attn=model_args.block_diag_attn,
|
||||
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
if training_args.predict_with_generate:
|
||||
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
|
||||
elif finetuning_args.compute_accuracy:
|
||||
metric_module["compute_metrics"] = ComputeAccuracy()
|
||||
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
||||
if is_transformers_version_greater_than("4.58.0"):
|
||||
extra_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
|
||||
if not isinstance(extra_ids, list):
|
||||
extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])
|
||||
string_tokens = [str(t) for t in extra_special_tokens]
|
||||
extra_ids = tokenizer.convert_tokens_to_ids(string_tokens)
|
||||
all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1]
|
||||
gen_kwargs["eos_token_id"] = list(dict.fromkeys(all_eos_ids))
|
||||
else:
|
||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
|
||||
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
||||
|
||||
callbacks = list(callbacks or [])
|
||||
processor = tokenizer_module.get("processor")
|
||||
if processor is not None:
|
||||
callbacks.append(SaveProcessorCallback(processor))
|
||||
|
||||
compute_loss_func = None
|
||||
if finetuning_args.use_dft_loss:
|
||||
compute_loss_func = dft_loss_func
|
||||
elif finetuning_args.use_eaft_loss:
|
||||
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
|
||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||
)
|
||||
elif finetuning_args.use_asft_loss:
|
||||
from functools import partial
|
||||
|
||||
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
|
||||
|
||||
trainer = HyperParallelTrainer(
|
||||
hp_args=hp_args,
|
||||
model=model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
gen_kwargs=gen_kwargs,
|
||||
ref_model=ref_model,
|
||||
compute_loss_func=compute_loss_func,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
|
||||
from types import MethodType
|
||||
|
||||
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
|
||||
trainer.add_callback(BAdamCallback)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
if finetuning_args.include_effective_tokens_per_second:
|
||||
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
|
||||
dataset_module["train_dataset"], train_result.metrics, stage="sft"
|
||||
)
|
||||
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
keys = ["loss"]
|
||||
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||
keys += sum(
|
||||
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()],
|
||||
[],
|
||||
)
|
||||
else:
|
||||
keys += ["eval_loss", "eval_accuracy"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
|
||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
|
||||
|
||||
# Create model card
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
@@ -24,7 +24,12 @@ from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
|
||||
from ..extras.packages import (
|
||||
is_hyper_parallel_available,
|
||||
is_mcore_adapter_available,
|
||||
is_ray_available,
|
||||
is_transformers_version_greater_than,
|
||||
)
|
||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
@@ -71,7 +76,16 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
|
||||
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
||||
if not is_hyper_parallel_available():
|
||||
raise ImportError(
|
||||
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
|
||||
)
|
||||
from .hyper_parallel import run_sft as run_sft_hp
|
||||
|
||||
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
|
||||
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
|
||||
if not is_mcore_adapter_available():
|
||||
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
||||
if finetuning_args.stage == "pt":
|
||||
|
||||
Reference in New Issue
Block a user