mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
[v1] add batch generator (#9744)
This commit is contained in:
@@ -32,7 +32,8 @@ class AlpacaSample(TypedDict, total=False):
|
||||
|
||||
|
||||
SharegptMessage = TypedDict(
|
||||
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
|
||||
"SharegptMessage",
|
||||
{"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str},
|
||||
)
|
||||
|
||||
|
||||
@@ -118,15 +119,8 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"observation": "tool",
|
||||
"function_call": "assistant",
|
||||
}
|
||||
sample = {}
|
||||
messages = []
|
||||
tools = raw_sample.get("tools")
|
||||
if tools:
|
||||
try:
|
||||
tools: list[dict[str, Any]] = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||
tools = []
|
||||
|
||||
for message in raw_sample.get("conversations", []):
|
||||
tag = message["from"]
|
||||
if tag not in tag_mapping:
|
||||
@@ -157,10 +151,17 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
}
|
||||
)
|
||||
|
||||
sample["messages"] = messages
|
||||
|
||||
tools = raw_sample.get("tools")
|
||||
if tools:
|
||||
return {"messages": messages, "tools": json.dumps(tools)}
|
||||
else:
|
||||
return {"messages": messages}
|
||||
try:
|
||||
tools: list[dict[str, Any]] = json.loads(tools)
|
||||
sample["tools"] = json.dumps(tools)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@DataConverterPlugin("pair").register()
|
||||
@@ -179,17 +180,44 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
def process_message(raw_messages: list[OpenaiMessage]):
|
||||
messages = []
|
||||
for message in raw_messages:
|
||||
messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
"content": [{"type": "text", "value": message["content"]}],
|
||||
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||
}
|
||||
)
|
||||
if message["role"] == "tool":
|
||||
try:
|
||||
tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}")
|
||||
continue
|
||||
|
||||
if not isinstance(tool_calls, list):
|
||||
tool_calls = [tool_calls]
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
|
||||
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
"content": [{"type": "text", "value": message["content"]}],
|
||||
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
chosen_messages = process_message(raw_sample.get("chosen", []))
|
||||
rejected_messages = process_message(raw_sample.get("rejected", []))
|
||||
sample = {}
|
||||
sample["chosen_messages"] = process_message(raw_sample.get("chosen", []))
|
||||
sample["rejected_messages"] = process_message(raw_sample.get("rejected", []))
|
||||
|
||||
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
|
||||
tools = raw_sample.get("tools")
|
||||
if tools:
|
||||
try:
|
||||
tools: list[dict[str, Any]] = json.loads(tools)
|
||||
sample["tools"] = json.dumps(tools)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||
|
||||
return sample
|
||||
|
||||
Reference in New Issue
Block a user