[model] add qwen3 (#7885)
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
@@ -60,7 +61,7 @@ class Template:
|
||||
tools: Optional[str] = None,
|
||||
) -> tuple[list[int], list[int]]:
|
||||
r"""Return a single pair of token ids representing prompt and response respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True)
|
||||
prompt_ids = []
|
||||
for encoded_ids in encoded_messages[:-1]:
|
||||
prompt_ids += encoded_ids
|
||||
@@ -76,7 +77,7 @@ class Template:
|
||||
tools: Optional[str] = None,
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False)
|
||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||
|
||||
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
@@ -110,12 +111,18 @@ class Template:
|
||||
|
||||
return token_ids
|
||||
|
||||
def _remove_thought(self, content: str) -> str:
|
||||
r"""Remove thought from assistant message."""
|
||||
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
|
||||
return re.sub(pattern, "", content).lstrip("\n")
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
remove_thought: bool,
|
||||
) -> list[list[int]]:
|
||||
r"""Encode formatted inputs to pairs of token ids.
|
||||
|
||||
@@ -133,14 +140,18 @@ class Template:
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
content = message["content"]
|
||||
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
|
||||
content = self._remove_thought(content)
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
elements += self.format_user.apply(content=content, idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elements += self.format_assistant.apply(content=content)
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elements += self.format_observation.apply(content=content)
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elements += self.format_function.apply(content=content)
|
||||
else:
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
@@ -317,6 +328,7 @@ class Llama2Template(Template):
|
||||
messages: list[dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
remove_thought: bool,
|
||||
) -> list[list[int]]:
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
@@ -330,14 +342,18 @@ class Llama2Template(Template):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=system_text + message["content"])
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
content = message["content"]
|
||||
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
|
||||
content = self._remove_thought(content)
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
elements += self.format_user.apply(content=system_text + content)
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elements += self.format_assistant.apply(content=content)
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elements += self.format_observation.apply(content=content)
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elements += self.format_function.apply(content=content)
|
||||
else:
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
@@ -476,6 +492,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
|
||||
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
|
||||
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
|
||||
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
|
||||
|
||||
if len(user_slot) > len(user_slot_empty_system):
|
||||
default_system = find_diff(user_slot_empty_system, user_slot)
|
||||
@@ -1411,6 +1428,21 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen template
|
||||
register_template(
|
||||
name="qwen3",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen"),
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
register_template(
|
||||
name="qwen2_audio",
|
||||
|
||||
@@ -2403,6 +2403,69 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen3-0.6B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-Base",
|
||||
},
|
||||
"Qwen3-1.7B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-Base",
|
||||
},
|
||||
"Qwen3-4B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-Base",
|
||||
},
|
||||
"Qwen3-8B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-Base",
|
||||
},
|
||||
"Qwen3-14B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-Base",
|
||||
},
|
||||
"Qwen3-30B-A3B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-Base",
|
||||
},
|
||||
"Qwen3-0.6B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B",
|
||||
},
|
||||
"Qwen3-1.7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B",
|
||||
},
|
||||
"Qwen3-4B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-4B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B",
|
||||
},
|
||||
"Qwen3-8B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-8B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B",
|
||||
},
|
||||
"Qwen3-14B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-14B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B",
|
||||
},
|
||||
"Qwen3-32B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-32B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B",
|
||||
},
|
||||
"Qwen3-30B-A3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B",
|
||||
},
|
||||
"Qwen3-235B-A22B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
|
||||
},
|
||||
},
|
||||
template="qwen3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen2-Audio-7B": {
|
||||
|
||||
@@ -56,11 +56,11 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
|
||||
Inputs: top.quantization_method
|
||||
Outputs: top.quantization_bit
|
||||
"""
|
||||
if quantization_method == QuantizationMethod.BNB.value:
|
||||
if quantization_method == QuantizationMethod.BNB:
|
||||
available_bits = ["none", "8", "4"]
|
||||
elif quantization_method == QuantizationMethod.HQQ.value:
|
||||
elif quantization_method == QuantizationMethod.HQQ:
|
||||
available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"]
|
||||
elif quantization_method == QuantizationMethod.EETQ.value:
|
||||
elif quantization_method == QuantizationMethod.EETQ:
|
||||
available_bits = ["none", "8"]
|
||||
|
||||
return gr.Dropdown(choices=available_bits)
|
||||
|
||||
Reference in New Issue
Block a user