support control eos, fix #6345

Former-commit-id: cb0f8399356bf372f3b7963f2565c3d504be0923
This commit is contained in:
hiyouga
2024-12-17 10:42:05 +00:00
parent 6522467ddb
commit 19ebc0e7a2
5 changed files with 21 additions and 7 deletions

View File

@@ -205,7 +205,9 @@ class HuggingfaceEngine(BaseEngine):
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
response = tokenizer.batch_decode(
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
@@ -249,7 +251,9 @@ class HuggingfaceEngine(BaseEngine):
videos,
input_kwargs,
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=generating_args["skip_special_tokens"]
)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()