bf16 by default, gemma2 attns
Gemma2 finetuning cannot work until merging https://github.com/huggingface/transformers/pull/31674 Former-commit-id: da66c32c7be0adc28d2185b23e9f62d56acb961c
This commit is contained in:
@@ -28,7 +28,13 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
def configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
|
||||
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
|
||||
model_args.flash_attn = "disabled"
|
||||
|
||||
if model_args.flash_attn == "auto":
|
||||
return
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ def patch_config(
|
||||
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
||||
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
|
||||
|
||||
configure_attn_implementation(config, model_args)
|
||||
configure_attn_implementation(config, model_args, is_trainable)
|
||||
configure_rope(config, model_args, is_trainable)
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user