diff options
author | Daniel Bershatsky | 2023-07-28 04:20:26 +0300 |
---|---|---|
committer | Daniel Bershatsky | 2023-07-28 04:20:26 +0300 |
commit | e9fb3b285e2967e6369493d334d639f37bab30ed (patch) | |
tree | 941d7b820d96a13cc8cc4cdc7d418003b6ae6a79 | |
parent | 8da2e52ebc134c3b764726dfdf973fae90a63d35 (diff) | |
download | aur-e9fb3b285e2967e6369493d334d639f37bab30ed.tar.gz |
Bump version to 0.4.14
-rw-r--r-- | .SRCINFO | 10 | ||||
-rw-r--r-- | PKGBUILD | 14 | ||||
-rw-r--r-- | python-jaxlib-cuda.diff | 33 |
3 files changed, 49 insertions, 8 deletions
@@ -1,6 +1,6 @@ pkgbase = python-jaxlib-cuda pkgdesc = XLA library for JAX - pkgver = 0.4.13 + pkgver = 0.4.14 pkgrel = 1 url = https://github.com/google/jax/ arch = x86_64 @@ -19,11 +19,13 @@ pkgbase = python-jaxlib-cuda depends = python-ml-dtypes depends = python-numpy depends = python-scipy - provides = python-jaxlib=0.4.13 + provides = python-jaxlib=0.4.14 conflicts = python-jaxlib - source = jaxlib-0.4.13.tar.gz::https://github.com/google/jax/archive/refs/tags/jaxlib-v0.4.13.tar.gz + source = jaxlib-0.4.14.tar.gz::https://github.com/google/jax/archive/refs/tags/jaxlib-v0.4.14.tar.gz source = bazelrc.user - sha256sums = 45766238b57b992851763c64bc943858aebafe4cad7b3df6cde844690bc34293 + source = python-jaxlib-cuda.diff + sha256sums = 9f309476a8f6337717b059b8d10b5859b4134c30cf8f1220bb70379b5e2744a4 sha256sums = 07da4c3594dad382ee02748b860c629ffa083ba37ad22a892291bdc72efbac5e + sha256sums = c70c257544f4f17ee4f6668e880e96154850f592adff7a817e1db3d8ab7feb49 pkgname = python-jaxlib-cuda @@ -1,7 +1,7 @@ # Maintainer: Daniel Bershatsky <bepshatsky@yandex.ru> pkgname=python-jaxlib-cuda -pkgver=0.4.13 +pkgver=0.4.14 pkgrel=1 pkgdesc='XLA library for JAX' arch=('x86_64') @@ -19,9 +19,11 @@ makedepends=('bazel' 'gcc12' 'pybind11' 'python-installer' 'python-setuptools' ' conflicts=('python-jaxlib') provides=("python-jaxlib=$pkgver") source=("jaxlib-${pkgver}.tar.gz::https://github.com/google/jax/archive/refs/tags/jaxlib-v${pkgver}.tar.gz" - 'bazelrc.user') -sha256sums=('45766238b57b992851763c64bc943858aebafe4cad7b3df6cde844690bc34293' - '07da4c3594dad382ee02748b860c629ffa083ba37ad22a892291bdc72efbac5e') + 'bazelrc.user' + 'python-jaxlib-cuda.diff') +sha256sums=('9f309476a8f6337717b059b8d10b5859b4134c30cf8f1220bb70379b5e2744a4' + '07da4c3594dad382ee02748b860c629ffa083ba37ad22a892291bdc72efbac5e' + 'c70c257544f4f17ee4f6668e880e96154850f592adff7a817e1db3d8ab7feb49') prepare() { # Allow any bazel version @@ -58,8 +60,12 @@ prepare() { # add latest PTX for future compatibility # Valid values can be discovered from nvcc --help export TF_CUDA_COMPUTE_CAPABILITIES=sm_70,sm_72,sm_75,sm_80,sm_86,sm_87,sm_89,sm_90,compute_90 + + cd $srcdir/jax-jaxlib-v$pkgver + patch -p1 -i ../python-jaxlib-cuda.diff } + build() { cd $srcdir/jax-jaxlib-v$pkgver bazel run --verbose_failures=true \ diff --git a/python-jaxlib-cuda.diff b/python-jaxlib-cuda.diff new file mode 100644 index 000000000000..ccfa57faf91e --- /dev/null +++ b/python-jaxlib-cuda.diff @@ -0,0 +1,33 @@ +--- a/build/build.py.bak 2023-07-28 04:10:24.273760188 +0300 ++++ b/build/build.py 2023-07-28 04:10:57.007308693 +0300 +@@ -552,7 +552,7 @@ + command += ["--editable"] + print(" ".join(command)) + shell(command) +- shell([bazel_path] + args.bazel_startup_options + ["shutdown"]) ++ subprocess.run([bazel_path] + args.bazel_startup_options + ["shutdown"]) + + + if __name__ == "__main__": + +--- a/jaxlib/tools/build_wheel.py 2023-07-28 03:00:56.894081609 +0300 ++++ b/jaxlib/tools/build_wheel.py 2023-07-28 03:32:53.284008361 +0300 +@@ -311,10 +311,14 @@ + platform_tag_arg = f"-C=--build-option=--plat-name={platform_name}_{cpu_name}" + if os.environ.get('JAXLIB_NIGHTLY'): + edit_jaxlib_version(sources_path) +- subprocess.run( +- [sys.executable, "-m", "build", "-n", "-w", +- python_tag_arg, platform_tag_arg], +- check=True, cwd=sources_path) ++ try: ++ subprocess.run( ++ [sys.executable, "-m", "build", "-n", "-w", ++ python_tag_arg, platform_tag_arg], ++ check=True, capture_output=True, cwd=sources_path) ++ except subprocess.CalledProcessError as e: ++ raise RuntimeError(f'Command {e.cmd} returned non-zero exit status.', ++ {'stdout': e.stdout, 'stderr': stderr}) from e + for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")): + output_file = os.path.join(output_path, os.path.basename(wheel)) + sys.stderr.write(f"Output wheel: {output_file}\n\n") |