format style
Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
@@ -4,12 +4,13 @@
|
||||
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -17,7 +18,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Shell(nn.Module):
|
||||
|
||||
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||
@@ -26,7 +26,7 @@ class Shell(nn.Module):
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
||||
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]):
|
||||
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]): # noqa: C403
|
||||
parent_name = ".".join(name.split(".")[:-1])
|
||||
child_name = name.split(".")[-1]
|
||||
parent_module = model.get_submodule(parent_name)
|
||||
@@ -35,7 +35,7 @@ def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
||||
weight = getattr(base_layer, "weight", None)
|
||||
bias = getattr(base_layer, "bias", None)
|
||||
setattr(parent_module, child_name, Shell(weight, bias))
|
||||
|
||||
|
||||
print("Model unwrapped.")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ def quantize_loftq(
|
||||
lora_dropout=0.1,
|
||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||
init_lora_weights="loftq",
|
||||
loftq_config=loftq_config
|
||||
loftq_config=loftq_config,
|
||||
)
|
||||
|
||||
# Init LoftQ model
|
||||
|
||||
Reference in New Issue
Block a user