imporve log

Former-commit-id: a6abf375975ffea3d51e1b944c9855b5f62ffac8
This commit is contained in:
hiyouga
2025-01-08 09:56:10 +00:00
parent 3b843ac9d4
commit 647c51a772
16 changed files with 78 additions and 67 deletions

View File

@@ -73,19 +73,31 @@ class AverageMeter:
self.avg = self.sum / self.count
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Optionally checks the package version.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if mandatory:
hint = f"To fix: run `pip install {requirement}`."
else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)
def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
check_version("transformers>=4.41.2,<=4.46.1")
check_version("datasets>=2.16.0,<=3.1.0")
check_version("accelerate>=0.34.0,<=1.0.1")
check_version("peft>=0.11.1,<=0.12.0")
check_version("trl>=0.8.6,<=0.9.6")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
@@ -253,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
return model_args.model_name_or_path
if use_modelscope():
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
@@ -264,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
)
if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
check_version("openmind>=0.8.0", mandatory=True)
from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download(