[inference] support sglang backend (#7278)

* Mimic SGLang offline Engine

* Add more tests and args

* Pass all current tests

* Clean Code

* fix sample_params

* clean code

* Fix Stream Chat

* change sglang from engine mode to server mode

* fix

* Fix Review Issues

* Use SGLang Built-In Utilities

* Fix test SGLang

* Some Doc Issue

* fix sglang engine

* add readme

---------

Co-authored-by: Jin Pan <jpan236@wisc.edu>
Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
Qiaolin Yu
2025-03-14 16:37:58 -04:00
committed by GitHub
parent 93e6184cbe
commit a44a53ebec
15 changed files with 433 additions and 27 deletions

View File

@@ -25,6 +25,7 @@ from ..extras.constants import EngineName
from ..extras.misc import torch_gc
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .sglang_engine import SGLangEngine
from .vllm_engine import VllmEngine
@@ -52,6 +53,8 @@ class ChatModel:
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.VLLM:
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.SGLANG:
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import asyncio
import concurrent.futures
import os
from collections.abc import AsyncGenerator
from threading import Thread
@@ -349,7 +348,6 @@ class HuggingfaceEngine(BaseEngine):
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
@@ -365,8 +363,7 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs,
)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
return await asyncio.to_thread(self._chat, *input_args)
@override
async def stream_chat(
@@ -382,7 +379,6 @@ class HuggingfaceEngine(BaseEngine):
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
@@ -398,13 +394,12 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs,
)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args)
while True:
try:
yield await loop.run_in_executor(pool, stream)
except StopAsyncIteration:
break
stream = self._stream_chat(*input_args)
while True:
try:
yield await asyncio.to_thread(stream)
except StopAsyncIteration:
break
@override
async def get_scores(
@@ -415,8 +410,6 @@ class HuggingfaceEngine(BaseEngine):
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args)
return await asyncio.to_thread(self._get_scores, *input_args)

View File

@@ -0,0 +1,282 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import atexit
import json
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
import requests
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_device_count, torch_gc
from ..extras.packages import is_sglang_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from .base_engine import BaseEngine, Response
if is_sglang_available():
from sglang.utils import launch_server_cmd, terminate_process, wait_for_server
if TYPE_CHECKING:
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class SGLangEngine(BaseEngine):
"""Inference engine for SGLang models.
This class wraps the SGLang engine to provide a consistent interface for text generation
that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for
better interaction and performance. The engine launches a server process and communicates
with it via HTTP requests.
For more details on the SGLang HTTP server approach, see:
https://docs.sglang.ai/backend/send_request.html
"""
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.SGLANG
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16"
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict()
launch_cmd = [
"python3 -m sglang.launch_server",
f"--model-path {model_args.model_name_or_path}",
f"--dtype {model_args.infer_dtype}",
f"--context-length {model_args.sglang_maxlen}",
f"--mem-fraction-static {model_args.sglang_mem_fraction}",
f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}",
f"--download-dir {model_args.cache_dir}",
"--log-level error",
]
launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try:
torch_gc()
self.server_process, port = launch_server_cmd(launch_cmd)
self.base_url = f"http://localhost:{port}"
atexit.register(self._cleanup_server)
logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}")
wait_for_server(self.base_url, timeout=300)
logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}")
try:
response = requests.get(f"{self.base_url}/get_model_info", timeout=5)
if response.status_code == 200:
model_info = response.json()
logger.info(f"SGLang server model info: {model_info}")
except Exception as e:
logger.debug(f"Note: could not get model info: {str(e)}")
except Exception as e:
logger.error(f"Failed to start SGLang server: {str(e)}")
self._cleanup_server() # make sure to clean up any started process
raise RuntimeError(f"SGLang server initialization failed: {str(e)}.")
def _cleanup_server(self):
r"""Clean up the server process when the engine is destroyed."""
if hasattr(self, "server_process") and self.server_process:
try:
logger.info("Terminating SGLang server process")
terminate_process(self.server_process)
logger.info("SGLang server process terminated")
except Exception as e:
logger.warning(f"Error terminating SGLang server: {str(e)}")
async def _generate(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncIterator[dict[str, Any]]:
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if videos is not None:
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None:
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if num_return_sequences != 1:
raise NotImplementedError("SGLang only supports n=1.")
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
if max_new_tokens:
max_tokens = max_new_tokens
sampling_params = {
"temperature": temperature if temperature is not None else self.generating_args["temperature"],
"top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
"top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
"stop": stop,
"stop_token_ids": self.template.get_stop_token_ids(self.tokenizer),
"max_new_tokens": max_tokens,
"repetition_penalty": (
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
)
or 1.0, # repetition_penalty must > 0
"skip_special_tokens": skip_special_tokens
if skip_special_tokens is not None
else self.generating_args["skip_special_tokens"],
}
def stream_request():
json_data = {
"input_ids": prompt_ids,
"sampling_params": sampling_params,
"stream": True,
}
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
for chunk in response.iter_lines(decode_unicode=False):
chunk = str(chunk.decode("utf-8"))
if chunk == "data: [DONE]":
break
if chunk and chunk.startswith("data:"):
yield json.loads(chunk[5:].strip("\n"))
return await asyncio.to_thread(stream_request)
@override
async def chat(
self,
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> list["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
for request_output in generator:
final_output = request_output
results = [
Response(
response_text=final_output["text"],
response_length=final_output["meta_info"]["completion_tokens"],
prompt_length=final_output["meta_info"]["prompt_tokens"],
finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length",
)
]
return results
@override
async def stream_chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
for result in generator:
delta_text = result["text"][len(generated_text) :]
generated_text = result["text"]
yield delta_text
@override
async def get_scores(
self,
batch_input: list[str],
**input_kwargs,
) -> list[float]:
raise NotImplementedError("SGLang engine does not support `get_scores`.")
def __del__(self):
r"""Ensure server is cleaned up when object is deleted."""
self._cleanup_server()
try:
atexit.unregister(self._cleanup_server)
except Exception:
pass

View File

@@ -252,4 +252,4 @@ class VllmEngine(BaseEngine):
batch_input: list[str],
**input_kwargs,
) -> list[float]:
raise NotImplementedError("vLLM engine does not support get_scores.")
raise NotImplementedError("vLLM engine does not support `get_scores`.")