support streaming data, fix #284 #274 #268

Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
This commit is contained in:
hiyouga
2023-07-31 23:33:00 +08:00
parent 124f61b404
commit dd3f3e9749
28 changed files with 478 additions and 344 deletions

View File

@@ -1,6 +1,6 @@
import os
import json
from typing import List, Optional
from typing import List, Literal, Optional
from dataclasses import dataclass, field
@@ -16,10 +16,10 @@ class DatasetAttr:
return self.dataset_name
def __post_init__(self):
self.prompt_column = "instruction"
self.query_column = "input"
self.response_column = "output"
self.history_column = None
self.prompt = "instruction"
self.query = "input"
self.response = "output"
self.history = None
@dataclass
@@ -27,8 +27,11 @@ class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template: str = field(
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
dataset: Optional[str] = field(
default="alpaca_zh",
default="alpaca_en",
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
)
dataset_dir: Optional[str] = field(
@@ -39,6 +42,18 @@ class DataArguments:
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."}
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable streaming mode."}
)
buffer_size: Optional[int] = field(
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing."}
)
overwrite_cache: Optional[bool] = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}
@@ -75,10 +90,6 @@ class DataArguments:
default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
)
prompt_template: Optional[str] = field(
default="default",
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
def init_for_training(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
@@ -111,9 +122,9 @@ class DataArguments:
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)