improve lora+ impl.
Former-commit-id: 332bad25455a70ad9204e7dd384bb086d789aa39
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import torch
|
||||
@@ -17,6 +18,18 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
class QuantizationMethod(str, Enum):
|
||||
r"""
|
||||
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
|
||||
"""
|
||||
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora.
|
||||
@@ -24,7 +37,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
quantization_method = getattr(model, "quantization_method", None)
|
||||
if quantization_method is None:
|
||||
linear_cls = torch.nn.Linear
|
||||
elif quantization_method == "bitsandbytes":
|
||||
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
||||
|
||||
Reference in New Issue
Block a user