summarylogtreecommitdiffstats
path: root/python-jaxlib-cuda.diff
blob: ccfa57faf91ed47cb5f498a01bc41174928f35dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
--- a/build/build.py.bak	2023-07-28 04:10:24.273760188 +0300
+++ b/build/build.py	2023-07-28 04:10:57.007308693 +0300
@@ -552,7 +552,7 @@
     command += ["--editable"]
   print(" ".join(command))
   shell(command)
-  shell([bazel_path] + args.bazel_startup_options + ["shutdown"])
+  subprocess.run([bazel_path] + args.bazel_startup_options + ["shutdown"])
 
 
 if __name__ == "__main__":

--- a/jaxlib/tools/build_wheel.py	2023-07-28 03:00:56.894081609 +0300
+++ b/jaxlib/tools/build_wheel.py	2023-07-28 03:32:53.284008361 +0300
@@ -311,10 +311,14 @@
   platform_tag_arg = f"-C=--build-option=--plat-name={platform_name}_{cpu_name}"
   if os.environ.get('JAXLIB_NIGHTLY'):
     edit_jaxlib_version(sources_path)
-  subprocess.run(
-    [sys.executable, "-m", "build", "-n", "-w",
-     python_tag_arg, platform_tag_arg],
-    check=True, cwd=sources_path)
+  try:
+    subprocess.run(
+      [sys.executable, "-m", "build", "-n", "-w",
+       python_tag_arg, platform_tag_arg],
+      check=True, capture_output=True, cwd=sources_path)
+  except subprocess.CalledProcessError as e:
+    raise RuntimeError(f'Command {e.cmd} returned non-zero exit status.',
+                       {'stdout': e.stdout, 'stderr': stderr}) from e
   for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
     output_file = os.path.join(output_path, os.path.basename(wheel))
     sys.stderr.write(f"Output wheel: {output_file}\n\n")