use pre-commit
Former-commit-id: 7cfede95df22a9ff236788f04159b6b16b8d04bb
This commit is contained in:
@@ -92,7 +92,7 @@ def fix_valuehead_checkpoint(
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
logger.info("Value head model saved at: {}".format(output_dir))
|
||||
logger.info(f"Value head model saved at: {output_dir}")
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
@@ -106,7 +106,7 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||
fix_valuehead_checkpoint(
|
||||
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
|
||||
)
|
||||
@@ -123,7 +123,7 @@ class SaveProcessorCallback(TrainerCallback):
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
@override
|
||||
@@ -145,7 +145,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir))
|
||||
logger.info(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
|
||||
if isinstance(model, PeftModel):
|
||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||
@@ -159,7 +159,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
|
||||
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
|
||||
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
|
||||
logger.info(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
|
||||
# 1. save a pissa backup with init_lora_weights: True
|
||||
# 2. save a converted lora with init_lora_weights: pissa
|
||||
# 3. load the pissa backup with init_lora_weights: True
|
||||
|
||||
Reference in New Issue
Block a user