follow #5115
Former-commit-id: 7d917e03e2df570139bae18227d9c7303a12de2a
This commit is contained in:
@@ -163,11 +163,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
if finetuning_args.stage != "sft":
|
||||
if training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if finetuning_args.stage != "sft" and data_args.neat_packing:
|
||||
raise ValueError("`neat_packing` cannot be set as True except SFT.")
|
||||
if data_args.neat_packing:
|
||||
raise ValueError("`neat_packing` cannot be set as True except SFT.")
|
||||
|
||||
if data_args.train_on_prompt or data_args.mask_history:
|
||||
raise ValueError("`train_on_prompt` or `mask_history` cannot be set as True except SFT.")
|
||||
|
||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
@@ -175,21 +179,18 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
|
||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
||||
if finetuning_args.stage == "ppo":
|
||||
if not training_args.do_train:
|
||||
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
if model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
|
||||
if (
|
||||
finetuning_args.stage == "ppo"
|
||||
and training_args.report_to
|
||||
and training_args.report_to[0] not in ["wandb", "tensorboard"]
|
||||
):
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
|
||||
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||
|
||||
Reference in New Issue
Block a user