mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
224 lines
6.9 KiB
Python
224 lines
6.9 KiB
Python
# Copyright 2025 the LlamaFactory team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import json
|
|
from typing import Any, Literal, NotRequired, TypedDict
|
|
|
|
from ...utils import logging
|
|
from ...utils.plugin import BasePlugin
|
|
from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class AlpacaSample(TypedDict, total=False):
|
|
system: NotRequired[str]
|
|
instruction: str
|
|
input: NotRequired[str]
|
|
output: str
|
|
|
|
|
|
SharegptMessage = TypedDict(
|
|
"SharegptMessage",
|
|
{"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str},
|
|
)
|
|
|
|
|
|
class SharegptSample(TypedDict, total=False):
|
|
conversations: list[SharegptMessage]
|
|
tools: NotRequired[str]
|
|
|
|
|
|
class OpenaiMessage(TypedDict, total=False):
|
|
role: Literal["user", "assistant", "tool"]
|
|
content: str
|
|
|
|
|
|
class OpenaiSample(TypedDict, total=False):
|
|
messages: list[OpenaiMessage]
|
|
|
|
|
|
class PairSample(TypedDict, total=False):
|
|
chosen: list[OpenaiMessage]
|
|
rejected: list[OpenaiMessage]
|
|
|
|
|
|
class DataConverterPlugin(BasePlugin):
|
|
"""Plugin for data converters."""
|
|
|
|
def __call__(self, raw_sample: dict[str, Any]) -> Sample:
|
|
return super().__call__(raw_sample)
|
|
|
|
|
|
@DataConverterPlugin("alpaca").register()
|
|
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
|
"""Convert Alpaca sample to SFT sample.
|
|
|
|
See raw example at: https://huggingface.co/datasets/llamafactory/alpaca_gpt4_en
|
|
|
|
Args:
|
|
raw_sample (AlpacaSample): Alpaca sample.
|
|
|
|
Returns:
|
|
SFTSample: SFT sample.
|
|
"""
|
|
messages = []
|
|
if "system" in raw_sample:
|
|
messages.append(
|
|
{"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
|
|
)
|
|
|
|
if "instruction" in raw_sample or "input" in raw_sample:
|
|
messages.append(
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "value": raw_sample.get("instruction", "") + raw_sample.get("input", "")}
|
|
],
|
|
"loss_weight": 0.0,
|
|
}
|
|
)
|
|
|
|
if "output" in raw_sample:
|
|
messages.append(
|
|
{"role": "assistant", "content": [{"type": "text", "value": raw_sample["output"]}], "loss_weight": 1.0}
|
|
)
|
|
|
|
return {"messages": messages}
|
|
|
|
|
|
@DataConverterPlugin("sharegpt").register()
|
|
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
|
"""Convert ShareGPT sample to SFT sample.
|
|
|
|
See raw example at: https://huggingface.co/datasets/llamafactory/glaive_toolcall_en
|
|
|
|
Args:
|
|
raw_sample (SharegptSample): ShareGPT sample.
|
|
|
|
Returns:
|
|
SFTSample: SFT sample.
|
|
"""
|
|
tag_mapping = {
|
|
"system": "system",
|
|
"human": "user",
|
|
"gpt": "assistant",
|
|
"observation": "tool",
|
|
"function_call": "assistant",
|
|
}
|
|
sample = {}
|
|
messages = []
|
|
for message in raw_sample.get("conversations", []):
|
|
tag = message["from"]
|
|
if tag not in tag_mapping:
|
|
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
|
|
elif tag == "function_call":
|
|
try:
|
|
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
|
|
except json.JSONDecodeError:
|
|
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
|
|
continue
|
|
|
|
if not isinstance(tool_calls, list):
|
|
tool_calls = [tool_calls]
|
|
|
|
messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
|
|
"loss_weight": 1.0,
|
|
}
|
|
)
|
|
else:
|
|
messages.append(
|
|
{
|
|
"role": tag_mapping[tag],
|
|
"content": [{"type": "text", "value": message["value"]}],
|
|
"loss_weight": 1.0 if tag == "gpt" else 0.0,
|
|
}
|
|
)
|
|
|
|
sample["messages"] = messages
|
|
|
|
tools = raw_sample.get("tools")
|
|
if tools:
|
|
try:
|
|
tools: list[dict[str, Any]] = json.loads(tools)
|
|
sample["tools"] = json.dumps(tools)
|
|
except json.JSONDecodeError:
|
|
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
|
|
|
return sample
|
|
|
|
|
|
@DataConverterPlugin("pair").register()
|
|
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
|
"""Convert Pair sample to DPO sample.
|
|
|
|
See raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
|
|
|
|
Args:
|
|
raw_sample (PairSample): pair sample with chosen, rejected fields.
|
|
|
|
Returns:
|
|
DPOSample: DPO sample with chosen_messages and rejected_messages.
|
|
"""
|
|
|
|
def process_message(raw_messages: list[OpenaiMessage]):
|
|
messages = []
|
|
for message in raw_messages:
|
|
if message["role"] == "tool":
|
|
try:
|
|
tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"])
|
|
except json.JSONDecodeError:
|
|
logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}")
|
|
continue
|
|
|
|
if not isinstance(tool_calls, list):
|
|
tool_calls = [tool_calls]
|
|
|
|
messages.append(
|
|
{
|
|
"role": message["role"],
|
|
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
|
|
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
|
}
|
|
)
|
|
else:
|
|
messages.append(
|
|
{
|
|
"role": message["role"],
|
|
"content": [{"type": "text", "value": message["content"]}],
|
|
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
|
}
|
|
)
|
|
|
|
return messages
|
|
|
|
sample = {}
|
|
sample["chosen_messages"] = process_message(raw_sample.get("chosen", []))
|
|
sample["rejected_messages"] = process_message(raw_sample.get("rejected", []))
|
|
|
|
tools = raw_sample.get("tools")
|
|
if tools:
|
|
try:
|
|
tools: list[dict[str, Any]] = json.loads(tools)
|
|
sample["tools"] = json.dumps(tools)
|
|
except json.JSONDecodeError:
|
|
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
|
|
|
return sample
|