Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@@ -16,10 +16,10 @@ class DatasetAttr:
|
||||
return self.dataset_name
|
||||
|
||||
def __post_init__(self):
|
||||
self.prompt_column = "instruction"
|
||||
self.query_column = "input"
|
||||
self.response_column = "output"
|
||||
self.history_column = None
|
||||
self.prompt = "instruction"
|
||||
self.query = "input"
|
||||
self.response = "output"
|
||||
self.history = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -27,8 +27,11 @@ class DataArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
template: str = field(
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
default="alpaca_zh",
|
||||
default="alpaca_en",
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
@@ -39,6 +42,18 @@ class DataArguments:
|
||||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||
)
|
||||
streaming: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable streaming mode."}
|
||||
)
|
||||
buffer_size: Optional[int] = field(
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
|
||||
)
|
||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing."}
|
||||
)
|
||||
overwrite_cache: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||
@@ -75,10 +90,6 @@ class DataArguments:
|
||||
default=0,
|
||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
||||
)
|
||||
prompt_template: Optional[str] = field(
|
||||
default="default",
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
|
||||
def init_for_training(self): # support mixing multiple datasets
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||
@@ -111,9 +122,9 @@ class DataArguments:
|
||||
dataset_attr.source_prefix = prefix_list[i]
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
||||
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
||||
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
||||
Reference in New Issue
Block a user