fix paligemma inference

Former-commit-id: 46357b7a677e8ba2e0a7c9d4ec1974abd061569c
This commit is contained in:
hiyouga
2024-05-20 23:36:43 +08:00
parent 247cda4b68
commit b31d808655
4 changed files with 47 additions and 20 deletions

View File

@@ -290,10 +290,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set):
if "bos_token" in slot:
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
@@ -325,9 +325,11 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
jinja_template += "{% endif %}"
jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja(
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
@@ -614,6 +616,9 @@ _register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
efficient_eos=True,
force_system=True,
)