Merge pull request #99 from burtenshaw/cpu-mps-dev-ben
Add mps and cpu dependency management
This commit is contained in:
@@ -19,6 +19,15 @@ dependencies = [
|
|||||||
"wandb>=0.21.3",
|
"wandb>=0.21.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
# Optional groups to control PyTorch source selection
|
||||||
|
cpu = [
|
||||||
|
"torch>=2.8.0",
|
||||||
|
]
|
||||||
|
cuda = [
|
||||||
|
"torch>=2.8.0",
|
||||||
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["maturin>=1.7,<2.0"]
|
requires = ["maturin>=1.7,<2.0"]
|
||||||
build-backend = "maturin"
|
build-backend = "maturin"
|
||||||
@@ -34,6 +43,31 @@ dev = [
|
|||||||
"maturin>=1.9.4",
|
"maturin>=1.9.4",
|
||||||
"pytest>=8.0.0",
|
"pytest>=8.0.0",
|
||||||
]
|
]
|
||||||
|
cuda = [
|
||||||
|
"cuda", # refers to the above optional dependency group
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
default-groups = ["cuda"]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = [
|
||||||
|
{ index = "pytorch-cpu", marker = "platform_system == 'Darwin'"},
|
||||||
|
{ index = "pytorch-cpu", extra = "cpu" },
|
||||||
|
{ index = "pytorch-cu128", extra = "cuda"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# CPU-only index
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cpu"
|
||||||
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
# CUDA 12.8 index
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cu128"
|
||||||
|
url = "https://download.pytorch.org/whl/cu128"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
markers = [
|
markers = [
|
||||||
|
|||||||
Reference in New Issue
Block a user