[breaking] bump transformers to 4.45.0 & improve ci (#7746)
* update ci * fix * fix * fix * fix * fix
This commit is contained in:
@@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.41.2,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("transformers>=4.43.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("datasets>=2.16.0,<=3.5.0")
|
||||
check_version("accelerate>=0.34.0,<=1.6.0")
|
||||
check_version("peft>=0.14.0,<=0.15.1")
|
||||
@@ -141,13 +141,13 @@ def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""Get the current available device."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_mps_available():
|
||||
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_cuda_available():
|
||||
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
@@ -155,11 +155,13 @@ def get_current_device() -> "torch.device":
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""Get the number of available GPU or NPU devices."""
|
||||
r"""Get the number of available devices."""
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
elif is_torch_npu_available():
|
||||
return torch.npu.device_count()
|
||||
elif is_torch_mps_available():
|
||||
return torch.mps.device_count()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.device_count()
|
||||
else:
|
||||
@@ -175,10 +177,12 @@ def get_logits_processor() -> "LogitsProcessorList":
|
||||
|
||||
def get_peak_memory() -> tuple[int, int]:
|
||||
r"""Get the peak memory usage for the current device (in Bytes)."""
|
||||
if is_torch_npu_available():
|
||||
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
|
||||
elif is_torch_xpu_available():
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
|
||||
elif is_torch_npu_available():
|
||||
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
|
||||
elif is_torch_mps_available():
|
||||
return torch.mps.current_allocated_memory(), -1
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
|
||||
else:
|
||||
@@ -200,9 +204,11 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
return torch.float32
|
||||
|
||||
|
||||
def is_gpu_or_npu_available() -> bool:
|
||||
r"""Check if the GPU or NPU is available."""
|
||||
return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available()
|
||||
def is_accelerator_available() -> bool:
|
||||
r"""Check if the accelerator is available."""
|
||||
return (
|
||||
is_torch_xpu_available() or is_torch_npu_available() or is_torch_mps_available() or is_torch_cuda_available()
|
||||
)
|
||||
|
||||
|
||||
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
||||
@@ -229,7 +235,7 @@ def skip_check_imports() -> None:
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""Collect GPU or NPU memory."""
|
||||
r"""Collect the device memory."""
|
||||
gc.collect()
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
@@ -280,7 +286,7 @@ def use_ray() -> bool:
|
||||
|
||||
|
||||
def find_available_port() -> int:
|
||||
"""Find an available port on the local machine."""
|
||||
r"""Find an available port on the local machine."""
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(("", 0))
|
||||
port = sock.getsockname()[1]
|
||||
@@ -288,8 +294,8 @@ def find_available_port() -> int:
|
||||
return port
|
||||
|
||||
|
||||
def fix_proxy(ipv6_enabled: bool) -> None:
|
||||
"""Fix proxy settings for gradio ui."""
|
||||
def fix_proxy(ipv6_enabled: bool = False) -> None:
|
||||
r"""Fix proxy settings for gradio ui."""
|
||||
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
|
||||
if ipv6_enabled:
|
||||
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
|
||||
|
||||
Reference in New Issue
Block a user