add export_device in webui #3333

Former-commit-id: 30ebd3652809d73941e0a5e4a8be11d989faf98d
This commit is contained in:
hiyouga
2024-04-25 19:02:32 +08:00
parent 36be12a3b7
commit 9aeb88c426
7 changed files with 57 additions and 22 deletions

View File

@@ -5,7 +5,9 @@ from transformers.utils.versions import require_version
if TYPE_CHECKING:
from transformers import PreTrainedModel
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
@@ -37,3 +39,15 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
set_z3_leaf_modules(model, [DbrxFFN])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)