mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
fix fsdp model loading
Former-commit-id: fc6fe23cc9ae4a920a17e8268a85c1aa4ad16d3b
This commit is contained in:
@@ -7,6 +7,7 @@ import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, GPTQConfig
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.constants import FILEEXT2TYPE
|
||||
@@ -133,7 +134,7 @@ def configure_quantization(
|
||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
|
||||
)
|
||||
|
||||
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use auto device map.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user