update webui and add CLIs

Former-commit-id: 1368dda22ab875914c9dd86ee5146a4f6a4736ad
This commit is contained in:
hiyouga
2024-05-03 02:58:23 +08:00
parent 2cedb59bee
commit ce8200ad98
65 changed files with 363 additions and 372 deletions

View File

@@ -1,14 +1,18 @@
import json
import logging
import os
import signal
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict
import transformers
from transformers import TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import LOG_FILE_NAME
from .logging import get_logger
from .constants import TRAINER_LOG
from .logging import LoggerHandler, get_logger
from .misc import fix_valuehead_checkpoint
@@ -33,20 +37,32 @@ class FixValueHeadModelCallback(TrainerCallback):
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.in_training = False
def __init__(self, output_dir: str) -> None:
self.aborted = False
self.do_train = False
self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0")))
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(output_dir)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def _set_abort(self, signum, frame) -> None:
self.aborted = True
def _reset(self, max_steps: int = 0) -> None:
self.start_time = time.time()
self.cur_steps = 0
self.max_steps = 0
self.max_steps = max_steps
self.elapsed_time = ""
self.remaining_time = ""
def timing(self):
def _timing(self, cur_steps: int) -> None:
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
self.cur_steps = cur_steps
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
@@ -54,36 +70,27 @@ class LogCallback(TrainerCallback):
r"""
Event called at the beginning of training.
"""
if state.is_local_process_zero:
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
if args.should_log:
self.do_train = True
self._reset(max_steps=state.max_steps)
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
if args.should_save:
os.makedirs(args.output_dir, exist_ok=True)
self.thread_pool = ThreadPoolExecutor(max_workers=1)
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if state.is_local_process_zero:
self.in_training = False
self.cur_steps = 0
self.max_steps = 0
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
@@ -91,42 +98,41 @@ class LogCallback(TrainerCallback):
r"""
Event called at the end of a training step.
"""
if state.is_local_process_zero:
self.cur_steps = state.global_step
self.timing()
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
if args.should_log:
self._timing(cur_steps=state.global_step)
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
Event called at the end of training.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
self.thread_pool.shutdown(wait=True)
self.thread_pool = None
def on_predict(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a successful prediction.
Event called after a prediction step.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
eval_dataloader = kwargs.pop("eval_dataloader", None)
if args.should_log and has_length(eval_dataloader) and not self.do_train:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self._timing(cur_steps=self.cur_steps + 1)
def _write_log(self, output_dir: str, logs: Dict[str, Any]):
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r"""
Event called after logging the last logs.
Event called after logging the last logs, `args.should_log` has been applied.
"""
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
@@ -141,26 +147,13 @@ class LogCallback(TrainerCallback):
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
)
if self.runner is not None:
logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and "loss" in logs and "learning_rate" in logs and "epoch" in logs:
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
logs["loss"], logs["learning_rate"], logs["epoch"]
)
)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""
eval_dataloader = kwargs.pop("eval_dataloader", None)
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self.cur_steps += 1
self.timing()
if args.should_save and self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs)