Former-commit-id: 5b069a967823e659dbc70b0d50361b3ad248087e
This commit is contained in:
hiyouga
2023-10-14 19:20:11 +08:00
parent 8659084ab0
commit 27dd87c890
4 changed files with 46 additions and 27 deletions

View File

@@ -31,7 +31,7 @@ class DataArguments:
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
dataset: Optional[str] = field(
default="alpaca_en",
default=None,
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
)
dataset_dir: Optional[str] = field(
@@ -46,13 +46,17 @@ class DataArguments:
default=1024,
metadata={"help": "The maximum length of the model inputs after tokenization."}
)
train_on_prompt: Optional[bool] = field(
default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."}
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable streaming mode."}
metadata={"help": "Enable dataset streaming."}
)
buffer_size: Optional[int] = field(
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
@@ -95,10 +99,20 @@ class DataArguments:
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
)
def __post_init__(self):
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.")
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")
def init_for_training(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
try:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
except Exception:
dataset_info = None
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))