release v0.9.0 (real)
Former-commit-id: 8ff781c8ae5654680f738f69a6db9d7b95d76baf
This commit is contained in:
@@ -96,38 +96,45 @@ def fix_valuehead_checkpoint(
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for fixing the checkpoint for valuehead models.
|
||||
"""
|
||||
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
fix_valuehead_checkpoint(
|
||||
model=kwargs.pop("model"),
|
||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||
safe_serialization=args.save_safetensors,
|
||||
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
|
||||
)
|
||||
|
||||
|
||||
class SaveProcessorCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for saving the processor.
|
||||
"""
|
||||
|
||||
def __init__(self, processor: "ProcessorMixin") -> None:
|
||||
r"""
|
||||
Initializes a callback for saving the processor.
|
||||
"""
|
||||
self.processor = processor
|
||||
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
class PissaConvertCallback(TrainerCallback):
|
||||
r"""
|
||||
Initializes a callback for converting the PiSSA adapter to a normal one.
|
||||
A callback for converting the PiSSA adapter to a normal one.
|
||||
"""
|
||||
|
||||
@override
|
||||
@@ -147,9 +154,6 @@ class PissaConvertCallback(TrainerCallback):
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
@@ -177,21 +181,22 @@ class PissaConvertCallback(TrainerCallback):
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for logging training and evaluation status.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
r"""
|
||||
Initializes a callback for logging training and evaluation status.
|
||||
"""
|
||||
""" Progress """
|
||||
# Progress
|
||||
self.start_time = 0
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
self.thread_pool: Optional["ThreadPoolExecutor"] = None
|
||||
""" Status """
|
||||
# Status
|
||||
self.aborted = False
|
||||
self.do_train = False
|
||||
""" Web UI """
|
||||
# Web UI
|
||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||
if self.webui_mode:
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
@@ -233,9 +238,6 @@ class LogCallback(TrainerCallback):
|
||||
|
||||
@override
|
||||
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of the initialization of the `Trainer`.
|
||||
"""
|
||||
if (
|
||||
args.should_save
|
||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
@@ -246,9 +248,6 @@ class LogCallback(TrainerCallback):
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
self.do_train = True
|
||||
self._reset(max_steps=state.max_steps)
|
||||
@@ -256,50 +255,32 @@ class LogCallback(TrainerCallback):
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
self._close_thread_pool()
|
||||
|
||||
@override
|
||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if self.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
@override
|
||||
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of a training step.
|
||||
"""
|
||||
if self.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
@override
|
||||
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after an evaluation phase.
|
||||
"""
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
@override
|
||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a successful prediction.
|
||||
"""
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
@override
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if not args.should_save:
|
||||
return
|
||||
|
||||
@@ -342,9 +323,6 @@ class LogCallback(TrainerCallback):
|
||||
def on_prediction_step(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
):
|
||||
r"""
|
||||
Event called after a prediction step.
|
||||
"""
|
||||
if self.do_train:
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user