[v1&WIP] dataloader init (#9645)

This commit is contained in:
Kingsley
2025-12-23 16:29:47 +08:00
committed by GitHub
parent 7901b2f32e
commit 1c8a42d2f8
7 changed files with 981 additions and 18 deletions

View File

@@ -14,6 +14,7 @@
from dataclasses import dataclass
from typing import Union
@dataclass
@@ -22,5 +23,112 @@ class Template:
assistant_template: str
system_template: str
def render_message(self, message: "dict[str, str]") -> str:
def render_message(self, message: dict[str, str]) -> str:
return self.user_template.format(**message)
@dataclass
class QwenTemplate:
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
thinking_template: str = "<think>\n{content}\n</think>\n\n"
def _extract_content(self, content_data: Union[str, list[dict[str, str]]]) -> str:
if isinstance(content_data, str):
return content_data.strip()
if isinstance(content_data, list):
parts = []
for item in content_data:
if item.get("type") == "text":
parts.append(item.get("value", ""))
elif item.get("type") == "image_url":
pass
return "\n".join(parts).strip()
return ""
def render_message(self, message: dict[str, Union[str, list[dict[str, str]]]]) -> str:
role = message["role"]
content = self._extract_content(message.get("content", ""))
if role == "assistant":
reasoning_content = message.get("reasoning_content", "")
if reasoning_content:
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
return self.message_template.format(role="assistant", content=reasoning_content + content)
else:
return self.message_template.format(role=role, content=content)
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
"""Encode one message."""
input_ids, attention_mask, labels = [], [], []
for message in messages:
content_str = self.render_message(message)
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
if hasattr(message, "loss_weight"):
loss_weight = message["loss_weight"]
else:
loss_weight = 1 if message["role"] == "assistant" else 0
if loss_weight == 1:
labels += content_ids
else:
labels += [-100] * len(content_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
model_inputs.update({"position_ids": list(range(len(input_ids)))})
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
if __name__ == "__main__":
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
out = []
for m in messages:
role = m["role"]
content = template._extract_content(m.get("content", ""))
if role == "assistant":
reasoning = (m.get("reasoning_content") or "").strip()
if reasoning:
content = template.thinking_template.format(content=reasoning) + content
out.append({"role": role, "content": content})
return out
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-30B-A3B-Thinking-2507",
trust_remote_code=True,
)
test_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [{"type": "text", "text": "1+1等于几"}, {"type": "text", "text": "2+2等于几"}],
},
{
"role": "assistant",
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
},
]
template = QwenTemplate()
rendered_custom = "".join([template.render_message(m) for m in test_messages])
qwen3_messages = to_qwen3_messages(template, test_messages)
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
print("==== custom ====")
print(rendered_custom)
print("==== hf ====")
print(rendered_hf)
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"