support batch infer in vllm

Former-commit-id: 3ef5ed3b9a44eed2f7e3ff221dfc343d0a97c0b5
This commit is contained in:
hiyouga
2024-12-04 13:50:00 +00:00
parent 53edd62f8b
commit c1768cfb14
29 changed files with 148 additions and 407 deletions

View File

@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import cal_effective_tokens, get_logits_processor
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps, get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
@@ -33,6 +34,9 @@ if TYPE_CHECKING:
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
@@ -65,11 +69,6 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset
effective_token_num = 0.0
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["input_ids"])
# Metric utils
metric_module = {}
if training_args.predict_with_generate:
@@ -99,12 +98,12 @@ def run_sft(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="sft"
)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
@@ -124,6 +123,7 @@ def run_sft(
# Predict
if training_args.do_predict:
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)