add option to disable version check

Former-commit-id: fd769cb2de696aee3c5e882237e16eace6a9d675
This commit is contained in:
hiyouga
2024-02-10 22:31:23 +08:00
parent 62b6a7971a
commit 5f83860aa1
4 changed files with 47 additions and 26 deletions

View File

@@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
@@ -21,13 +20,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
@@ -63,7 +55,6 @@ def load_model_and_tokenizer(
model = None
if is_trainable and model_args.use_unsloth:
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_kwargs = {