support function calling
Former-commit-id: 66533b3f65babf2429c92c0f8fafe4eff5e0ff63
This commit is contained in:
@@ -1,40 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
from typing import List, Literal, Optional
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
|
||||
|
||||
def use_modelscope() -> bool:
|
||||
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: Optional[bool] = False
|
||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
messages: Optional[str] = "conversations"
|
||||
role: Optional[str] = "from"
|
||||
content: Optional[str] = "value"
|
||||
system: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
r"""
|
||||
@@ -126,64 +93,3 @@ class DataArguments:
|
||||
|
||||
if self.streaming and self.max_samples is not None:
|
||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||
|
||||
def init_for_training(self, seed: int): # support mixing multiple datasets
|
||||
self.seed = seed
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as err:
|
||||
if self.dataset is not None:
|
||||
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
|
||||
dataset_info = None
|
||||
|
||||
if self.interleave_probs is not None:
|
||||
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
||||
|
||||
self.dataset_list: List[DatasetAttr] = []
|
||||
for name in dataset_names:
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||
|
||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||
|
||||
if has_hf_url or has_ms_url:
|
||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||
dataset_attr = DatasetAttr(
|
||||
"ms_hub",
|
||||
dataset_name=dataset_info[name]["ms_hub_url"]
|
||||
)
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"hf_hub",
|
||||
dataset_name=dataset_info[name]["hf_hub_url"]
|
||||
)
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr(
|
||||
"script",
|
||||
dataset_name=dataset_info[name]["script_url"]
|
||||
)
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
dataset_name=dataset_info[name]["file_name"],
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||
)
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
||||
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
|
||||
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
|
||||
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
||||
dataset_attr.system = dataset_info[name]["columns"].get("system", None)
|
||||
|
||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||
dataset_attr.folder = dataset_info[name].get("folder", None)
|
||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
||||
Reference in New Issue
Block a user