fix zero2 high ram usage
Former-commit-id: 01797126eb173250250e31f8e76b69ae0047745d
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
@@ -43,8 +43,8 @@ def init_adapter(
|
||||
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
||||
raise ValueError("You can only use lora for quantized models.")
|
||||
|
||||
if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||
cast_trainable_params_to_fp32 = False
|
||||
else:
|
||||
logger.info("Upcasting trainable params to float32.")
|
||||
|
||||
Reference in New Issue
Block a user