@@ -59,6 +59,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
) # must after fixing tokenizer to resize vocab
|
||||
self.generating_args = generating_args.to_dict()
|
||||
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
||||
|
||||
@staticmethod
|
||||
def _process_args(
|
||||
@@ -259,9 +260,6 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
return scores
|
||||
|
||||
async def start(self) -> None:
|
||||
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@@ -286,7 +284,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
image,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||
|
||||
@@ -314,7 +312,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
image,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
stream = self._stream_chat(*input_args)
|
||||
while True:
|
||||
@@ -333,6 +331,6 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||
async with self._semaphore:
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
||||
|
||||
Reference in New Issue
Block a user