support batch infer in vllm
Former-commit-id: 3ef5ed3b9a44eed2f7e3ff221dfc343d0a97c0b5
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -87,6 +87,21 @@ def check_dependencies() -> None:
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
|
||||
|
||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
r"""
|
||||
Calculates effective tokens per second.
|
||||
"""
|
||||
effective_token_num = 0
|
||||
for data in dataset:
|
||||
if stage == "sft":
|
||||
effective_token_num += len(data["input_ids"])
|
||||
elif stage == "rm":
|
||||
effective_token_num += len(data["chosen_input_ids"]) + len(data["rejected_input_ids"])
|
||||
|
||||
result = effective_token_num * metrics["epoch"] / metrics["train_runtime"]
|
||||
return result / dist.get_world_size() if dist.is_initialized() else result
|
||||
|
||||
|
||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
@@ -264,11 +279,3 @@ def use_modelscope() -> bool:
|
||||
|
||||
def use_openmind() -> bool:
|
||||
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||
|
||||
|
||||
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
|
||||
r"""
|
||||
calculate effective tokens.
|
||||
"""
|
||||
result = effective_token_num * epoch / train_runtime
|
||||
return result / dist.get_world_size() if dist.is_initialized() else result
|
||||
|
||||
@@ -122,7 +122,7 @@ def _check_extra_dependencies(
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4")
|
||||
require_version("vllm>=0.4.3,<0.6.5", "To fix: pip install vllm>=0.4.3,<0.6.5")
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.misc import cal_effective_tokens
|
||||
from ...extras.misc import calculate_tps
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...hparams import ModelArguments
|
||||
from ...model import load_model, load_tokenizer
|
||||
@@ -65,12 +65,6 @@ def run_dpo(
|
||||
# Update arguments
|
||||
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
||||
|
||||
effective_token_num = 0.0
|
||||
if finetuning_args.include_effective_tokens_per_second:
|
||||
for data in dataset_module["train_dataset"]:
|
||||
effective_token_num += len(data["chosen_input_ids"])
|
||||
effective_token_num += len(data["rejected_input_ids"])
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomDPOTrainer(
|
||||
model=model,
|
||||
@@ -86,13 +80,12 @@ def run_dpo(
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
|
||||
trainer.save_model()
|
||||
if finetuning_args.include_effective_tokens_per_second:
|
||||
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
|
||||
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
|
||||
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
|
||||
dataset_module["train_dataset"], train_result.metrics, stage="rm"
|
||||
)
|
||||
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
|
||||
@@ -161,12 +161,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||
|
||||
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
|
||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
|
||||
res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
|
||||
|
||||
writer.write("\n".join(res))
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as f:
|
||||
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
|
||||
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
|
||||
|
||||
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.misc import cal_effective_tokens, get_logits_processor
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import calculate_tps, get_logits_processor
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
@@ -33,6 +34,9 @@ if TYPE_CHECKING:
|
||||
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
@@ -65,11 +69,6 @@ def run_sft(
|
||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||
training_args.remove_unused_columns = False # important for multimodal dataset
|
||||
|
||||
effective_token_num = 0.0
|
||||
if finetuning_args.include_effective_tokens_per_second:
|
||||
for data in dataset_module["train_dataset"]:
|
||||
effective_token_num += len(data["input_ids"])
|
||||
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
if training_args.predict_with_generate:
|
||||
@@ -99,12 +98,12 @@ def run_sft(
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
if finetuning_args.include_effective_tokens_per_second:
|
||||
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
|
||||
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
|
||||
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
|
||||
dataset_module["train_dataset"], train_result.metrics, stage="sft"
|
||||
)
|
||||
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
@@ -124,6 +123,7 @@ def run_sft(
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
|
||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||
predict_results.metrics.pop("predict_loss", None)
|
||||
|
||||
@@ -35,7 +35,7 @@ if is_gradio_available():
|
||||
|
||||
def create_ui(demo_mode: bool = False) -> "gr.Blocks":
|
||||
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
||||
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split('.')[0]
|
||||
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
|
||||
|
||||
with gr.Blocks(title=f"LLaMA Board ({hostname})", css=CSS) as demo:
|
||||
if demo_mode:
|
||||
|
||||
Reference in New Issue
Block a user