bf16 by default, gemma2 attns

Gemma2 finetuning cannot work until merging https://github.com/huggingface/transformers/pull/31674


Former-commit-id: da66c32c7be0adc28d2185b23e9f62d56acb961c
This commit is contained in:
hiyouga
2024-06-28 06:00:26 +08:00
parent cfdf5a5a78
commit fda2cf677b
3 changed files with 9 additions and 3 deletions

View File

@@ -67,7 +67,7 @@ def patch_config(
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
configure_attn_implementation(config, model_args)
configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)