Feature BAdam
Former-commit-id: d8d2807fbcf587c37f7fd34a23e9397d2775ceed
This commit is contained in:
@@ -37,7 +37,7 @@ def init_adapter(
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||
@@ -82,7 +82,7 @@ def init_adapter(
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
@@ -162,7 +162,7 @@ def init_adapter(
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user