[inference] support sglang backend (#7278)
* Mimic SGLang offline Engine * Add more tests and args * Pass all current tests * Clean Code * fix sample_params * clean code * Fix Stream Chat * change sglang from engine mode to server mode * fix * Fix Review Issues * Use SGLang Built-In Utilities * Fix test SGLang * Some Doc Issue * fix sglang engine * add readme --------- Co-authored-by: Jin Pan <jpan236@wisc.edu> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -302,7 +302,7 @@ class VllmArguments:
|
||||
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
|
||||
)
|
||||
vllm_gpu_util: float = field(
|
||||
default=0.9,
|
||||
default=0.7,
|
||||
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
|
||||
)
|
||||
vllm_enforce_eager: bool = field(
|
||||
@@ -324,7 +324,35 @@ class VllmArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments):
|
||||
class SGLangArguments:
|
||||
r"""Arguments pertaining to the SGLang worker."""
|
||||
|
||||
sglang_maxlen: int = field(
|
||||
default=4096,
|
||||
metadata={"help": "Maximum sequence (prompt + response) length of the SGLang engine."},
|
||||
)
|
||||
sglang_mem_fraction: float = field(
|
||||
default=0.7,
|
||||
metadata={"help": "The memory fraction (0-1) to be used for the SGLang engine."},
|
||||
)
|
||||
sglang_tp_size: int = field(
|
||||
default=-1,
|
||||
metadata={"help": "Tensor parallel size for the SGLang engine."},
|
||||
)
|
||||
sglang_config: Optional[Union[dict, str]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
|
||||
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(
|
||||
SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
|
||||
):
|
||||
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
|
||||
The class on the most right will be displayed first.
|
||||
@@ -356,6 +384,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
ProcessorArguments.__post_init__(self)
|
||||
ExportArguments.__post_init__(self)
|
||||
VllmArguments.__post_init__(self)
|
||||
SGLangArguments.__post_init__(self)
|
||||
|
||||
@classmethod
|
||||
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
|
||||
|
||||
@@ -31,7 +31,7 @@ from transformers.training_args import ParallelMode
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES
|
||||
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
||||
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
@@ -134,9 +134,12 @@ def _check_extra_dependencies(
|
||||
if model_args.mixture_of_depths is not None:
|
||||
check_version("mixture-of-depth>=1.1.6", mandatory=True)
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
if model_args.infer_backend == EngineName.VLLM:
|
||||
check_version("vllm>=0.4.3,<=0.7.3")
|
||||
check_version("vllm", mandatory=True)
|
||||
elif model_args.infer_backend == EngineName.SGLANG:
|
||||
check_version("sglang>=0.4.4")
|
||||
check_version("sglang", mandatory=True)
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
check_version("galore_torch", mandatory=True)
|
||||
|
||||
Reference in New Issue
Block a user