fix version checking

Former-commit-id: 5780da8d640609cca388f55983d0251e5547209a
This commit is contained in:
hiyouga
2024-03-06 14:51:51 +08:00
parent 6b407092d9
commit 73d9dfc7ab
18 changed files with 49 additions and 33 deletions

View File

@@ -75,8 +75,7 @@ class Formatter(ABC):
tool_format: Literal["default"] = "default"
@abstractmethod
def apply(self, **kwargs) -> SLOTS:
...
def apply(self, **kwargs) -> SLOTS: ...
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
raise NotImplementedError

View File

@@ -14,6 +14,7 @@ from transformers.utils import (
is_torch_npu_available,
is_torch_xpu_available,
)
from transformers.utils.versions import require_version
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger
@@ -56,6 +57,17 @@ class AverageMeter:
self.avg = self.sum / self.count
def check_dependencies() -> None:
if int(os.environ.get("DISABLE_VERSION_CHECK", "0")):
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.

View File

@@ -173,10 +173,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
)
disable_version_checking: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to disable version checking."},
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},

View File

@@ -7,7 +7,6 @@ import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.packages import is_unsloth_available
@@ -29,17 +28,6 @@ _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArgu
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _check_dependencies(disabled: bool) -> None:
if disabled:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
@@ -152,7 +140,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Unsloth does not support DoRA.")
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if (
training_args.do_train
@@ -249,7 +236,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
@@ -262,7 +248,6 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
model_args.aqlm_optimization = True
if data_args.template is None:

View File

@@ -5,7 +5,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from ..extras.misc import check_dependencies, count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass
@@ -20,6 +20,9 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
check_dependencies()
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
return {
"trust_remote_code": True,