[v1] add cli sampler (#9721)

This commit is contained in:
Yaowei Zheng
2026-01-06 23:31:27 +08:00
committed by GitHub
parent e944dc442c
commit ea0b4e2466
45 changed files with 1091 additions and 505 deletions

View File

@@ -13,11 +13,12 @@
# limitations under the License.
import json
from typing import Any, Literal, NotRequired, TypedDict
from ...utils import logging
from ...utils.plugin import BasePlugin
from ...utils.types import DPOSample, Sample, SFTSample
from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
logger = logging.get_logger(__name__)
@@ -61,7 +62,7 @@ class DataConverterPlugin(BasePlugin):
return super().__call__(raw_sample)
@DataConverterPlugin("alpaca").register
@DataConverterPlugin("alpaca").register()
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample.
@@ -98,7 +99,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages}
@DataConverterPlugin("sharegpt").register
@DataConverterPlugin("sharegpt").register()
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"""Convert ShareGPT sample to SFT sample.
@@ -118,17 +119,32 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"function_call": "assistant",
}
messages = []
tools = raw_sample.get("tools", "")
tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
tools = []
for message in raw_sample.get("conversations", []):
tag = message["from"]
if tag not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
elif tag == "function_call":
try:
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
continue
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append(
{
"role": "assistant",
"content": [{"type": "tool_calls", "value": message["value"]}],
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0,
}
)
@@ -142,15 +158,12 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
)
if tools:
if messages and messages[0]["role"] == "system":
messages[0]["content"].append({"type": "tools", "value": tools})
else:
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
return {"messages": messages}
return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
else:
return {"messages": messages}
@DataConverterPlugin("pair").register
@DataConverterPlugin("pair").register()
def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to DPO sample.