fix zero2 high ram usage
Former-commit-id: 01797126eb173250250e31f8e76b69ae0047745d
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
||||
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
@@ -72,7 +72,7 @@ def patch_config(
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
if deepspeed_config() is None and not is_fsdp_enabled(): # set dtype and device map if not use deepspeed or fsdp
|
||||
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
|
||||
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
|
||||
|
||||
Reference in New Issue
Block a user