mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
[inference] support sglang backend (#7278)
* Mimic SGLang offline Engine * Add more tests and args * Pass all current tests * Clean Code * fix sample_params * clean code * Fix Stream Chat * change sglang from engine mode to server mode * fix * Fix Review Issues * Use SGLang Built-In Utilities * Fix test SGLang * Some Doc Issue * fix sglang engine * add readme --------- Co-authored-by: Jin Pan <jpan236@wisc.edu> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from threading import Thread
|
||||
@@ -349,7 +348,6 @@ class HuggingfaceEngine(BaseEngine):
|
||||
if not self.can_generate:
|
||||
raise ValueError("The current model does not support `chat`.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
input_args = (
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
@@ -365,8 +363,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||
return await asyncio.to_thread(self._chat, *input_args)
|
||||
|
||||
@override
|
||||
async def stream_chat(
|
||||
@@ -382,7 +379,6 @@ class HuggingfaceEngine(BaseEngine):
|
||||
if not self.can_generate:
|
||||
raise ValueError("The current model does not support `stream_chat`.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
input_args = (
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
@@ -398,13 +394,12 @@ class HuggingfaceEngine(BaseEngine):
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
stream = self._stream_chat(*input_args)
|
||||
while True:
|
||||
try:
|
||||
yield await loop.run_in_executor(pool, stream)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
stream = self._stream_chat(*input_args)
|
||||
while True:
|
||||
try:
|
||||
yield await asyncio.to_thread(stream)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@override
|
||||
async def get_scores(
|
||||
@@ -415,8 +410,6 @@ class HuggingfaceEngine(BaseEngine):
|
||||
if self.can_generate:
|
||||
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
||||
return await asyncio.to_thread(self._get_scores, *input_args)
|
||||
|
||||
Reference in New Issue
Block a user