support batch infer in vllm
Former-commit-id: 3ef5ed3b9a44eed2f7e3ff221dfc343d0a97c0b5
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -87,6 +87,21 @@ def check_dependencies() -> None:
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
|
||||
|
||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
r"""
|
||||
Calculates effective tokens per second.
|
||||
"""
|
||||
effective_token_num = 0
|
||||
for data in dataset:
|
||||
if stage == "sft":
|
||||
effective_token_num += len(data["input_ids"])
|
||||
elif stage == "rm":
|
||||
effective_token_num += len(data["chosen_input_ids"]) + len(data["rejected_input_ids"])
|
||||
|
||||
result = effective_token_num * metrics["epoch"] / metrics["train_runtime"]
|
||||
return result / dist.get_world_size() if dist.is_initialized() else result
|
||||
|
||||
|
||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
@@ -264,11 +279,3 @@ def use_modelscope() -> bool:
|
||||
|
||||
def use_openmind() -> bool:
|
||||
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||
|
||||
|
||||
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
|
||||
r"""
|
||||
calculate effective tokens.
|
||||
"""
|
||||
result = effective_token_num * epoch / train_runtime
|
||||
return result / dist.get_world_size() if dist.is_initialized() else result
|
||||
|
||||
Reference in New Issue
Block a user