imporve log
Former-commit-id: a6abf375975ffea3d51e1b944c9855b5f62ffac8
This commit is contained in:
@@ -73,19 +73,31 @@ class AverageMeter:
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
r"""
|
||||
Optionally checks the package version.
|
||||
"""
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
|
||||
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
|
||||
if mandatory:
|
||||
hint = f"To fix: run `pip install {requirement}`."
|
||||
else:
|
||||
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
|
||||
|
||||
require_version(requirement, hint)
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
|
||||
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
check_version("transformers>=4.41.2,<=4.46.1")
|
||||
check_version("datasets>=2.16.0,<=3.1.0")
|
||||
check_version("accelerate>=0.34.0,<=1.0.1")
|
||||
check_version("peft>=0.11.1,<=0.12.0")
|
||||
check_version("trl>=0.8.6,<=0.9.6")
|
||||
|
||||
|
||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
@@ -253,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
return model_args.model_name_or_path
|
||||
|
||||
if use_modelscope():
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
check_version("modelscope>=1.11.0", mandatory=True)
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
@@ -264,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
)
|
||||
|
||||
if use_openmind():
|
||||
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||
check_version("openmind>=0.8.0", mandatory=True)
|
||||
from openmind.utils.hub import snapshot_download # type: ignore
|
||||
|
||||
return snapshot_download(
|
||||
|
||||
Reference in New Issue
Block a user