add tool test

Former-commit-id: 639a355a9ceb2e4585b81aea71fc810f4b510776
This commit is contained in:
hiyouga
2024-01-18 10:26:26 +08:00
parent a423274fd9
commit d8affd3967
9 changed files with 63 additions and 37 deletions

View File

@@ -1,6 +1,6 @@
from .loader import get_dataset
from .template import get_template_and_fix_tokenizer, templates
from .utils import split_dataset
from .utils import split_dataset, Role
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset"]
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"]

View File

@@ -97,7 +97,7 @@ def preprocess_packed_supervised_dataset(
messages = examples["prompt"][i] + examples["response"][i]
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], 1_000_000
tokenizer, messages, examples["system"][i], examples["tool"][i]
)):
if data_args.train_on_prompt:
source_mask = source_ids

View File

@@ -33,13 +33,13 @@ class Template:
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tool: str,
cutoff_len: int
tools: str,
cutoff_len: Optional[int] = 1_000_000
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
@@ -52,13 +52,13 @@ class Template:
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tool: str,
cutoff_len: int
tools: str,
cutoff_len: Optional[int] = 1_000_000
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
return encoded_pairs
def _encode(
@@ -66,7 +66,7 @@ class Template:
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tool: str,
tools: str,
cutoff_len: int
) -> List[Tuple[List[int], List[int]]]:
r"""
@@ -78,8 +78,8 @@ class Template:
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tool):
tool_text = self.format_tool(content=tool)[0] if tool else ""
if i == 0 and (system or tools):
tool_text = self.format_tool(content=tools)[0] if tools else ""
elements += self.format_system(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.separator
@@ -131,7 +131,7 @@ class Llama2Template(Template):
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tool: str,
tools: str,
cutoff_len: int
) -> List[Tuple[List[int], List[int]]]:
r"""
@@ -144,8 +144,8 @@ class Llama2Template(Template):
for i, message in enumerate(messages):
elements = []
system_text = ""
if i == 0 and (system or tool):
tool_text = self.format_tool(content=tool)[0] if tool else ""
if i == 0 and (system or tools):
tool_text = self.format_tool(content=tools)[0] if tools else ""
system_text = self.format_system(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
elements += self.separator