Merge branch 'main' into minicpmv

Former-commit-id: d8840ae416660e23f1d615ffd404f519360151d9
This commit is contained in:
Zhangchi Feng
2025-01-10 20:12:07 +08:00
committed by GitHub
41 changed files with 647 additions and 357 deletions

View File

@@ -1424,6 +1424,14 @@ register_model_group(
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
},
"Phi-3.5-4B-instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3.5-mini-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-mini-instruct",
},
"Phi-3.5-MoE-42B-A6.6B-instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3.5-MoE-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-MoE-instruct",
},
},
template="phi",
)
@@ -1444,6 +1452,17 @@ register_model_group(
)
register_model_group(
models={
"Phi-4-14B-Instruct": {
DownloadSource.DEFAULT: "microsoft/phi-4",
DownloadSource.MODELSCOPE: "LLM-Research/phi-4",
},
},
template="phi4",
)
register_model_group(
models={
"Pixtral-12B-Instruct": {

View File

@@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler):
class _Logger(logging.Logger):
r"""
A logger that supports info_rank0 and warning_once.
A logger that supports rank0 logging.
"""
def info_rank0(self, *args, **kwargs) -> None:
@@ -77,7 +77,7 @@ class _Logger(logging.Logger):
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def warning_once(self, *args, **kwargs) -> None:
def warning_rank0_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
@@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
@lru_cache(None)
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_once = warning_once
logging.Logger.warning_rank0_once = warning_rank0_once

View File

@@ -73,19 +73,31 @@ class AverageMeter:
self.avg = self.sum / self.count
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Optionally checks the package version.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if mandatory:
hint = f"To fix: run `pip install {requirement}`."
else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)
def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
check_version("transformers>=4.41.2,<=4.46.1")
check_version("datasets>=2.16.0,<=3.1.0")
check_version("accelerate>=0.34.0,<=1.0.1")
check_version("peft>=0.11.1,<=0.12.0")
check_version("trl>=0.8.6,<=0.9.6")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
@@ -229,7 +241,7 @@ def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports
@@ -253,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
return model_args.model_name_or_path
if use_modelscope():
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
@@ -264,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
)
if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
check_version("openmind>=0.8.0", mandatory=True)
from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download(
@@ -275,8 +287,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def use_ray() -> bool:
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]

View File

@@ -62,6 +62,10 @@ def is_pillow_available():
return _is_package_available("PIL")
def is_ray_available():
return _is_package_available("ray")
def is_requests_available():
return _is_package_available("requests")