add option to disable version check

Former-commit-id: fd769cb2de696aee3c5e882237e16eace6a9d675
This commit is contained in:
hiyouga
2024-02-10 22:31:23 +08:00
parent 62b6a7971a
commit 5f83860aa1
4 changed files with 47 additions and 26 deletions

View File

@@ -8,8 +8,10 @@ import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.packages import is_unsloth_available
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
@@ -28,6 +30,14 @@ _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArgu
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _check_dependencies():
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
@@ -123,8 +133,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth")
_verify_model_args(model_args, finetuning_args)
if not finetuning_args.disable_version_checking:
_check_dependencies()
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
@@ -145,7 +161,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# postprocess training_args
# Post-process training arguments
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
@@ -158,7 +174,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
training_args.resume_from_checkpoint = None
if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None
else:
can_resume_from_checkpoint = True
@@ -194,7 +212,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
)
)
# postprocess model_args
# Post-process model arguments
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
@@ -212,7 +230,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
@@ -220,24 +237,30 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
_verify_model_args(model_args, finetuning_args)
if not finetuning_args.disable_version_checking:
_check_dependencies()
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
_verify_model_args(model_args, finetuning_args)
if not finetuning_args.disable_version_checking:
_check_dependencies()
transformers.set_seed(eval_args.seed)