[breaking] bump transformers to 4.45.0 & improve ci (#7746)

* update ci

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
hoshi-hiyouga
2025-04-17 02:36:48 +08:00
committed by GitHub
parent d222f63cb7
commit 86ebb219d6
23 changed files with 211 additions and 140 deletions

View File

@@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.41.2,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("transformers>=4.43.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.5.0")
check_version("accelerate>=0.34.0,<=1.6.0")
check_version("peft>=0.14.0,<=0.15.1")
@@ -141,13 +141,13 @@ def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
def get_current_device() -> "torch.device":
r"""Get the current available device."""
if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_npu_available():
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_mps_available():
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_cuda_available():
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
else:
device = "cpu"
@@ -155,11 +155,13 @@ def get_current_device() -> "torch.device":
def get_device_count() -> int:
r"""Get the number of available GPU or NPU devices."""
r"""Get the number of available devices."""
if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
return torch.npu.device_count()
elif is_torch_mps_available():
return torch.mps.device_count()
elif is_torch_cuda_available():
return torch.cuda.device_count()
else:
@@ -175,10 +177,12 @@ def get_logits_processor() -> "LogitsProcessorList":
def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes)."""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_xpu_available():
if is_torch_xpu_available():
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
elif is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_mps_available():
return torch.mps.current_allocated_memory(), -1
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else:
@@ -200,9 +204,11 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
return torch.float32
def is_gpu_or_npu_available() -> bool:
r"""Check if the GPU or NPU is available."""
return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available()
def is_accelerator_available() -> bool:
r"""Check if the accelerator is available."""
return (
is_torch_xpu_available() or is_torch_npu_available() or is_torch_mps_available() or is_torch_cuda_available()
)
def is_env_enabled(env_var: str, default: str = "0") -> bool:
@@ -229,7 +235,7 @@ def skip_check_imports() -> None:
def torch_gc() -> None:
r"""Collect GPU or NPU memory."""
r"""Collect the device memory."""
gc.collect()
if is_torch_xpu_available():
torch.xpu.empty_cache()
@@ -280,7 +286,7 @@ def use_ray() -> bool:
def find_available_port() -> int:
"""Find an available port on the local machine."""
r"""Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
@@ -288,8 +294,8 @@ def find_available_port() -> int:
return port
def fix_proxy(ipv6_enabled: bool) -> None:
"""Fix proxy settings for gradio ui."""
def fix_proxy(ipv6_enabled: bool = False) -> None:
r"""Fix proxy settings for gradio ui."""
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
if ipv6_enabled:
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):