Merge pull request #99 from burtenshaw/cpu-mps-dev-ben
Add mps and cpu dependency management
This commit is contained in:
@@ -19,6 +19,15 @@ dependencies = [
|
||||
"wandb>=0.21.3",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
# Optional groups to control PyTorch source selection
|
||||
cpu = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
cuda = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["maturin>=1.7,<2.0"]
|
||||
build-backend = "maturin"
|
||||
@@ -34,6 +43,31 @@ dev = [
|
||||
"maturin>=1.9.4",
|
||||
"pytest>=8.0.0",
|
||||
]
|
||||
cuda = [
|
||||
"cuda", # refers to the above optional dependency group
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
default-groups = ["cuda"]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "platform_system == 'Darwin'"},
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-cu128", extra = "cuda"},
|
||||
]
|
||||
|
||||
# CPU-only index
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
# CUDA 12.8 index
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
|
||||
Reference in New Issue
Block a user