code refactor
Former-commit-id: ee3f85aa9677d0aeecb3bc396530d2cd7c50dce5
This commit is contained in:
@@ -17,11 +17,9 @@
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.misc import get_logits_processor
|
||||
from ...extras.misc import cal_effective_tokens, get_logits_processor
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
@@ -68,8 +66,9 @@ def run_sft(
|
||||
training_args.remove_unused_columns = False # important for multimodal dataset
|
||||
|
||||
effective_token_num = 0.0
|
||||
for data in dataset_module["train_dataset"]:
|
||||
effective_token_num += len(data["input_ids"])
|
||||
if finetuning_args.include_effective_tokens_per_second:
|
||||
for data in dataset_module["train_dataset"]:
|
||||
effective_token_num += len(data["input_ids"])
|
||||
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
@@ -100,12 +99,9 @@ def run_sft(
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
train_result.metrics["effective_tokens_per_sec"] = (
|
||||
effective_token_num * train_result.metrics["epoch"] / train_result.metrics["train_runtime"]
|
||||
)
|
||||
if dist.is_initialized():
|
||||
train_result.metrics["effective_tokens_per_sec"] = (
|
||||
train_result.metrics["effective_tokens_per_sec"] / dist.get_world_size()
|
||||
if finetuning_args.include_effective_tokens_per_second:
|
||||
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
|
||||
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
|
||||
)
|
||||
|
||||
trainer.save_model()
|
||||
|
||||
Reference in New Issue
Block a user