modify code structure

Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
This commit is contained in:
hiyouga
2023-08-02 23:17:36 +08:00
parent 8bd1da7144
commit 28a51b622b
25 changed files with 188 additions and 145 deletions

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
from datetime import timedelta
from transformers import TrainerCallback
from transformers.trainer_utils import has_length
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
@@ -14,58 +15,105 @@ class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.in_training = False
self.start_time = time.time()
self.tracker = {}
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
def timing(self):
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
self.start_time = time.time()
if state.is_local_process_zero:
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step
might take several inputs.
Event called at the end of training.
"""
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
if state.is_local_process_zero:
self.in_training = False
self.cur_steps = 0
self.max_steps = 0
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.runner is not None and self.runner.aborted:
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if state.is_local_process_zero:
self.cur_steps = state.global_step
self.timing()
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
r"""
Event called after a successful prediction.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r"""
Event called after logging the last logs.
"""
if not state.is_world_process_zero:
if not state.is_local_process_zero:
return
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_steps = state.max_steps - cur_steps
remaining_time = remaining_steps * avg_time_per_step
self.tracker = {
"current_steps": cur_steps,
"total_steps": state.max_steps,
"loss": state.log_history[-1].get("loss", None),
"eval_loss": state.log_history[-1].get("eval_loss", None),
"predict_loss": state.log_history[-1].get("predict_loss", None),
"reward": state.log_history[-1].get("reward", None),
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None),
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time
)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(self.tracker) + "\n")
f.write(json.dumps(logs) + "\n")
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a prediction step.
"""
eval_dataloader = kwargs.pop("eval_dataloader", None)
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self.cur_steps += 1
self.timing()