fix gemma2 attention
Former-commit-id: aeafc68e169ae0ea5939cc81cb0cf89f0ca044b6
This commit is contained in:
@@ -21,6 +21,7 @@ from peft import PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
@@ -88,6 +89,9 @@ def patch_config(
|
||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
|
||||
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user