add export_device in webui #3333
Former-commit-id: 30ebd3652809d73941e0a5e4a8be11d989faf98d
This commit is contained in:
@@ -5,7 +5,9 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
@@ -37,3 +39,15 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
|
||||
|
||||
set_z3_leaf_modules(model, [DbrxFFN])
|
||||
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if model_args.moe_aux_loss_coef is not None:
|
||||
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
|
||||
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
|
||||
elif getattr(config, "model_type", None) == "deepseek":
|
||||
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
|
||||
|
||||
if getattr(config, "model_type", None) in ["dbrx", "jamba", "mixtral", "qwen2_moe"]:
|
||||
setattr(config, "output_router_logits", is_trainable)
|
||||
|
||||
Reference in New Issue
Block a user