support control eos, fix #6345
Former-commit-id: cb0f8399356bf372f3b7963f2565c3d504be0923
This commit is contained in:
@@ -205,7 +205,9 @@ class HuggingfaceEngine(BaseEngine):
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
response = tokenizer.batch_decode(
|
||||
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
|
||||
)
|
||||
results = []
|
||||
for i in range(len(response)):
|
||||
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
||||
@@ -249,7 +251,9 @@ class HuggingfaceEngine(BaseEngine):
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, skip_prompt=True, skip_special_tokens=generating_args["skip_special_tokens"]
|
||||
)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
Reference in New Issue
Block a user