diff --git a/pyproject.toml b/pyproject.toml index 8d2c8f0..26625fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,3 +43,20 @@ testpaths = ["tests"] python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] + +# target torch to cuda 12.8 +[tool.uv.sources] +torch = [ + { index = "pytorch-cpu", marker = "sys_platform != 'linux'" }, + { index = "pytorch-cu128", marker = "sys_platform == 'linux'" }, +] + +[[tool.uv.index]] +name = "pytorch-cpu" +url = "https://download.pytorch.org/whl/cpu" +explicit = true + +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" +explicit = true \ No newline at end of file