[misc] support export ollama modelfile (#6899)
* support export ollama modelfile * update config * add system and num ctx Former-commit-id: 8c2af7466f4015f300b51841db11bcd2505ebf20
This commit is contained in:
@@ -239,11 +239,9 @@ class Template:
|
||||
Returns the jinja template.
|
||||
"""
|
||||
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
|
||||
system_message = self._convert_slots_to_jinja(
|
||||
self.format_system.apply(), tokenizer, placeholder="system_message"
|
||||
)
|
||||
user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
|
||||
assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
|
||||
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
|
||||
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
|
||||
assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
|
||||
jinja_template = ""
|
||||
if prefix:
|
||||
jinja_template += "{{ " + prefix + " }}"
|
||||
@@ -254,13 +252,13 @@ class Template:
|
||||
jinja_template += (
|
||||
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
|
||||
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
|
||||
"{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||
"{% if system_message is defined %}{{ " + system + " }}{% endif %}"
|
||||
"{% for message in loop_messages %}"
|
||||
"{% set content = message['content'] %}"
|
||||
"{% if message['role'] == 'user' %}"
|
||||
"{{ " + user_message + " }}"
|
||||
"{{ " + user + " }}"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"{{ " + assistant_message + " }}"
|
||||
"{{ " + assistant + " }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
@@ -276,6 +274,64 @@ class Template:
|
||||
except ValueError as e:
|
||||
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
||||
|
||||
@staticmethod
|
||||
def _convert_slots_to_ollama(
|
||||
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
|
||||
) -> str:
|
||||
r"""
|
||||
Converts slots to ollama template.
|
||||
"""
|
||||
slot_items = []
|
||||
for slot in slots:
|
||||
if isinstance(slot, str):
|
||||
slot_pieces = slot.split("{{content}}")
|
||||
if slot_pieces[0]:
|
||||
slot_items.append(slot_pieces[0])
|
||||
if len(slot_pieces) > 1:
|
||||
slot_items.append("{{ " + placeholder + " }}")
|
||||
if slot_pieces[1]:
|
||||
slot_items.append(slot_pieces[1])
|
||||
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
|
||||
if "bos_token" in slot and tokenizer.bos_token_id is not None:
|
||||
slot_items.append(tokenizer.bos_token)
|
||||
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
|
||||
slot_items.append(tokenizer.eos_token)
|
||||
elif isinstance(slot, dict):
|
||||
raise ValueError("Dict is not supported.")
|
||||
|
||||
return "".join(slot_items)
|
||||
|
||||
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the ollama template.
|
||||
"""
|
||||
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
|
||||
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
|
||||
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
|
||||
assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content")
|
||||
return (
|
||||
f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}"
|
||||
f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}"""
|
||||
f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}"""
|
||||
)
|
||||
|
||||
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the ollama modelfile.
|
||||
|
||||
TODO: support function calling.
|
||||
"""
|
||||
modelfile = f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
|
||||
|
||||
if self.default_system:
|
||||
modelfile += f'SYSTEM system "{self.default_system}"\n\n'
|
||||
|
||||
for stop_token_id in self.get_stop_token_ids(tokenizer):
|
||||
modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'
|
||||
|
||||
modelfile += "PARAMETER num_ctx 4096\n"
|
||||
return modelfile
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
@@ -1020,7 +1076,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
# copied from minicpm_v template
|
||||
_register_template(
|
||||
name="minicpm_o",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
|
||||
Reference in New Issue
Block a user