@@ -8,7 +8,6 @@ import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies
|
||||
@@ -119,6 +118,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
|
||||
if (
|
||||
finetuning_args.stage == "ppo"
|
||||
and training_args.report_to is not None
|
||||
and training_args.report_to[0] not in ["wandb", "tensorboard"]
|
||||
):
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
@@ -128,12 +134,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
|
||||
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
if model_args.quantization_bit is not None:
|
||||
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
|
||||
|
||||
if model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
if finetuning_args.use_dora and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
|
||||
if finetuning_args.pure_bf16:
|
||||
if not is_torch_bf16_gpu_available():
|
||||
|
||||
Reference in New Issue
Block a user