fix mask_history tiny bug

Former-commit-id: cac07aac6196be026f723b2397a343d4fb675973
This commit is contained in:
“Wzw”
2024-08-08 10:09:33 +08:00
parent eada49e56b
commit d71446e387
2 changed files with 13 additions and 8 deletions

View File

@@ -69,12 +69,16 @@ class Template:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
mask_history: bool = False,
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
if not mask_history:
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
else:
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(len(encoded_messages)-2, -1, -2)]
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
r"""