[v1] model loader (#9613)

This commit is contained in:
Yaowei Zheng
2025-12-14 11:50:52 +08:00
committed by GitHub
parent fdd24276ed
commit aeda079014
27 changed files with 449 additions and 305 deletions

View File

@@ -0,0 +1,92 @@
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's verl library.
# https://github.com/volcengine/verl/blob/v0.6.1/verl/utils/torch_dtypes.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Union
import torch
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
from ..accelerator.interface import DistributedInterface
class DtypeRegistry:
HALF_LIST = ["fp16", "float16", "half", torch.float16]
FLOAT_LIST = ["fp32", "float32", "float", torch.float32]
BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16]
class DtypeInterface:
"""Type of precision used."""
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface.current_accelerator)
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface.current_accelerator)
_is_fp32_available = True
@staticmethod
def is_available(precision: Union[str, torch.dtype]) -> bool:
if precision in DtypeRegistry.HALF_LIST:
return DtypeInterface._is_fp16_available
elif precision in DtypeRegistry.FLOAT_LIST:
return DtypeInterface._is_fp32_available
elif precision in DtypeRegistry.BFLOAT_LIST:
return DtypeInterface._is_bf16_available
else:
raise RuntimeError(f"Unexpected precision: {precision}")
@staticmethod
def is_fp16(precision: Union[str, torch.dtype]) -> bool:
return precision in DtypeRegistry.HALF_LIST
@staticmethod
def is_fp32(precision: Union[str, torch.dtype]) -> bool:
return precision in DtypeRegistry.FLOAT_LIST
@staticmethod
def is_bf16(precision: Union[str, torch.dtype]) -> bool:
return precision in DtypeRegistry.BFLOAT_LIST
@staticmethod
def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype:
if precision in DtypeRegistry.HALF_LIST:
return torch.float16
elif precision in DtypeRegistry.FLOAT_LIST:
return torch.float32
elif precision in DtypeRegistry.BFLOAT_LIST:
return torch.bfloat16
else:
raise RuntimeError(f"Unexpected precision: {precision}")
@staticmethod
def to_str(precision: torch.dtype) -> str:
if precision == torch.float16:
return "float16"
elif precision == torch.float32:
return "float32"
elif precision == torch.bfloat16:
return "bfloat16"
else:
raise RuntimeError(f"Unexpected precision: {precision}")
@contextmanager
def set_dtype(self, precision: Union[str, torch.dtype]):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.to_dtype(precision))
try:
yield
finally:
torch.set_default_dtype(original_dtype)

View File

@@ -29,7 +29,7 @@ _default_log_level: "logging._Level" = logging.INFO
class _Logger(logging.Logger):
r"""A logger that supports rank0 logging."""
"""A logger that supports rank0 logging."""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
@@ -42,7 +42,7 @@ class _Logger(logging.Logger):
def _get_default_logging_level() -> "logging._Level":
r"""Return the default logging level."""
"""Return the default logging level."""
env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
@@ -62,7 +62,7 @@ def _get_library_root_logger() -> "_Logger":
def _configure_library_root_logger() -> None:
r"""Configure root logger using a stdout stream handler with an explicit format."""
"""Configure root logger using a stdout stream handler with an explicit format."""
global _default_handler
with _thread_lock:
@@ -82,7 +82,7 @@ def _configure_library_root_logger() -> None:
def get_logger(name: Optional[str] = None) -> "_Logger":
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None:
name = _get_library_name()
@@ -91,13 +91,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
def add_handler(handler: "logging.Handler") -> None:
r"""Add a handler to the root logger."""
"""Add a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""Remove a handler to the root logger."""
"""Remove a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)

View File

@@ -38,7 +38,7 @@ class BasePlugin:
self.name = name
@property
def register(self) -> Callable:
def register(self):
"""Decorator to register a function as a plugin.
Example usage:
@@ -60,7 +60,7 @@ class BasePlugin:
return decorator
def __call__(self, *args, **kwargs) -> Callable:
def __call__(self, *args, **kwargs):
"""Call the registered function with the given arguments.
Example usage:
@@ -75,6 +75,9 @@ class BasePlugin:
if __name__ == "__main__":
"""
python -m llamafactory.v1.utils.plugin
"""
class PrintPlugin(BasePlugin):
pass

View File

@@ -23,6 +23,7 @@ if TYPE_CHECKING:
import torch
import torch.utils.data
import transformers
from torch.distributed import ProcessGroup
from torch.distributed.fsdp import FullyShardedDataParallel
Tensor = torch.Tensor
@@ -37,6 +38,7 @@ if TYPE_CHECKING:
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
Optimizer = torch.optim.Optimizer
Scheduler = torch.optim.lr_scheduler.LRScheduler
ProcessGroup = ProcessGroup
else:
Tensor = None
TensorLike = None
@@ -50,6 +52,7 @@ else:
Processor = None
Optimizer = None
Scheduler = None
ProcessGroup = None
class DatasetInfo(TypedDict, total=False):
@@ -69,6 +72,19 @@ class DatasetInfo(TypedDict, total=False):
"""Is streaming dataset, default to False."""
class DistributedConfig(TypedDict, total=False):
mp_replicate_size: NotRequired[int]
"""Model parallel replicate size, default to 1."""
mp_shard_size: NotRequired[int]
"""Model parallel shard size, default to world_size // mp_replicate_size."""
dp_size: NotRequired[int]
"""Data parallel size, default to world_size // cp_size."""
cp_size: NotRequired[int]
"""Context parallel size, default to 1."""
timeout: NotRequired[int]
"""Timeout for distributed communication, default to 600."""
class Content(TypedDict):
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
value: str