modify style

Former-commit-id: 54b713d0c4ffdfc6a7faeb14471b58bb1cd8acf5
This commit is contained in:
BUAADreamer
2024-04-25 21:15:16 +08:00
parent 266fe908e3
commit c425436676
16 changed files with 374 additions and 502 deletions

View File

@@ -42,9 +42,7 @@ class Template:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_pairs = self._encode(
tokenizer, messages, system, tools, cutoff_len, reserved_label_len
)
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids += query_ids + resp_ids
@@ -64,9 +62,7 @@ class Template:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
return self._encode(
tokenizer, messages, system, tools, cutoff_len, reserved_label_len
)
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
def _encode(
self,
@@ -93,9 +89,7 @@ class Template:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(
content=message["content"], idx=str(i // 2)
)
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
@@ -130,11 +124,7 @@ class Template:
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError(
"Input must be string, set[str] or dict[str, str], got {}".format(
type(elem)
)
)
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return token_ids
@@ -192,9 +182,7 @@ class Llama2Template(Template):
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(
content=system_text + message["content"]
)
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
@@ -257,9 +245,7 @@ def _register_template(
template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(
slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots
)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
templates[name] = template_class(
@@ -295,9 +281,7 @@ def _jinja_escape(content: str) -> str:
return content.replace("\n", r"\n").replace("'", r"\'")
def _convert_slots_to_jinja(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
) -> str:
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
slot_items = []
for slot in slots:
if isinstance(slot, str):
@@ -311,9 +295,7 @@ def _convert_slots_to_jinja(
elif isinstance(slot, set):
if "bos_token" in slot:
slot_items.append("'" + tokenizer.bos_token + "'")
elif (
"eos_token" in slot
): # do not use {{ eos_token }} since it may be replaced
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
@@ -325,37 +307,25 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template = ""
if template.default_system:
jinja_template += (
"{% set system_message = '"
+ _jinja_escape(template.default_system)
+ "' %}"
)
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}"
"{% set system_message = messages[0]['content'] %}"
"{% endif %}"
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}"
)
system_message = _convert_slots_to_jinja(
template.format_system.apply(), tokenizer, placeholder="system_message"
)
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
if isinstance(template, Llama2Template):
pass
elif template.force_system:
jinja_template += "{{ " + system_message + " }}"
else:
jinja_template += (
"{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
)
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in messages %}"
jinja_template += "{% set content = message['content'] %}"
if isinstance(template, Llama2Template):
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += (
"{% set content = " + system_message + " + message['content'] %}"
)
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
jinja_template += "{% endif %}"
jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
@@ -403,9 +373,7 @@ def get_template_and_fix_tokenizer(
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning(
"New tokens have been added, make sure `resize_vocab` is True."
)
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
@@ -417,9 +385,7 @@ def get_template_and_fix_tokenizer(
_register_template(
name="alpaca",
format_user=StringFormatter(
slots=["### Instruction:\n{{content}}\n\n### Response:\n"]
),
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. "
@@ -458,9 +424,7 @@ _register_template(
_register_template(
name="baichuan",
format_user=StringFormatter(
slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]
),
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
@@ -483,9 +447,7 @@ _register_template(
_register_template(
name="bluelm",
format_user=StringFormatter(
slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]
),
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
@@ -504,9 +466,7 @@ _register_template(
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_system=StringFormatter(
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
force_system=True,
@@ -515,13 +475,9 @@ _register_template(
_register_template(
name="chatglm3",
format_user=StringFormatter(
slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(
slots=[
@@ -539,9 +495,7 @@ _register_template(
_register_template(
name="chatglm3_system",
format_user=StringFormatter(
slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(
slots=[
@@ -572,15 +526,9 @@ _register_template(
_register_template(
name="chatml",
format_user=StringFormatter(
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_system=StringFormatter(
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
),
format_observation=StringFormatter(
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
@@ -589,15 +537,9 @@ _register_template(
_register_template(
name="chatml_de",
format_user=StringFormatter(
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_system=StringFormatter(
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
),
format_observation=StringFormatter(
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"],
@@ -607,9 +549,7 @@ _register_template(
_register_template(
name="codegeex2",
format_system=StringFormatter(
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
force_system=True,
)
@@ -639,15 +579,9 @@ _register_template(
_register_template(
name="dbrx",
format_user=StringFormatter(
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_system=StringFormatter(
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
),
format_observation=StringFormatter(
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
@@ -725,9 +659,7 @@ _register_template(
_register_template(
name="gemma",
format_user=StringFormatter(
slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
@@ -740,9 +672,7 @@ _register_template(
_register_template(
name="intern",
format_user=StringFormatter(
slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]
),
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
stop_words=["<eoa>"],
efficient_eos=True,
@@ -751,12 +681,8 @@ _register_template(
_register_template(
name="intern2",
format_user=StringFormatter(
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]
),
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
@@ -859,9 +785,7 @@ _register_template(
_register_template(
name="orion",
format_user=StringFormatter(
slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]
),
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
@@ -869,15 +793,9 @@ _register_template(
_register_template(
name="phi",
format_user=StringFormatter(
slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]
),
format_observation=StringFormatter(
slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]
),
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful AI assistant.",
stop_words=["<|end|>"],
@@ -887,15 +805,9 @@ _register_template(
_register_template(
name="qwen",
format_user=StringFormatter(
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_system=StringFormatter(
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
),
format_observation=StringFormatter(
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
@@ -951,12 +863,8 @@ _register_template(
_register_template(
name="yayi",
format_user=StringFormatter(
slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]
),
format_system=StringFormatter(
slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]
),
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant named YaYi "
@@ -975,9 +883,7 @@ _register_template(
_register_template(
name="yi",
format_user=StringFormatter(
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
@@ -995,9 +901,7 @@ _register_template(
_register_template(
name="zephyr",
format_user=StringFormatter(
slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]
),
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are a friendly chatbot who always responds in the style of a pirate",