mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-30 06:12:04 +00:00
[feat] support all_exhausted_without_replacement in datasets.interleave_datasets (#10112)
This commit is contained in:
@@ -65,11 +65,17 @@ def merge_dataset(
|
||||
if not data_args.streaming:
|
||||
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
|
||||
strategy_map: str = {
|
||||
"interleave_under": "first_exhausted",
|
||||
"interleave_over": "all_exhausted",
|
||||
"interleave_once": "all_exhausted_without_replacement",
|
||||
}[data_args.mix_strategy]
|
||||
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
stopping_strategy=strategy_map, # type: ignore
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -63,9 +63,9 @@ class DataArguments:
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
||||
)
|
||||
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
|
||||
mix_strategy: Literal["concat", "interleave_under", "interleave_over", "interleave_once"] = field(
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling/sampling w.o. replacement)."},
|
||||
)
|
||||
interleave_probs: str | None = field(
|
||||
default=None,
|
||||
|
||||
Reference in New Issue
Block a user