bf16 by default, gemma2 attns

Gemma2 finetuning cannot work until merging https://github.com/huggingface/transformers/pull/31674


Former-commit-id: da66c32c7be0adc28d2185b23e9f62d56acb961c
This commit is contained in:
hiyouga
2024-06-28 06:00:26 +08:00
parent cfdf5a5a78
commit fda2cf677b
3 changed files with 9 additions and 3 deletions

View File

@@ -28,7 +28,13 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
def configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
model_args.flash_attn = "disabled"
if model_args.flash_attn == "auto":
return