fix galore

Former-commit-id: 62a3ceeef8f60caef43ccc7f971a0c9184e21296
This commit is contained in:
hiyouga
2024-03-08 00:44:51 +08:00
parent 81fcb80466
commit e416cecf62
11 changed files with 129 additions and 25 deletions

View File

@@ -34,7 +34,8 @@ def init_adapter(
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
model = model.float()
if not finetuning_args.pure_bf16:
model = model.float()
if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
@@ -78,7 +79,8 @@ def init_adapter(
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
param.data = param.data.to(torch.float32)
if not finetuning_args.pure_bf16:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
@@ -150,8 +152,9 @@ def init_adapter(
)
model = get_peft_model(model, lora_config)
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
if not finetuning_args.pure_bf16:
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))