[model] fix kv cache (#7564)
This commit is contained in:
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ..extras import logging
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
@@ -277,14 +276,14 @@ def init_adapter(
|
||||
|
||||
# cast trainable parameters to float32 if:
|
||||
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
|
||||
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
|
||||
# 2. is_trainable and not pure_bf16 and not badam and not zero3 (zero3 already in fp32)
|
||||
cast_trainable_params_to_fp32 = False
|
||||
if not is_trainable:
|
||||
pass
|
||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
elif model_args.quantization_bit is None and is_deepspeed_zero3_enabled():
|
||||
logger.info_rank0("DeepSpeed ZeRO3 detected, remaining trainable params in float32.")
|
||||
else:
|
||||
logger.info_rank0("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
Reference in New Issue
Block a user