release v0.9.0 (real)

Former-commit-id: 8ff781c8ae5654680f738f69a6db9d7b95d76baf
This commit is contained in:
hiyouga
2024-09-09 01:00:25 +08:00
parent e20baa4218
commit dfff411e1a
7 changed files with 45 additions and 53 deletions

View File

@@ -96,38 +96,45 @@ def fix_valuehead_checkpoint(
class FixValueHeadModelCallback(TrainerCallback):
r"""
A callback for fixing the checkpoint for valuehead models.
"""
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
fix_valuehead_checkpoint(
model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors,
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
)
class SaveProcessorCallback(TrainerCallback):
r"""
A callback for saving the processor.
"""
def __init__(self, processor: "ProcessorMixin") -> None:
r"""
Initializes a callback for saving the processor.
"""
self.processor = processor
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(self.processor, "image_processor").save_pretrained(output_dir)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
class PissaConvertCallback(TrainerCallback):
r"""
Initializes a callback for converting the PiSSA adapter to a normal one.
A callback for converting the PiSSA adapter to a normal one.
"""
@override
@@ -147,9 +154,6 @@ class PissaConvertCallback(TrainerCallback):
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
@@ -177,21 +181,22 @@ class PissaConvertCallback(TrainerCallback):
class LogCallback(TrainerCallback):
r"""
A callback for logging training and evaluation status.
"""
def __init__(self) -> None:
r"""
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
# Progress
self.start_time = 0
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None
""" Status """
# Status
self.aborted = False
self.do_train = False
""" Web UI """
# Web UI
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
@@ -233,9 +238,6 @@ class LogCallback(TrainerCallback):
@override
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
"""
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
@@ -246,9 +248,6 @@ class LogCallback(TrainerCallback):
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
self.do_train = True
self._reset(max_steps=state.max_steps)
@@ -256,50 +255,32 @@ class LogCallback(TrainerCallback):
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
self._close_thread_pool()
@override
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
@override
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if not self.do_train:
self._close_thread_pool()
@override
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
"""
if not self.do_train:
self._close_thread_pool()
@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after logging the last logs.
"""
if not args.should_save:
return
@@ -342,9 +323,6 @@ class LogCallback(TrainerCallback):
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""
if self.do_train:
return