support system column #1765
Former-commit-id: f425584a511c5e42bae8b3ba090eaa898b28adad
This commit is contained in:
@@ -83,7 +83,7 @@ def get_dataset(
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||
)
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"):
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if max_samples is not None: # truncate dataset
|
||||
@@ -91,8 +91,8 @@ def get_dataset(
|
||||
|
||||
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
# convert dataset from sharegpt format to alpaca format
|
||||
outputs = {"prompt": [], "query": [], "response": [], "history": []}
|
||||
for msg_list in examples[dataset_attr.messages]:
|
||||
outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
|
||||
for i, msg_list in enumerate(examples[dataset_attr.messages]):
|
||||
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
|
||||
if len(msg_list) == 0:
|
||||
continue
|
||||
@@ -116,6 +116,7 @@ def get_dataset(
|
||||
outputs["query"].append("")
|
||||
outputs["response"].append(msg_pairs[-1][1])
|
||||
outputs["history"].append(msg_pairs[:-1])
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -136,17 +137,10 @@ def get_dataset(
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
for column_name in ["prompt", "query", "response", "history"]: # align dataset
|
||||
for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
|
||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
||||
if dataset_attr.system_prompt: # add system prompt
|
||||
system_prompt = dataset_attr.system_prompt
|
||||
if data_args.streaming:
|
||||
dataset = dataset.map(lambda _: {"system": system_prompt})
|
||||
else:
|
||||
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
|
||||
|
||||
all_datasets.append(dataset)
|
||||
|
||||
if len(data_args.dataset_list) == 1:
|
||||
|
||||
Reference in New Issue
Block a user