support DoRA, AWQ, AQLM #2512

Former-commit-id: 6614cc1f08aa944db083e27e451bbdd733f7dd97
This commit is contained in:
hiyouga
2024-02-28 19:53:28 +08:00
parent 1e7962dfc4
commit b392e6cfb9
9 changed files with 40 additions and 9 deletions

View File

@@ -85,7 +85,7 @@ def init_adapter(
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
if model_args.adapter_name_or_path is not None:
@@ -123,6 +123,10 @@ def init_adapter(
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
if finetuning_args.use_dora:
if getattr(model, "quantization_method", None):
raise ValueError("DoRA is currently not compatible with quantized models.")
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
@@ -141,6 +145,7 @@ def init_adapter(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
modules_to_save=finetuning_args.additional_target,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)

View File

@@ -51,7 +51,7 @@ def load_model_and_tokenizer(
patch_tokenizer(tokenizer)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
patch_config(config, tokenizer, model_args, finetuning_args, config_kwargs, is_trainable)
model = None
if is_trainable and model_args.use_unsloth:

View File

@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
@@ -253,6 +253,7 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
config_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
@@ -273,6 +274,9 @@ def patch_config(
_configure_quantization(config, tokenizer, model_args, config_kwargs)
if finetuning_args.use_dora:
config_kwargs["device_map"] = {"": get_current_device()}
def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool