summarylogtreecommitdiffstats
path: root/mpi4jax.diff
diff options
context:
space:
mode:
Diffstat (limited to 'mpi4jax.diff')
-rw-r--r--mpi4jax.diff20
1 files changed, 20 insertions, 0 deletions
diff --git a/mpi4jax.diff b/mpi4jax.diff
new file mode 100644
index 000000000000..29a62a80eb3c
--- /dev/null
+++ b/mpi4jax.diff
@@ -0,0 +1,20 @@
+--- a/setup.py 2025-03-22 17:45:31.333454882 +0300
++++ b/setup.py 2025-03-22 17:47:07.398306761 +0300
+@@ -149,6 +149,8 @@
+
+ # Taken from CUPY (MIT License)
+ def get_cuda_path():
++ if os.getenv("ENABLE_CUDA") is None:
++ return None
+ nvcc_path = search_on_path(("nvcc", "nvcc.exe"))
+ cuda_path_default = None
+ if nvcc_path is not None:
+@@ -180,6 +182,8 @@
+
+
+ def get_sycl_path():
++ if os.getenv("ENABLE_SYCL") is None:
++ return None
+ sycl_path = os.getenv("CMPLR_ROOT", "")
+ if len(sycl_path) > 0 and os.path.exists(sycl_path):
+ _sycl_path = sycl_path