update flashattn, fix ppo save model
Former-commit-id: 0b08bc3dac246d4aa3f89afb7172529dcad9c39f
This commit is contained in:
@@ -25,16 +25,16 @@ class SavePeftModelCallback(TrainerCallback):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
|
||||
return control
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
|
||||
|
||||
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
|
||||
return control
|
||||
if args.should_save:
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
Reference in New Issue
Block a user