add docstrings, refactor logger

Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
hiyouga
2024-09-08 00:56:56 +08:00
parent 93d4570a59
commit 7f71276ad8
30 changed files with 334 additions and 57 deletions

View File

@@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
@@ -67,14 +70,16 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError("Unknown mixing strategy.")
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
) -> "DatasetDict":
r"""
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
"""
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)