summarylogtreecommitdiffstats
path: root/mpi4jax.diff
blob: 29a62a80eb3c6d8abcb7c9227f8037b20f8d8046 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
--- a/setup.py	2025-03-22 17:45:31.333454882 +0300
+++ b/setup.py	2025-03-22 17:47:07.398306761 +0300
@@ -149,6 +149,8 @@
 
 # Taken from CUPY (MIT License)
 def get_cuda_path():
+    if os.getenv("ENABLE_CUDA") is None:
+        return None
     nvcc_path = search_on_path(("nvcc", "nvcc.exe"))
     cuda_path_default = None
     if nvcc_path is not None:
@@ -180,6 +182,8 @@
 
 
 def get_sycl_path():
+    if os.getenv("ENABLE_SYCL") is None:
+        return None
     sycl_path = os.getenv("CMPLR_ROOT", "")
     if len(sycl_path) > 0 and os.path.exists(sycl_path):
         _sycl_path = sycl_path