imporve log
Former-commit-id: a6abf375975ffea3d51e1b944c9855b5f62ffac8
This commit is contained in:
@@ -29,11 +29,10 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.training_args import ParallelMode
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from ..extras.misc import check_dependencies, check_version, get_current_device
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
@@ -124,38 +123,35 @@ def _check_extra_dependencies(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
training_args: Optional["TrainingArguments"] = None,
|
||||
) -> None:
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
|
||||
if model_args.use_unsloth:
|
||||
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
|
||||
check_version("unsloth", mandatory=True)
|
||||
|
||||
if model_args.enable_liger_kernel:
|
||||
require_version("liger-kernel", "To fix: pip install liger-kernel")
|
||||
check_version("liger-kernel", mandatory=True)
|
||||
|
||||
if model_args.mixture_of_depths is not None:
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
check_version("mixture-of-depth>=1.1.6", mandatory=True)
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7")
|
||||
check_version("vllm>=0.4.3,<0.6.7")
|
||||
check_version("vllm", mandatory=True)
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
check_version("galore_torch", mandatory=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
|
||||
check_version("badam>=1.2.1", mandatory=True)
|
||||
|
||||
if finetuning_args.use_adam_mini:
|
||||
require_version("adam-mini", "To fix: pip install adam-mini")
|
||||
check_version("adam-mini", mandatory=True)
|
||||
|
||||
if finetuning_args.plot_loss:
|
||||
require_version("matplotlib", "To fix: pip install matplotlib")
|
||||
check_version("matplotlib", mandatory=True)
|
||||
|
||||
if training_args is not None and training_args.predict_with_generate:
|
||||
require_version("jieba", "To fix: pip install jieba")
|
||||
require_version("nltk", "To fix: pip install nltk")
|
||||
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
|
||||
check_version("jieba", mandatory=True)
|
||||
check_version("nltk", mandatory=True)
|
||||
check_version("rouge_chinese", mandatory=True)
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
|
||||
|
||||
Reference in New Issue
Block a user