summarylogtreecommitdiffstats
diff options
context:
space:
mode:
authorDaniel Bershatsky2023-07-28 04:20:26 +0300
committerDaniel Bershatsky2023-07-28 04:20:26 +0300
commite9fb3b285e2967e6369493d334d639f37bab30ed (patch)
tree941d7b820d96a13cc8cc4cdc7d418003b6ae6a79
parent8da2e52ebc134c3b764726dfdf973fae90a63d35 (diff)
downloadaur-e9fb3b285e2967e6369493d334d639f37bab30ed.tar.gz
Bump version to 0.4.14
-rw-r--r--.SRCINFO10
-rw-r--r--PKGBUILD14
-rw-r--r--python-jaxlib-cuda.diff33
3 files changed, 49 insertions, 8 deletions
diff --git a/.SRCINFO b/.SRCINFO
index bd3f08b5cd9a..d291869c2452 100644
--- a/.SRCINFO
+++ b/.SRCINFO
@@ -1,6 +1,6 @@
pkgbase = python-jaxlib-cuda
pkgdesc = XLA library for JAX
- pkgver = 0.4.13
+ pkgver = 0.4.14
pkgrel = 1
url = https://github.com/google/jax/
arch = x86_64
@@ -19,11 +19,13 @@ pkgbase = python-jaxlib-cuda
depends = python-ml-dtypes
depends = python-numpy
depends = python-scipy
- provides = python-jaxlib=0.4.13
+ provides = python-jaxlib=0.4.14
conflicts = python-jaxlib
- source = jaxlib-0.4.13.tar.gz::https://github.com/google/jax/archive/refs/tags/jaxlib-v0.4.13.tar.gz
+ source = jaxlib-0.4.14.tar.gz::https://github.com/google/jax/archive/refs/tags/jaxlib-v0.4.14.tar.gz
source = bazelrc.user
- sha256sums = 45766238b57b992851763c64bc943858aebafe4cad7b3df6cde844690bc34293
+ source = python-jaxlib-cuda.diff
+ sha256sums = 9f309476a8f6337717b059b8d10b5859b4134c30cf8f1220bb70379b5e2744a4
sha256sums = 07da4c3594dad382ee02748b860c629ffa083ba37ad22a892291bdc72efbac5e
+ sha256sums = c70c257544f4f17ee4f6668e880e96154850f592adff7a817e1db3d8ab7feb49
pkgname = python-jaxlib-cuda
diff --git a/PKGBUILD b/PKGBUILD
index 0e5ab72f286c..d8eab65e1c7d 100644
--- a/PKGBUILD
+++ b/PKGBUILD
@@ -1,7 +1,7 @@
# Maintainer: Daniel Bershatsky <bepshatsky@yandex.ru>
pkgname=python-jaxlib-cuda
-pkgver=0.4.13
+pkgver=0.4.14
pkgrel=1
pkgdesc='XLA library for JAX'
arch=('x86_64')
@@ -19,9 +19,11 @@ makedepends=('bazel' 'gcc12' 'pybind11' 'python-installer' 'python-setuptools' '
conflicts=('python-jaxlib')
provides=("python-jaxlib=$pkgver")
source=("jaxlib-${pkgver}.tar.gz::https://github.com/google/jax/archive/refs/tags/jaxlib-v${pkgver}.tar.gz"
- 'bazelrc.user')
-sha256sums=('45766238b57b992851763c64bc943858aebafe4cad7b3df6cde844690bc34293'
- '07da4c3594dad382ee02748b860c629ffa083ba37ad22a892291bdc72efbac5e')
+ 'bazelrc.user'
+ 'python-jaxlib-cuda.diff')
+sha256sums=('9f309476a8f6337717b059b8d10b5859b4134c30cf8f1220bb70379b5e2744a4'
+ '07da4c3594dad382ee02748b860c629ffa083ba37ad22a892291bdc72efbac5e'
+ 'c70c257544f4f17ee4f6668e880e96154850f592adff7a817e1db3d8ab7feb49')
prepare() {
# Allow any bazel version
@@ -58,8 +60,12 @@ prepare() {
# add latest PTX for future compatibility
# Valid values can be discovered from nvcc --help
export TF_CUDA_COMPUTE_CAPABILITIES=sm_70,sm_72,sm_75,sm_80,sm_86,sm_87,sm_89,sm_90,compute_90
+
+ cd $srcdir/jax-jaxlib-v$pkgver
+ patch -p1 -i ../python-jaxlib-cuda.diff
}
+
build() {
cd $srcdir/jax-jaxlib-v$pkgver
bazel run --verbose_failures=true \
diff --git a/python-jaxlib-cuda.diff b/python-jaxlib-cuda.diff
new file mode 100644
index 000000000000..ccfa57faf91e
--- /dev/null
+++ b/python-jaxlib-cuda.diff
@@ -0,0 +1,33 @@
+--- a/build/build.py.bak 2023-07-28 04:10:24.273760188 +0300
++++ b/build/build.py 2023-07-28 04:10:57.007308693 +0300
+@@ -552,7 +552,7 @@
+ command += ["--editable"]
+ print(" ".join(command))
+ shell(command)
+- shell([bazel_path] + args.bazel_startup_options + ["shutdown"])
++ subprocess.run([bazel_path] + args.bazel_startup_options + ["shutdown"])
+
+
+ if __name__ == "__main__":
+
+--- a/jaxlib/tools/build_wheel.py 2023-07-28 03:00:56.894081609 +0300
++++ b/jaxlib/tools/build_wheel.py 2023-07-28 03:32:53.284008361 +0300
+@@ -311,10 +311,14 @@
+ platform_tag_arg = f"-C=--build-option=--plat-name={platform_name}_{cpu_name}"
+ if os.environ.get('JAXLIB_NIGHTLY'):
+ edit_jaxlib_version(sources_path)
+- subprocess.run(
+- [sys.executable, "-m", "build", "-n", "-w",
+- python_tag_arg, platform_tag_arg],
+- check=True, cwd=sources_path)
++ try:
++ subprocess.run(
++ [sys.executable, "-m", "build", "-n", "-w",
++ python_tag_arg, platform_tag_arg],
++ check=True, capture_output=True, cwd=sources_path)
++ except subprocess.CalledProcessError as e:
++ raise RuntimeError(f'Command {e.cmd} returned non-zero exit status.',
++ {'stdout': e.stdout, 'stderr': stderr}) from e
+ for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
+ output_file = os.path.join(output_path, os.path.basename(wheel))
+ sys.stderr.write(f"Output wheel: {output_file}\n\n")