summarylogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.SRCINFO2
-rw-r--r--PKGBUILD11
-rw-r--r--python-jaxlib.diff32
3 files changed, 43 insertions, 2 deletions
diff --git a/.SRCINFO b/.SRCINFO
index df7121dc7a72..d62ae02adaec 100644
--- a/.SRCINFO
+++ b/.SRCINFO
@@ -16,6 +16,8 @@ pkgbase = python-jaxlib
depends = python-scipy
conflicts = python-jaxlib
source = jaxlib-0.4.14.tar.gz::https://github.com/google/jax/archive/refs/tags/jaxlib-v0.4.14.tar.gz
+ source = python-jaxlib.diff
sha256sums = 9f309476a8f6337717b059b8d10b5859b4134c30cf8f1220bb70379b5e2744a4
+ sha256sums = SKIP
pkgname = python-jaxlib
diff --git a/PKGBUILD b/PKGBUILD
index 654108692433..659391eccad7 100644
--- a/PKGBUILD
+++ b/PKGBUILD
@@ -15,8 +15,15 @@ depends=('python-absl'
'python-scipy')
makedepends=('python-installer' 'python-setuptools' 'python-wheel')
conflicts=('python-jaxlib')
-source=("jaxlib-${pkgver}.tar.gz::https://github.com/google/jax/archive/refs/tags/jaxlib-v${pkgver}.tar.gz")
-sha256sums=('9f309476a8f6337717b059b8d10b5859b4134c30cf8f1220bb70379b5e2744a4')
+source=("jaxlib-${pkgver}.tar.gz::https://github.com/google/jax/archive/refs/tags/jaxlib-v${pkgver}.tar.gz"
+ 'python-jaxlib.diff')
+sha256sums=('9f309476a8f6337717b059b8d10b5859b4134c30cf8f1220bb70379b5e2744a4'
+ 'SKIP')
+
+prepare() {
+ cd $srcdir/jax-jaxlib-v$pkgver
+ patch -p1 -i ../python-jaxlib.diff
+}
build() {
cd $srcdir/jax-jaxlib-v$pkgver
diff --git a/python-jaxlib.diff b/python-jaxlib.diff
new file mode 100644
index 000000000000..8e0b38a2dc18
--- /dev/null
+++ b/python-jaxlib.diff
@@ -0,0 +1,32 @@
+--- a/jaxlib/tools/build_wheel.py 2023-07-28 03:00:56.894081609 +0300
++++ b/jaxlib/tools/build_wheel.py 2023-07-28 03:04:32.978292616 +0300
+@@ -29,6 +29,7 @@
+ import subprocess
+ import sys
+ import tempfile
++from subprocess import PIPE, STDOUT, CalledProcessError
+
+ from bazel_tools.tools.python.runfiles import runfiles
+
+@@ -311,10 +312,17 @@
+ platform_tag_arg = f"-C=--build-option=--plat-name={platform_name}_{cpu_name}"
+ if os.environ.get('JAXLIB_NIGHTLY'):
+ edit_jaxlib_version(sources_path)
+- subprocess.run(
+- [sys.executable, "-m", "build", "-n", "-w",
+- python_tag_arg, platform_tag_arg],
+- check=True, cwd=sources_path)
++ try:
++ subprocess.run(
++ [sys.executable, "-m", "build", "-n", "-w",
++ python_tag_arg, platform_tag_arg],
++ check=True, cwd=sources_path, stdout=PIPE, stderr=STDOUT)
++ except CalledProcessError as e:
++ print('STDOUT')
++ print(e.stdout)
++ print('STDERR')
++ print(e.stderr)
++ raise
+ for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
+ output_file = os.path.join(output_path, os.path.basename(wheel))
+ sys.stderr.write(f"Output wheel: {output_file}\n\n")