summarylogtreecommitdiffstats
path: root/python-jaxlib.diff
blob: 8e0b38a2dc1833be4031475fb29ab6e682d158b2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
--- a/jaxlib/tools/build_wheel.py	2023-07-28 03:00:56.894081609 +0300
+++ b/jaxlib/tools/build_wheel.py	2023-07-28 03:04:32.978292616 +0300
@@ -29,6 +29,7 @@
 import subprocess
 import sys
 import tempfile
+from subprocess import PIPE, STDOUT, CalledProcessError
 
 from bazel_tools.tools.python.runfiles import runfiles
 
@@ -311,10 +312,17 @@
   platform_tag_arg = f"-C=--build-option=--plat-name={platform_name}_{cpu_name}"
   if os.environ.get('JAXLIB_NIGHTLY'):
     edit_jaxlib_version(sources_path)
-  subprocess.run(
-    [sys.executable, "-m", "build", "-n", "-w",
-     python_tag_arg, platform_tag_arg],
-    check=True, cwd=sources_path)
+  try:
+    subprocess.run(
+      [sys.executable, "-m", "build", "-n", "-w",
+       python_tag_arg, platform_tag_arg],
+      check=True, cwd=sources_path, stdout=PIPE, stderr=STDOUT)
+  except CalledProcessError as e:
+    print('STDOUT')
+    print(e.stdout)
+    print('STDERR')
+    print(e.stderr)
+    raise
   for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
     output_file = os.path.join(output_path, os.path.basename(wheel))
     sys.stderr.write(f"Output wheel: {output_file}\n\n")