Merge branch 'hiyouga:main' into main
Former-commit-id: 0395d0aafb69e86645e6b0a36b8f8dadb82219e0
This commit is contained in:
@@ -37,7 +37,7 @@ class Template:
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
reserved_label_len: Optional[int] = 16,
|
||||
reserved_label_len: Optional[int] = 1,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
@@ -57,7 +57,7 @@ class Template:
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
reserved_label_len: Optional[int] = 16,
|
||||
reserved_label_len: Optional[int] = 1,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
@@ -144,10 +144,10 @@ class Template:
|
||||
max_len=(cutoff_len - total_length),
|
||||
reserved_label_len=reserved_label_len,
|
||||
)
|
||||
encoded_messages[i] = encoded_messages[i][:max_source_len]
|
||||
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
|
||||
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
|
||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||
source_ids = encoded_messages[i][:max_source_len]
|
||||
target_ids = encoded_messages[i + 1][:max_target_len]
|
||||
total_length += len(source_ids) + len(target_ids)
|
||||
encoded_pairs.append((source_ids, target_ids))
|
||||
|
||||
return encoded_pairs
|
||||
|
||||
@@ -218,7 +218,7 @@ def register_template(
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||
default_tool_formatter = ToolFormatter(slots="default")
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
templates[name] = template_class(
|
||||
format_user=format_user or default_user_formatter,
|
||||
@@ -356,6 +356,14 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="cpm",
|
||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="deepseek",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||
@@ -464,7 +472,7 @@ register_template(
|
||||
|
||||
register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: </s>"]),
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user