imporve log

Former-commit-id: a6abf375975ffea3d51e1b944c9855b5f62ffac8
This commit is contained in:
hiyouga
2025-01-08 09:56:10 +00:00
parent 3b843ac9d4
commit 647c51a772
16 changed files with 78 additions and 67 deletions

View File

@@ -29,11 +29,10 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, get_current_device
from ..extras.misc import check_dependencies, check_version, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
@@ -124,38 +123,35 @@ def _check_extra_dependencies(
finetuning_args: "FinetuningArguments",
training_args: Optional["TrainingArguments"] = None,
) -> None:
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
check_version("unsloth", mandatory=True)
if model_args.enable_liger_kernel:
require_version("liger-kernel", "To fix: pip install liger-kernel")
check_version("liger-kernel", mandatory=True)
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7")
check_version("vllm>=0.4.3,<0.6.7")
check_version("vllm", mandatory=True)
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
check_version("galore_torch", mandatory=True)
if finetuning_args.use_badam:
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
check_version("badam>=1.2.1", mandatory=True)
if finetuning_args.use_adam_mini:
require_version("adam-mini", "To fix: pip install adam-mini")
check_version("adam-mini", mandatory=True)
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
check_version("matplotlib", mandatory=True)
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
check_version("jieba", mandatory=True)
check_version("nltk", mandatory=True)
check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: