add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
@@ -48,6 +48,9 @@ def _load_single_dataset(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Loads a single dataset and aligns it to the standard format.
|
||||
"""
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||
@@ -117,7 +120,7 @@ def _load_single_dataset(
|
||||
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
indexes = np.random.permutation(len(dataset))[:target_num]
|
||||
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
||||
target_num -= len(indexes)
|
||||
if target_num > 0:
|
||||
expand_indexes = np.random.choice(len(dataset), target_num)
|
||||
@@ -141,6 +144,9 @@ def _get_merged_dataset(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
r"""
|
||||
Gets the merged datasets in the standard format.
|
||||
"""
|
||||
if dataset_names is None:
|
||||
return None
|
||||
|
||||
@@ -164,6 +170,9 @@ def _get_preprocessed_dataset(
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
is_eval: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
r"""
|
||||
Preprocesses the dataset, including format checking and tokenization.
|
||||
"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
@@ -209,6 +218,9 @@ def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> "DatasetModule":
|
||||
r"""
|
||||
Gets the train dataset and optionally gets the evaluation dataset.
|
||||
"""
|
||||
# Load tokenized dataset
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
|
||||
Reference in New Issue
Block a user