mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-31 06:42:05 +00:00
[feat] support megatron-LM training by mcore_adapter (#9237)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -56,6 +56,8 @@ LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
||||
|
||||
MCA_SUPPORTED_MODELS = {"deepseek_v3", "llama", "mistral", "mixtral", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "qwen3_next"}
|
||||
|
||||
METHODS = ["full", "freeze", "lora", "oft"]
|
||||
|
||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||
|
||||
@@ -70,6 +70,10 @@ def is_matplotlib_available():
|
||||
return _is_package_available("matplotlib")
|
||||
|
||||
|
||||
def is_mcore_adapter_available():
|
||||
return _is_package_available("mcore_adapter")
|
||||
|
||||
|
||||
def is_pillow_available():
|
||||
return _is_package_available("PIL")
|
||||
|
||||
|
||||
@@ -461,7 +461,7 @@ class FinetuningArguments(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."},
|
||||
)
|
||||
finetuning_type: Literal["lora", "freeze", "full"] = field(
|
||||
finetuning_type: Literal["lora", "oft", "freeze", "full"] = field(
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."},
|
||||
)
|
||||
@@ -473,6 +473,10 @@ class FinetuningArguments(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
|
||||
)
|
||||
use_mca: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use MCA (Megatron Core Adapter) training. Controlled by USE_MCA environment variable."},
|
||||
)
|
||||
use_muon: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the Muon optimizer."},
|
||||
|
||||
@@ -32,7 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
||||
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
||||
from ..extras.packages import is_transformers_version_greater_than
|
||||
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
@@ -53,6 +53,13 @@ _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, Generatin
|
||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
|
||||
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
||||
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
||||
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_MCA_CLS = tuple[ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
else:
|
||||
_TRAIN_MCA_ARGS = []
|
||||
_TRAIN_MCA_CLS = tuple()
|
||||
|
||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
||||
r"""Get arguments from the command line or a config file."""
|
||||
@@ -197,6 +204,27 @@ def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
|
||||
parser, args, allow_extra_keys=allow_extra_keys
|
||||
)
|
||||
|
||||
_configure_mca_training_args(training_args, data_args, finetuning_args)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def _configure_mca_training_args(training_args, data_args, finetuning_args) -> None:
|
||||
"""Patch training args to avoid args checking errors and sync MCA settings."""
|
||||
training_args.predict_with_generate = False
|
||||
training_args.generation_max_length = data_args.cutoff_len
|
||||
training_args.generation_num_beams = 1
|
||||
training_args.use_mca = True
|
||||
finetuning_args.use_mca = True
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
@@ -216,7 +244,11 @@ def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Ray
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||
if is_env_enabled("USE_MCA"):
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
|
||||
else:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||
finetuning_args.use_mca = False
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
|
||||
@@ -19,7 +19,20 @@ from typing import Literal, Optional, Union
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
|
||||
from ..extras.misc import use_ray
|
||||
from ..extras.misc import is_env_enabled, use_ray
|
||||
|
||||
|
||||
if is_env_enabled("USE_MCA"):
|
||||
try:
|
||||
from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||
BaseTrainingArguments = McaSeq2SeqTrainingArguments
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"mcore_adapter is required when USE_MCA=1.",
|
||||
"Please install `mcore_adapter` and its dependencies."
|
||||
)
|
||||
else:
|
||||
BaseTrainingArguments = Seq2SeqTrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -78,7 +91,7 @@ class RayArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
||||
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
@@ -87,5 +100,5 @@ class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
Seq2SeqTrainingArguments.__post_init__(self)
|
||||
RayArguments.__post_init__(self)
|
||||
BaseTrainingArguments.__post_init__(self)
|
||||
|
||||
@@ -54,6 +54,10 @@ def launch():
|
||||
)
|
||||
|
||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||
if is_env_enabled("USE_MCA"):
|
||||
# force use torchrun
|
||||
os.environ["FORCE_TORCHRUN"] = "1"
|
||||
|
||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
||||
# launch distributed training
|
||||
nnodes = os.getenv("NNODES", "1")
|
||||
|
||||
19
src/llamafactory/train/mca/__init__.py
Normal file
19
src/llamafactory/train/mca/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_dpo, run_pt, run_sft
|
||||
|
||||
|
||||
__all__ = ["run_dpo", "run_pt", "run_sft"]
|
||||
|
||||
15
src/llamafactory/train/mca/trainer.py
Normal file
15
src/llamafactory/train/mca/trainer.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO override the original trainer
|
||||
292
src/llamafactory/train/mca/workflow.py
Normal file
292
src/llamafactory/train/mca/workflow.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# Copyright 2025 the ROLL team and the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ...data import (
|
||||
SFTDataCollatorWith4DAttentionMask,
|
||||
get_dataset,
|
||||
get_template_and_fix_tokenizer,
|
||||
)
|
||||
from ...data.collator import (
|
||||
PairwiseDataCollatorWithPadding,
|
||||
)
|
||||
from ...extras.constants import IGNORE_INDEX, MCA_SUPPORTED_MODELS
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import calculate_tps
|
||||
from ...extras.packages import is_mcore_adapter_available
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_tokenizer
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
|
||||
|
||||
if not is_mcore_adapter_available():
|
||||
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
||||
|
||||
from mcore_adapter.models import AutoConfig, AutoModel
|
||||
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
|
||||
from mcore_adapter.trainer import McaTrainer
|
||||
from mcore_adapter.trainer.dpo_config import DPOConfig
|
||||
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import DataCollatorForSeq2Seq, TrainerCallback
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _data_collator_wrapper(data_collator: Any):
|
||||
@functools.wraps(data_collator)
|
||||
def wrapper(features: Sequence[dict[str, Any]]):
|
||||
labels_key = [k for k in features[0].keys() if k.endswith("labels")]
|
||||
input_ids_key = [k for k in features[0].keys() if k.endswith("input_ids")]
|
||||
for feature in features:
|
||||
if len(labels_key) == 0: # pt
|
||||
feature["labels"] = deepcopy(feature["input_ids"])[1:]
|
||||
for k in labels_key:
|
||||
feature[k] = feature[k][1:]
|
||||
for k in input_ids_key:
|
||||
feature[k] = feature[k][:-1]
|
||||
for k in ["attention_mask", "position_ids"]:
|
||||
if k in feature:
|
||||
feature[k] = feature[k][:-1]
|
||||
return data_collator(features)
|
||||
|
||||
return wrapper
|
||||
|
||||
def _check_model_support(model_args: ModelArguments):
|
||||
from transformers import AutoConfig as HfAutoConfig
|
||||
config = HfAutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
|
||||
if config.model_type not in MCA_SUPPORTED_MODELS:
|
||||
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
|
||||
|
||||
def run_pt(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
|
||||
# dataset needs +1 then cut back due to MCA shift logic
|
||||
data_args.cutoff_len += 1
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
data_args.cutoff_len -= 1
|
||||
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
)
|
||||
data_collator = _data_collator_wrapper(data_collator)
|
||||
|
||||
trainer = McaTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
)
|
||||
|
||||
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
|
||||
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
|
||||
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
keys = ["loss"]
|
||||
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||
else:
|
||||
keys += ["eval_loss"]
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
):
|
||||
# align packing flags
|
||||
# TODO: FIX SequencePacking
|
||||
data_args.neat_packing = training_args.sequence_packing = data_args.neat_packing or training_args.sequence_packing
|
||||
data_args.packing = data_args.neat_packing or data_args.packing
|
||||
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
|
||||
# dataset needs +1 then cut back due to MCA shift logic
|
||||
data_args.cutoff_len += 1
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
data_args.cutoff_len -= 1
|
||||
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
# optional freezing for qwen2_vl, qwen2_5_vl
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_vision_tower:
|
||||
for name, p in model.named_parameters():
|
||||
if any(name.startswith(k) for k in ["vision_model.blocks", "vision_model.patch_embed"]):
|
||||
p.requires_grad_(False)
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_multi_modal_projector:
|
||||
for name, p in model.named_parameters():
|
||||
if any(name.startswith(k) for k in ["multi_modal_projector"]):
|
||||
p.requires_grad_(False)
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_language_model:
|
||||
for name, p in model.named_parameters():
|
||||
if any(name.startswith(k) for k in ["embedding", "decoder", "output_layer"]):
|
||||
p.requires_grad_(False)
|
||||
|
||||
pad_to_max = (
|
||||
training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||
)
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
template=template,
|
||||
padding="max_length" if pad_to_max else "longest",
|
||||
max_length=data_args.cutoff_len if pad_to_max else None,
|
||||
pad_to_multiple_of=64,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
**tokenizer_module,
|
||||
)
|
||||
data_collator = _data_collator_wrapper(data_collator)
|
||||
|
||||
trainer = McaTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
)
|
||||
|
||||
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
|
||||
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
|
||||
|
||||
train_result = trainer.train(training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
keys = ["loss"]
|
||||
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||
else:
|
||||
keys += ["eval_loss"]
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
|
||||
def run_dpo(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
if finetuning_args.use_ref_model:
|
||||
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
ref_model = AutoModel.from_config(ref_config)
|
||||
ref_model.load_state_dict(model.state_dict())
|
||||
else:
|
||||
ref_model = None
|
||||
|
||||
# dataset needs +1 then cut back due to MCA shift logic
|
||||
data_args.cutoff_len += 1
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
data_args.cutoff_len -= 1
|
||||
|
||||
pad_to_max = (
|
||||
training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||
)
|
||||
dpo_config = DPOConfig(
|
||||
beta=finetuning_args.pref_beta,
|
||||
pref_loss=finetuning_args.pref_loss,
|
||||
label_smoothing=finetuning_args.dpo_label_smoothing,
|
||||
)
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
template=template,
|
||||
pad_to_multiple_of=64,
|
||||
padding="max_length" if pad_to_max else "longest",
|
||||
max_length=data_args.cutoff_len if pad_to_max else None,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
**tokenizer_module,
|
||||
)
|
||||
data_collator = _data_collator_wrapper(data_collator)
|
||||
|
||||
trainer = McaDPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
train_config=dpo_config,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
)
|
||||
|
||||
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
|
||||
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
|
||||
|
||||
train_result = trainer.train(training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
if finetuning_args.include_effective_tokens_per_second:
|
||||
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
|
||||
dataset_module["train_dataset"], train_result.metrics, stage="rm"
|
||||
)
|
||||
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
keys = ["loss", "rewards/accuracies"]
|
||||
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||
else:
|
||||
keys += ["eval_loss"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
from ..extras.packages import is_ray_available
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
@@ -66,7 +66,19 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage == "pt":
|
||||
if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
|
||||
if not is_mcore_adapter_available():
|
||||
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
||||
if finetuning_args.stage == "pt":
|
||||
from .mca import run_pt as run_pt_mca
|
||||
run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
from .mca import run_sft as run_sft_mca
|
||||
run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else: # dpo
|
||||
from .mca import run_dpo as run_dpo_mca
|
||||
run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
|
||||
Reference in New Issue
Block a user