better dtype handle in loading
Former-commit-id: 663f0577dd61a1a31191db2c6fbb0c7cea533b21
This commit is contained in:
@@ -44,7 +44,7 @@ def init_adapter(
|
||||
raise ValueError("You can only use lora for quantized models.")
|
||||
|
||||
if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params in half precision.")
|
||||
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||
cast_trainable_params_to_fp32 = False
|
||||
else:
|
||||
logger.info("Upcasting trainable params to float32.")
|
||||
@@ -122,6 +122,9 @@ def init_adapter(
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
||||
model.vision_tower.requires_grad_(False)
|
||||
|
||||
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
|
||||
Reference in New Issue
Block a user