[deps] update to transformers 4.52 (#8125)

This commit is contained in:
hoshi-hiyouga
2025-05-21 05:16:18 +08:00
committed by GitHub
parent 56926d76f9
commit 9ae17cd173
28 changed files with 365 additions and 109 deletions

View File

@@ -52,7 +52,7 @@ class Template:
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
enable_thinking: bool
enable_thinking: Optional[bool]
mm_plugin: "BasePlugin"
def encode_oneturn(
@@ -411,14 +411,17 @@ class ReasoningTemplate(Template):
for i in range(1, len(messages) - 2, 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
if self.enable_thinking is False: # remove all cot
messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
if (
self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"]
):
if not self.enable_thinking:
prompt_ids = prompt_ids + self.get_thought_word_ids(tokenizer)
else:
): # add empty cot
if not self.enable_thinking: # do not compute loss
prompt_ids += self.get_thought_word_ids(tokenizer)
else: # do compute loss
response_ids = self.get_thought_word_ids(tokenizer) + response_ids
return prompt_ids, response_ids
@@ -431,15 +434,20 @@ class ReasoningTemplate(Template):
system: Optional[str] = None,
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages)
if self.enable_thinking is False: # remove all cot
for i in range(1, len(messages), 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(0, len(messages), 2):
if (
self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"]
):
if not self.enable_thinking:
): # add empty cot
if not self.enable_thinking: # do not compute loss
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
else:
else: # do compute loss
encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
@@ -463,7 +471,7 @@ def register_template(
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
enable_thinking: bool = True,
enable_thinking: Optional[bool] = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: type["Template"] = Template,
) -> None: