fix ppo dataset bug #4012

Former-commit-id: 7fc51b2e93698ae5e012566af8481f4d861c873d
This commit is contained in:
hiyouga
2024-06-06 19:03:20 +08:00
parent d5559461c1
commit ca95e98ca0
4 changed files with 4 additions and 4 deletions

View File

@@ -130,7 +130,7 @@ def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "kto"],
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]: