[feat] support all_exhausted_without_replacement in datasets.interleave_datasets (#10112)

This commit is contained in:
Meng WANG
2026-01-20 15:54:07 +08:00
committed by GitHub
parent db2f794f7b
commit e70651ac58
2 changed files with 9 additions and 3 deletions

View File

@@ -65,11 +65,17 @@ def merge_dataset(
if not data_args.streaming:
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
strategy_map: str = {
"interleave_under": "first_exhausted",
"interleave_over": "all_exhausted",
"interleave_once": "all_exhausted_without_replacement",
}[data_args.mix_strategy]
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
stopping_strategy=strategy_map, # type: ignore
)
else:

View File

@@ -63,9 +63,9 @@ class DataArguments:
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
)
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
mix_strategy: Literal["concat", "interleave_under", "interleave_over", "interleave_once"] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling/sampling w.o. replacement)."},
)
interleave_probs: str | None = field(
default=None,