add tool test
Former-commit-id: 639a355a9ceb2e4585b81aea71fc810f4b510776
This commit is contained in:
@@ -33,13 +33,13 @@ class Template:
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tool: str,
|
||||
cutoff_len: int
|
||||
tools: str,
|
||||
cutoff_len: Optional[int] = 1_000_000
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
|
||||
prompt_ids = []
|
||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||
prompt_ids = prompt_ids + query_ids + resp_ids
|
||||
@@ -52,13 +52,13 @@ class Template:
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tool: str,
|
||||
cutoff_len: int
|
||||
tools: str,
|
||||
cutoff_len: Optional[int] = 1_000_000
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
|
||||
return encoded_pairs
|
||||
|
||||
def _encode(
|
||||
@@ -66,7 +66,7 @@ class Template:
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tool: str,
|
||||
tools: str,
|
||||
cutoff_len: int
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
@@ -78,8 +78,8 @@ class Template:
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
if i == 0 and (system or tool):
|
||||
tool_text = self.format_tool(content=tool)[0] if tool else ""
|
||||
if i == 0 and (system or tools):
|
||||
tool_text = self.format_tool(content=tools)[0] if tools else ""
|
||||
elements += self.format_system(content=(system + tool_text))
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.separator
|
||||
@@ -131,7 +131,7 @@ class Llama2Template(Template):
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tool: str,
|
||||
tools: str,
|
||||
cutoff_len: int
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
@@ -144,8 +144,8 @@ class Llama2Template(Template):
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
system_text = ""
|
||||
if i == 0 and (system or tool):
|
||||
tool_text = self.format_tool(content=tool)[0] if tool else ""
|
||||
if i == 0 and (system or tools):
|
||||
tool_text = self.format_tool(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system(content=(system + tool_text))[0]
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.separator
|
||||
|
||||
Reference in New Issue
Block a user