Former-commit-id: 4715e5c5b8040b21e5f401f7e969b9fd2757d520
This commit is contained in:
@@ -41,17 +41,17 @@ class DataArguments:
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
train_last_turn_only: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train the last turn only."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||
)
|
||||
train_on_prompt: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
||||
metadata={"help": "Whether or not to disable the mask on the prompt."},
|
||||
)
|
||||
mask_history: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to mask the history and train on the last turn only."},
|
||||
)
|
||||
streaming: bool = field(
|
||||
default=False,
|
||||
|
||||
@@ -162,9 +162,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
# Check arguments
|
||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if finetuning_args.stage == "pt" and data_args.train_last_turn_only:
|
||||
raise ValueError("PT stage does not support `train_last_turn_only`.")
|
||||
|
||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
Reference in New Issue
Block a user