[model] add qwen3 (#7885)

This commit is contained in:
hoshi-hiyouga
2025-04-29 09:34:05 +08:00
committed by GitHub
parent db9559456c
commit 98f23c6584
7 changed files with 171 additions and 33 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
@@ -60,7 +61,7 @@ class Template:
tools: Optional[str] = None,
) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
@@ -76,7 +77,7 @@ class Template:
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
@@ -110,12 +111,18 @@ class Template:
return token_ids
def _remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
remove_thought: bool,
) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids.
@@ -133,14 +140,18 @@ class Template:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER:
elements += self.format_user.apply(content=content, idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content)
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content)
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content)
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
@@ -317,6 +328,7 @@ class Llama2Template(Template):
messages: list[dict[str, str]],
system: str,
tools: str,
remove_thought: bool,
) -> list[list[int]]:
system = system or self.default_system
encoded_messages = []
@@ -330,14 +342,18 @@ class Llama2Template(Template):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER:
elements += self.format_user.apply(content=system_text + content)
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content)
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content)
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content)
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
@@ -476,6 +492,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
if len(user_slot) > len(user_slot_empty_system):
default_system = find_diff(user_slot_empty_system, user_slot)
@@ -1411,6 +1428,21 @@ register_template(
)
# copied from qwen template
register_template(
name="qwen3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
)
# copied from chatml template
register_template(
name="qwen2_audio",