loose gemma2 attention

Former-commit-id: a0b645017a2de3d58b6cbc71bd91ec96fc7a818b
This commit is contained in:
hiyouga
2024-06-29 01:42:14 +08:00
parent 6a75d57060
commit 3c4f8eaa55
2 changed files with 9 additions and 6 deletions

View File

@@ -32,8 +32,14 @@ def configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
model_args.flash_attn = "disabled"
if model_args.flash_attn == "auto":
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
model_args.flash_attn = "disabled"
else:
logger.warning(
"Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. "
"Will proceed at your own risk.".format(model_args.flash_attn)
)
if model_args.flash_attn == "auto":
return