Former-commit-id: 2ab449adbb160f339a0586edeb846fa311ad8382
This commit is contained in:
hiyouga
2024-06-18 22:08:56 +08:00
parent 875270b851
commit 372da52d4a
4 changed files with 4 additions and 16 deletions

View File

@@ -59,6 +59,7 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab
self.generating_args = generating_args.to_dict()
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
@staticmethod
def _process_args(
@@ -259,9 +260,6 @@ class HuggingfaceEngine(BaseEngine):
return scores
async def start(self) -> None:
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
async def chat(
self,
messages: Sequence[Dict[str, str]],
@@ -286,7 +284,7 @@ class HuggingfaceEngine(BaseEngine):
image,
input_kwargs,
)
async with self._semaphore:
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
@@ -314,7 +312,7 @@ class HuggingfaceEngine(BaseEngine):
image,
input_kwargs,
)
async with self._semaphore:
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args)
while True:
@@ -333,6 +331,6 @@ class HuggingfaceEngine(BaseEngine):
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self._semaphore:
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args)