support batch infer in vllm

Former-commit-id: 3ef5ed3b9a44eed2f7e3ff221dfc343d0a97c0b5
This commit is contained in:
hiyouga
2024-12-04 13:50:00 +00:00
parent 53edd62f8b
commit c1768cfb14
29 changed files with 148 additions and 407 deletions

View File

@@ -17,7 +17,7 @@
import gc
import os
from typing import TYPE_CHECKING, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
import torch
import torch.distributed as dist
@@ -87,6 +87,21 @@ def check_dependencies() -> None:
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""
Calculates effective tokens per second.
"""
effective_token_num = 0
for data in dataset:
if stage == "sft":
effective_token_num += len(data["input_ids"])
elif stage == "rm":
effective_token_num += len(data["chosen_input_ids"]) + len(data["rejected_input_ids"])
result = effective_token_num * metrics["epoch"] / metrics["train_runtime"]
return result / dist.get_world_size() if dist.is_initialized() else result
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
@@ -264,11 +279,3 @@ def use_modelscope() -> bool:
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
r"""
calculate effective tokens.
"""
result = effective_token_num * epoch / train_runtime
return result / dist.get_world_size() if dist.is_initialized() else result