Merge branch 'main' into minicpmv
Former-commit-id: d8840ae416660e23f1d615ffd404f519360151d9
This commit is contained in:
@@ -1424,6 +1424,14 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
|
||||
},
|
||||
"Phi-3.5-4B-instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3.5-mini-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-mini-instruct",
|
||||
},
|
||||
"Phi-3.5-MoE-42B-A6.6B-instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3.5-MoE-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-MoE-instruct",
|
||||
},
|
||||
},
|
||||
template="phi",
|
||||
)
|
||||
@@ -1444,6 +1452,17 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-4-14B-Instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-4",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/phi-4",
|
||||
},
|
||||
},
|
||||
template="phi4",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Pixtral-12B-Instruct": {
|
||||
|
||||
@@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler):
|
||||
|
||||
class _Logger(logging.Logger):
|
||||
r"""
|
||||
A logger that supports info_rank0 and warning_once.
|
||||
A logger that supports rank0 logging.
|
||||
"""
|
||||
|
||||
def info_rank0(self, *args, **kwargs) -> None:
|
||||
@@ -77,7 +77,7 @@ class _Logger(logging.Logger):
|
||||
def warning_rank0(self, *args, **kwargs) -> None:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
def warning_once(self, *args, **kwargs) -> None:
|
||||
def warning_rank0_once(self, *args, **kwargs) -> None:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
logging.Logger.info_rank0 = info_rank0
|
||||
logging.Logger.warning_rank0 = warning_rank0
|
||||
logging.Logger.warning_once = warning_once
|
||||
logging.Logger.warning_rank0_once = warning_rank0_once
|
||||
|
||||
@@ -73,19 +73,31 @@ class AverageMeter:
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
r"""
|
||||
Optionally checks the package version.
|
||||
"""
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
|
||||
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
|
||||
if mandatory:
|
||||
hint = f"To fix: run `pip install {requirement}`."
|
||||
else:
|
||||
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
|
||||
|
||||
require_version(requirement, hint)
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
|
||||
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
|
||||
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
|
||||
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
check_version("transformers>=4.41.2,<=4.46.1")
|
||||
check_version("datasets>=2.16.0,<=3.1.0")
|
||||
check_version("accelerate>=0.34.0,<=1.0.1")
|
||||
check_version("peft>=0.11.1,<=0.12.0")
|
||||
check_version("trl>=0.8.6,<=0.9.6")
|
||||
|
||||
|
||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
@@ -229,7 +241,7 @@ def skip_check_imports() -> None:
|
||||
r"""
|
||||
Avoids flash attention import error in custom model files.
|
||||
"""
|
||||
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||
if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
|
||||
@@ -253,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
return model_args.model_name_or_path
|
||||
|
||||
if use_modelscope():
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
check_version("modelscope>=1.11.0", mandatory=True)
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
@@ -264,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
)
|
||||
|
||||
if use_openmind():
|
||||
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||
check_version("openmind>=0.8.0", mandatory=True)
|
||||
from openmind.utils.hub import snapshot_download # type: ignore
|
||||
|
||||
return snapshot_download(
|
||||
@@ -275,8 +287,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
|
||||
|
||||
def use_modelscope() -> bool:
|
||||
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
||||
return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
||||
|
||||
|
||||
def use_openmind() -> bool:
|
||||
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||
return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||
|
||||
|
||||
def use_ray() -> bool:
|
||||
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
|
||||
|
||||
@@ -62,6 +62,10 @@ def is_pillow_available():
|
||||
return _is_package_available("PIL")
|
||||
|
||||
|
||||
def is_ray_available():
|
||||
return _is_package_available("ray")
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
return _is_package_available("requests")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user