mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
[v1] add cli sampler (#9721)
This commit is contained in:
@@ -13,11 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import DPOSample, Sample, SFTSample
|
||||
from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -61,7 +62,7 @@ class DataConverterPlugin(BasePlugin):
|
||||
return super().__call__(raw_sample)
|
||||
|
||||
|
||||
@DataConverterPlugin("alpaca").register
|
||||
@DataConverterPlugin("alpaca").register()
|
||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
"""Convert Alpaca sample to SFT sample.
|
||||
|
||||
@@ -98,7 +99,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@DataConverterPlugin("sharegpt").register
|
||||
@DataConverterPlugin("sharegpt").register()
|
||||
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"""Convert ShareGPT sample to SFT sample.
|
||||
|
||||
@@ -118,17 +119,32 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"function_call": "assistant",
|
||||
}
|
||||
messages = []
|
||||
tools = raw_sample.get("tools", "")
|
||||
tools = raw_sample.get("tools")
|
||||
if tools:
|
||||
try:
|
||||
tools: list[dict[str, Any]] = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||
tools = []
|
||||
|
||||
for message in raw_sample.get("conversations", []):
|
||||
tag = message["from"]
|
||||
if tag not in tag_mapping:
|
||||
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
|
||||
elif tag == "function_call":
|
||||
try:
|
||||
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
|
||||
continue
|
||||
|
||||
if not isinstance(tool_calls, list):
|
||||
tool_calls = [tool_calls]
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_calls", "value": message["value"]}],
|
||||
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
@@ -142,15 +158,12 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
)
|
||||
|
||||
if tools:
|
||||
if messages and messages[0]["role"] == "system":
|
||||
messages[0]["content"].append({"type": "tools", "value": tools})
|
||||
else:
|
||||
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
|
||||
|
||||
return {"messages": messages}
|
||||
return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
|
||||
else:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@DataConverterPlugin("pair").register
|
||||
@DataConverterPlugin("pair").register()
|
||||
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
"""Convert Pair sample to DPO sample.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user