Former-commit-id: f7aa06c9c0b18c28419ea5792410915d3f322cbf
This commit is contained in:
hiyouga
2024-09-02 23:56:21 +08:00
parent 7367c6ec21
commit b5e9df5df8
5 changed files with 57 additions and 38 deletions

View File

@@ -39,42 +39,44 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
if not is_deepspeed_zero3_enabled():
return
if getattr(model.config, "model_type", None) == "dbrx":
model_type = getattr(model.config, "model_type", None)
if model_type == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
_set_z3_leaf_modules(model, [DbrxFFN])
if getattr(model.config, "model_type", None) == "jamba":
if model_type == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
_set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jetmoe":
if model_type == "jetmoe":
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if getattr(model.config, "model_type", None) == "mixtral":
if model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe":
if model_type == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
if model_type in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "jetmoe":
elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)