[feat] support megatron-LM training by mcore_adapter (#9237)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -19,7 +19,20 @@ from typing import Literal, Optional, Union
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
|
||||
from ..extras.misc import use_ray
|
||||
from ..extras.misc import is_env_enabled, use_ray
|
||||
|
||||
|
||||
if is_env_enabled("USE_MCA"):
|
||||
try:
|
||||
from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||
BaseTrainingArguments = McaSeq2SeqTrainingArguments
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"mcore_adapter is required when USE_MCA=1.",
|
||||
"Please install `mcore_adapter` and its dependencies."
|
||||
)
|
||||
else:
|
||||
BaseTrainingArguments = Seq2SeqTrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -78,7 +91,7 @@ class RayArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
||||
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
@@ -87,5 +100,5 @@ class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
Seq2SeqTrainingArguments.__post_init__(self)
|
||||
RayArguments.__post_init__(self)
|
||||
BaseTrainingArguments.__post_init__(self)
|
||||
|
||||
Reference in New Issue
Block a user