[feat] support megatron-LM training by mcore_adapter (#9237)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Kingsley
2025-10-26 16:21:30 +08:00
committed by GitHub
parent 129e918106
commit 13170577b2
14 changed files with 671 additions and 8 deletions

View File

@@ -19,7 +19,20 @@ from typing import Literal, Optional, Union
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray
from ..extras.misc import is_env_enabled, use_ray
if is_env_enabled("USE_MCA"):
try:
from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
BaseTrainingArguments = McaSeq2SeqTrainingArguments
except ImportError:
raise ImportError(
"mcore_adapter is required when USE_MCA=1.",
"Please install `mcore_adapter` and its dependencies."
)
else:
BaseTrainingArguments = Seq2SeqTrainingArguments
@dataclass
@@ -78,7 +91,7 @@ class RayArguments:
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
class TrainingArguments(RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field(
@@ -87,5 +100,5 @@ class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
)
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self)
BaseTrainingArguments.__post_init__(self)