mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 20:23:37 +00:00
[deps] goodbye python 3.9 (#9677)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -16,7 +16,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
|
||||
@@ -38,7 +37,7 @@ class DtypeInterface:
|
||||
_is_fp32_available = True
|
||||
|
||||
@staticmethod
|
||||
def is_available(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_available(precision: str | torch.dtype) -> bool:
|
||||
if precision in DtypeRegistry.HALF_LIST:
|
||||
return DtypeInterface._is_fp16_available
|
||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||
@@ -49,19 +48,19 @@ class DtypeInterface:
|
||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||
|
||||
@staticmethod
|
||||
def is_fp16(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_fp16(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.HALF_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_fp32(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_fp32(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.FLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_bf16(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_bf16(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.BFLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype:
|
||||
def to_dtype(precision: str | torch.dtype) -> torch.dtype:
|
||||
if precision in DtypeRegistry.HALF_LIST:
|
||||
return torch.float16
|
||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||
@@ -83,7 +82,7 @@ class DtypeInterface:
|
||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||
|
||||
@contextmanager
|
||||
def set_dtype(self, precision: Union[str, torch.dtype]):
|
||||
def set_dtype(self, precision: str | torch.dtype):
|
||||
original_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(self.to_dtype(precision))
|
||||
try:
|
||||
|
||||
@@ -81,7 +81,7 @@ def _configure_library_root_logger() -> None:
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
def get_logger(name: str | None = None) -> "_Logger":
|
||||
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
from . import logging
|
||||
|
||||
@@ -29,7 +29,7 @@ class BasePlugin:
|
||||
|
||||
_registry: dict[str, Callable] = {}
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
def __init__(self, name: str | None = None):
|
||||
"""Initialize the plugin with a name.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -12,9 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Union
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
Reference in New Issue
Block a user