mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-31 06:42:05 +00:00
[v1&WIP] dataloader init (#9645)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -22,5 +23,112 @@ class Template:
|
||||
assistant_template: str
|
||||
system_template: str
|
||||
|
||||
def render_message(self, message: "dict[str, str]") -> str:
|
||||
def render_message(self, message: dict[str, str]) -> str:
|
||||
return self.user_template.format(**message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QwenTemplate:
|
||||
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
|
||||
thinking_template: str = "<think>\n{content}\n</think>\n\n"
|
||||
|
||||
def _extract_content(self, content_data: Union[str, list[dict[str, str]]]) -> str:
|
||||
if isinstance(content_data, str):
|
||||
return content_data.strip()
|
||||
|
||||
if isinstance(content_data, list):
|
||||
parts = []
|
||||
for item in content_data:
|
||||
if item.get("type") == "text":
|
||||
parts.append(item.get("value", ""))
|
||||
elif item.get("type") == "image_url":
|
||||
pass
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
return ""
|
||||
|
||||
def render_message(self, message: dict[str, Union[str, list[dict[str, str]]]]) -> str:
|
||||
role = message["role"]
|
||||
content = self._extract_content(message.get("content", ""))
|
||||
|
||||
if role == "assistant":
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
if reasoning_content:
|
||||
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
|
||||
return self.message_template.format(role="assistant", content=reasoning_content + content)
|
||||
else:
|
||||
return self.message_template.format(role=role, content=content)
|
||||
|
||||
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
|
||||
"""Encode one message."""
|
||||
input_ids, attention_mask, labels = [], [], []
|
||||
for message in messages:
|
||||
content_str = self.render_message(message)
|
||||
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
|
||||
input_ids += content_ids
|
||||
attention_mask += [1] * len(content_ids)
|
||||
|
||||
if hasattr(message, "loss_weight"):
|
||||
loss_weight = message["loss_weight"]
|
||||
else:
|
||||
loss_weight = 1 if message["role"] == "assistant" else 0
|
||||
if loss_weight == 1:
|
||||
labels += content_ids
|
||||
else:
|
||||
labels += [-100] * len(content_ids)
|
||||
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
||||
model_inputs.update({"position_ids": list(range(len(input_ids)))})
|
||||
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
|
||||
return model_inputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
|
||||
out = []
|
||||
for m in messages:
|
||||
role = m["role"]
|
||||
content = template._extract_content(m.get("content", ""))
|
||||
if role == "assistant":
|
||||
reasoning = (m.get("reasoning_content") or "").strip()
|
||||
if reasoning:
|
||||
content = template.thinking_template.format(content=reasoning) + content
|
||||
out.append({"role": role, "content": content})
|
||||
return out
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "1+1等于几?"}, {"type": "text", "text": "2+2等于几?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
|
||||
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
|
||||
},
|
||||
]
|
||||
|
||||
template = QwenTemplate()
|
||||
rendered_custom = "".join([template.render_message(m) for m in test_messages])
|
||||
|
||||
qwen3_messages = to_qwen3_messages(template, test_messages)
|
||||
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
|
||||
|
||||
print("==== custom ====")
|
||||
print(rendered_custom)
|
||||
print("==== hf ====")
|
||||
print(rendered_hf)
|
||||
|
||||
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
|
||||
|
||||
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
|
||||
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
|
||||
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"
|
||||
|
||||
Reference in New Issue
Block a user