better dtype handle in loading

Former-commit-id: 663f0577dd61a1a31191db2c6fbb0c7cea533b21
This commit is contained in:
hiyouga
2024-05-17 02:14:56 +08:00
parent a3320f26cf
commit 969e605c7e
3 changed files with 15 additions and 8 deletions

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict
import torch
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
@@ -66,13 +66,16 @@ def patch_config(
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if init_kwargs["low_cpu_mem_usage"]:
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
if deepspeed_config() is None and not is_fsdp_enabled(): # set dtype and device map if not use deepspeed or fsdp
init_kwargs["torch_dtype"] = model_args.compute_dtype
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map