support system column #1765

Former-commit-id: f425584a511c5e42bae8b3ba090eaa898b28adad
This commit is contained in:
hiyouga
2023-12-12 19:45:59 +08:00
parent c27675f70d
commit 934d00ea1e
14 changed files with 42 additions and 76 deletions

View File

@@ -17,7 +17,6 @@ class DatasetAttr:
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: Optional[bool] = False
@@ -30,6 +29,7 @@ class DatasetAttr:
messages: Optional[str] = "conversations"
role: Optional[str] = "from"
content: Optional[str] = "value"
system: Optional[str] = None
def __repr__(self) -> str:
return self.dataset_name
@@ -104,10 +104,6 @@ class DataArguments:
default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
)
system_prompt: Optional[str] = field(
default=None,
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
)
val_size: Optional[float] = field(
default=0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
@@ -145,15 +141,11 @@ class DataArguments:
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
dataset_info = None
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
self.dataset_list: List[DatasetAttr] = []
for i, name in enumerate(dataset_names):
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
@@ -191,10 +183,10 @@ class DataArguments:
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
dataset_attr.system = dataset_info[name]["columns"].get("system", None)
dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.folder = dataset_info[name].get("folder", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr)