support unsloth
Former-commit-id: b857f00234b90b785d82ca7cdb29af3d948b1a7b
This commit is contained in:
@@ -93,16 +93,31 @@ def init_adapter(
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=finetuning_args.lora_rank,
|
||||
lora_alpha=finetuning_args.lora_alpha,
|
||||
lora_dropout=finetuning_args.lora_dropout,
|
||||
target_modules=target_modules,
|
||||
modules_to_save=finetuning_args.additional_target
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
"target_modules": target_modules,
|
||||
"lora_alpha": finetuning_args.lora_alpha,
|
||||
"lora_dropout": finetuning_args.lora_dropout
|
||||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
||||
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
|
||||
if getattr(model.config, "model_type", None) == "llama":
|
||||
model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
elif getattr(model.config, "model_type", None) == "mistral":
|
||||
model = FastMistralModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
modules_to_save=finetuning_args.additional_target,
|
||||
**peft_kwargs
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user