format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,10 +1,11 @@
import os
import json
import os
import time
from typing import TYPE_CHECKING
from datetime import timedelta
from typing import TYPE_CHECKING
from transformers import TrainerCallback
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import LOG_FILE_NAME
from .logging import get_logger
@@ -12,14 +13,13 @@ from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
from transformers import TrainerControl, TrainerState, TrainingArguments
logger = get_logger(__name__)
class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
@@ -28,12 +28,11 @@ class FixValueHeadModelCallback(TrainerCallback):
fix_valuehead_checkpoint(
model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors
safe_serialization=args.save_safetensors,
)
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.in_training = False
@@ -99,7 +98,9 @@ class LogCallback(TrainerCallback):
self.cur_steps = 0
self.max_steps = 0
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
def on_predict(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
):
r"""
Event called after a successful prediction.
"""
@@ -125,18 +126,22 @@ class LogCallback(TrainerCallback):
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time
remaining_time=self.remaining_time,
)
if self.runner is not None:
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
))
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
)
)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""