Merge branch 'hiyouga:main' into main

Former-commit-id: 131d0bcd554dedd794add7eb3d7b1201cac80e7c
This commit is contained in:
BUAADreamer
2024-04-25 20:02:50 +08:00
committed by GitHub
8 changed files with 58 additions and 23 deletions

View File

@@ -12,7 +12,7 @@ from .utils.attention import configure_attn_implementation, print_attn_implement
from .utils.checkpointing import prepare_model_for_training
from .utils.embedding import resize_embedding_layer
from .utils.longlora import configure_longlora
from .utils.moe import add_z3_leaf_module
from .utils.moe import add_z3_leaf_module, configure_moe
from .utils.quantization import configure_quantization
from .utils.rope import configure_rope
@@ -46,17 +46,12 @@ def patch_config(
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.")
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn)
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
@@ -65,9 +60,6 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"] and is_trainable:
setattr(config, "output_router_logits", True)
init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage

View File

@@ -5,7 +5,9 @@ from transformers.utils.versions import require_version
if TYPE_CHECKING:
from transformers import PreTrainedModel
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
@@ -37,3 +39,15 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
set_z3_leaf_modules(model, [DbrxFFN])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)