@@ -120,7 +120,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
if (
|
||||
finetuning_args.stage == "ppo"
|
||||
and training_args.report_to is not None
|
||||
and training_args.report_to
|
||||
and training_args.report_to[0] not in ["wandb", "tensorboard"]
|
||||
):
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
|
||||
Reference in New Issue
Block a user