simplify readme
Former-commit-id: 0da6ec2d516326fe9c7583ba71cd1778eb838178
This commit is contained in:
@@ -3,7 +3,9 @@ from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import cached_file
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import get_logger
|
||||
@@ -28,11 +30,23 @@ class QuantizationMethod(str, Enum):
|
||||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
QUANTO = "quanto"
|
||||
|
||||
|
||||
def add_z3_leaf_module(model: "PreTrainedModel", module: "torch.nn.Module") -> None:
|
||||
r"""
|
||||
Sets module as a leaf module to skip partitioning in deepspeed zero3.
|
||||
"""
|
||||
if is_deepspeed_zero3_enabled():
|
||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
set_z3_leaf_modules(model, [module])
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora.
|
||||
Finds all available modules to apply lora or galore.
|
||||
"""
|
||||
quantization_method = getattr(model, "quantization_method", None)
|
||||
if quantization_method is None:
|
||||
|
||||
Reference in New Issue
Block a user