Package Details: python-jaxlib-cuda 0.6.0-1

Git Clone URL: https://aur.archlinux.org/python-jaxlib-cuda.git (read-only, click to copy)
Package Base: python-jaxlib-cuda
Description: XLA library for JAX
Upstream URL: https://github.com/jax-ml/jax/
Keywords: deep-learning google jax maching-learning xla
Licenses: Apache-2.0
Groups: jax
Conflicts: python-jaxlib
Provides: python-jaxlib
Submitter: daskol
Maintainer: daskol
Last Packager: daskol
Votes: 9
Popularity: 0.003181
First Submitted: 2023-02-12 23:18 (UTC)
Last Updated: 2025-04-20 20:54 (UTC)

Latest Comments

1 2 3 4 5 6 Next › Last »

medaminezghal commented on 2025-10-03 08:09 (UTC)

@daskol @truncs Using this PKGBUILD file bellow and using this merge request, I was able to bypass all problems related with building JAX with local CUDA but I got problems related to architecture and options:

clang-20: error: unknown argument: '-Xcuda-fatbinary=--compress-all'
clang-20: error: unknown argument: '-nvcc_options=expt-relaxed-constexpr'
clang-20: warning: CUDA version is newer than the latest partially supported version 12.8 [-Wunknown-cuda-version]
clang-20: error: unsupported CUDA gpu architecture: sm_88

I think if I found a way to use NVCC instead of Clang, the problem will be solved.

This is the PKGBUILD file:

# Maintainer: Daniel Bershatsky <bepshatsky@yandex.ru>

pkgname=python-jaxlib
pkgver=0.7.2
pkgrel=1
pkgdesc='XLA library for JAX'
arch=('x86_64')
url='https://github.com/jax-ml/jax/'
license=('Apache-2.0')
groups=('jax')
depends=(
    'python-ml-dtypes'
    'python-numpy'
    'python-scipy'
)
makedepends=('clang' 'python-build' 'python-installer' 'python-setuptools' 'python-wheel')
_bazel_ver=7.4.1
_xla_commit=0fccb8a6037019b20af2e502ba4b8f5e0f98c8f6 #https://github.com/jax-ml/jax/blob/jax-v0.7.2/third_party/xla/revision.bzl
source=("jax-${pkgver}.tar.gz::$url/archive/refs/tags/jax-v${pkgver}.tar.gz"
        "openxla-xla-${_xla_commit:0:7}.tar.gz::https://api.github.com/repos/openxla/xla/tarball/$_xla_commit"
        "bazel-${_bazel_ver}-linux-x86_64::https://github.com/bazelbuild/bazel/releases/download/${_bazel_ver}/bazel-${_bazel_ver}-linux-x86_64")
noextract=("bazel-${_bazel_ver}-linux-x86_64")
sha256sums=('56d92604f1bb60bb3dbd7dc7c7dc21502d10b3474b8b905ce29ce06db6a26e45'
            '504315851ae676bf27122f20f68980fafb2a2c37e10113f58b03f6c284c55cfd'
            'c97f02133adce63f0c28678ac1f21d65fa8255c80429b588aeeba8a1fac6202b')

prepare() {
    ln -sf $(readlink bazel-${_bazel_ver}-linux-x86_64) $srcdir/jax-jax-v${pkgver}/build
    chmod +x $srcdir/bazel-${_bazel_ver}-linux-x86_64
    cd $srcdir/openxla-xla-${_xla_commit:0:7}
    sed -i 's/5354032ea08eadd7fc4456477f7f7c6308818509/54cbae0d3a67fa890b4c3d9ee162b7860315e341/g' third_party/gloo/workspace.bzl
    sed -i 's/5759a06e6c8863c58e8ceadeb56f7c701fec89b2559ba33a103a447207bf69c7/61089361dbdbc9d6f75e297148369b13f615a3e6b78de1be56cce74ca2f64940/g' third_party/gloo/workspace.bzl
}

build() {
    local  CUDA_MAJOR_VERSION=$(/opt/cuda/bin/nvcc --version | sed -n 's/^.*release \([0-9]\+\).*/\1/p')
    # Override default version.
    export JAXLIB_RELEASE=$pkgver

    cd $srcdir/jax-jax-v$pkgver
    build/build.py build \
        --bazel_path="$srcdir/bazel-${_bazel_ver}-linux-x86_64" \
        --bazel_startup_options="--output_user_root=$srcdir/bazel" \
        --bazel_options='--repo_env=LOCAL_CUDA_PATH="/opt/cuda"' \
        --bazel_options='--repo_env=LOCAL_CUDNN_PATH="/opt/cuda"' \
        --bazel_options='--repo_env=LOCAL_NCCL_PATH="/usr"' \
        --bazel_options='--repo_env=LOCAL_NVSHMEM_PATH="/usr"' \
        --bazel_options='--action_env=TF_NVCC_CLANG="0"' \
        --verbose \
        --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt \
        --cuda_major_version=$CUDA_MAJOR_VERSION \
        --cuda_compute_capabilities=$(echo sm_{75,80,86,87,88,89,90,90a,100,100a,103,103a,110,110a,120,120a,121,121a} | tr ' ' ','),compute_121 \
        --clang_path=/usr/bin/clang \
        --target_cpu_features=release \
        --local_xla_path=$srcdir/openxla-xla-${_xla_commit:0:7}
}

package() {
    cd $srcdir/jax-jax-v$pkgver
    install -Dm 644 LICENSE "$pkgdir/usr/share/licenses/$pkgname/LICENSE"
    python -m installer --compile-bytecode=1 --destdir=$pkgdir \
        $srcdir/jax-jax-v$pkgver/dist/jaxlib-$pkgver-*.whl
    python -m installer --compile-bytecode=1 --destdir=$pkgdir \
        $srcdir/jax-jax-v$pkgver/dist/jax_cuda$CUDA_MAJOR_VERSION_pjrt-$pkgver-*.whl
    python -m installer --compile-bytecode=1 --destdir=$pkgdir \
        $srcdir/jax-jax-v$pkgver/dist/jax_cuda$CUDA_MAJOR_VERSION_plugin-$pkgver-*.whl
}

Could you help me to fix the problem?

truncs commented on 2025-09-11 00:16 (UTC)

Using jax 0.7.0, clang-20 and including <cstdint> as @wheelsofindustry mentioned, I was able to build jaxlib-cuda but jaxlib-plugin build still fails. This one also looks like it is related to the toolchain

# Configuration: a1e1d9399d833527724a5906f125d5456703d1b9bc091a7b2e8a1e4f7ddfaf4c
# Execution platform: @@local_execution_config_platform//:platform
clang-20: warning: CUDA version is newer than the latest partially supported version 12.8 [-Wunknown-cuda-version]
In file included from <built-in>:1:
In file included from /usr/lib/clang/20/include/__clang_cuda_runtime_wrapper.h:41:
In file included from /usr/lib/clang/20/include/cuda_wrappers/cmath:28:
/usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/15.2.1/../../../../include/c++/15.2.1/cmath:55:15: fatal error: 'math.h' file not found
   55 | #include_next <math.h>
      |               ^~~~~~~~
1 error generated when compiling for sm_100.
Target //jaxlib/tools:build_gpu_kernels_wheel failed to build
INFO: Elapsed time: 1.658s, Critical Path: 0.91s
INFO: 37 processes: 34 internal, 3 local.
ERROR: Build did NOT complete successfully
ERROR: Build failed. Not running target
2025-09-10 17:09:22,680 - DEBUG - Command finished with return code 1
Traceback (most recent call last):

I will probably remove the arch jax and install jax using pip in a virtualenv.

truncs commented on 2025-08-27 23:36 (UTC)

This still doesn't build

Traceback (most recent call last):
  File "/home/aditya/.cache/yay/python-jaxlib-cuda/src/jax-jax-v0.6.0/build/build.py", line 778, in <module>
    asyncio.run(main())
    ~~~~~~~~~~~^^^^^^^^
  File "/usr/lib/python3.13/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ~~~~~~~~~~^^^^^^
  File "/usr/lib/python3.13/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^
  File "/usr/lib/python3.13/asyncio/base_events.py", line 725, in run_until_complete
    return future.result()
           ~~~~~~~~~~~~~^^
  File "/home/aditya/.cache/yay/python-jaxlib-cuda/src/jax-jax-v0.6.0/build/build.py", line 723, in main
    raise RuntimeError(f"Command failed with return code {result.return_code}")
RuntimeError: Command failed with return code 1
==> ERROR: A failure occurred in build().
    Aborting...
 -> error making: python-jaxlib-cuda-exit status 4

What is the alternative to have jax working with cuda?

wheelsofindustry commented on 2025-05-24 20:50 (UTC) (edited on 2025-05-24 20:52 (UTC) by wheelsofindustry)

@truncs adding an include statement for cstdint to types.h got me past the same issue

--- /python-jaxlib/src/bazel/*/gloo/gloo/types.h    2023-12-02 17:32:51.000000000 -0800
+++ /python-jaxlib-cuda/src/bazel/*/external/gloo/gloo/types.h  2025-05-24 06:28:59.597042640 -0700
@@ -5,6 +5,7 @@
 #pragma once

 #include <iostream>
+#include <cstdint>

 #ifdef __CUDA_ARCH__
 #include <cuda.h>

works on python-jaxlib & python-jaxlib-cuda, but after that it complains about the latest cuda version being too new (12.8)? rolling cuda back to 12.2 requires gcc12 trying that now.

daskol commented on 2025-05-07 08:27 (UTC)

@truncs It seems that the issues is gcc-libs>=15 again (other AUR packages does not build now too). Clang shares GCC's standard library that sometimes causes odd issues like broken include order, missing type defs, unexpected #warnings directives that causes errors, etc.

I'll try to build it and patch sources if possible.

truncs commented on 2025-05-06 17:46 (UTC)

Actually forcing it to use gcc is something I am not able to do. I tried removing the clang options and setting the repo_env to point to gcc but it would still use clang.

Additional Bazel build options: ['--action_env=JAXLIB_RELEASE', '--action_env=TF_CUDA_COMPUTE_CAPABILITIES=sm_80,sm_86,sm_89,sm_90,compute_90', '--repo_env=HERMETIC_CUDA_VERSION=12.8.0', '--repo_env=LOCAL_NCCL_PATH=/usr', '--repo_env=CC=/usr/bin/gcc', '--repo_env=CXX=/usr/bin/g++']
2025-05-06 10:41:34,195 - INFO - Bazel options written to .jax_configure.bazelrc
2025-05-06 10:41:34,195 - DEBUG - Artifacts output directory: /home/aditya/.cache/yay/python-jaxlib-cuda/src/jax-jax-v0.6.0/dist


2025-05-06 10:41:34,195 - INFO - Building jaxlib for linux x86_64...
2025-05-06 10:41:34,195 - INFO - [EXECUTING] .cache/yay/python-jaxlib-cuda/src/bazel-7.4.1-linux-x86_64 --output_user_root=.cache/yay/python-jaxlib-cuda/src/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.13 --verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/bin/clang-19" --repo_env=CC="/usr/bin/clang-19" --repo_env=CXX="/usr/bin/clang++-19" --repo_env=BAZEL_COMPILER="/usr/bin/clang-19" --config=clang --config=mkl_open_source_only --config=avx_posix --config=cuda --config=cuda_libraries_from_stubs --action_env=CLANG_CUDA_COMPILER_PATH="/usr/bin/clang-19" --config=build_cuda_with_clang --repo_env=HERMETIC_CUDNN_VERSION=9.8.0 --action_env=JAXLIB_RELEASE --action_env=TF_CUDA_COMPUTE_CAPABILITIES=sm_80,sm_86,sm_89,sm_90,compute_90 --repo_env=HERMETIC_CUDA_VERSION=12.8.0 --repo_env=LOCAL_NCCL_PATH=/usr --repo_env=CC=/usr/bin/gcc --repo_env=CXX=/usr/bin/g++ --config=cuda_libraries_from_stubs //jaxlib/tools:build_wheel -- --output_path=".cache/yay/python-jaxlib-cuda/src/jax-jax-v0.6.0/dist" --cpu=x86_64 --jaxlib_git_hash=72273970ff20769d87749c73d893ea28171fd53d

daskol commented on 2025-05-06 15:12 (UTC) (edited on 2025-05-06 15:13 (UTC) by daskol)

@truncs It seems a compiler issue. Try to enforce use of gcc13. It could solve the issue.

truncs commented on 2025-05-05 20:25 (UTC)

Getting this error for 0.6.0-1

 TF_CUDA_COMPUTE_CAPABILITIES=sm_80,sm_86,sm_89,sm_90,compute_90 \
  /usr/lib/llvm18/bin/clang-18 -MD -MF bazel-out/k8-opt/bin/external/gloo/_objs/gloo/types.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/gloo/_objs/gloo/types.pic.o' -iquote external/gloo -iquote bazel-out/k8-opt/bin/external/gloo -isystem external/gloo -isystem bazel-out/k8-opt/bin/external/gloo -fmerge-all-constants -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -Wno-invalid-partial-specialization -fno-omit-frame-pointer -no-canonical-prefixes -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '--cuda-path=external/cuda_nvcc' '-fvisibility=hidden' -Wno-sign-compare -Wno-unknown-warning-option -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' '-DNB_DOMAIN=jax' -Wno-gnu-offsetof-extensions -Qunused-arguments '-Werror=mismatched-tags' '-Wno-error=c23-extensions' -mavx -Wno-gnu-offsetof-extensions -Qunused-arguments '-Werror=mismatched-tags' '-Wno-error=c23-extensions' -mavx '-std=c++17' -fexceptions -Wno-unused-variable -c external/gloo/gloo/types.cc -o bazel-out/k8-opt/bin/external/gloo/_objs/gloo/types.pic.o)
# Configuration: 6f9d4bb27a6c9bc488a0cd126309e9bf2df92d2e54fc24d1a8d5d26084fd8413
# Execution platform: @@local_execution_config_platform//:platform
In file included from external/gloo/gloo/types.cc:9:
external/gloo/gloo/types.h:66:11: error: unknown type name 'uint8_t'
   66 | constexpr uint8_t kGatherSlotPrefix = 0x01;
      |           ^
external/gloo/gloo/types.h:67:11: error: unknown type name 'uint8_t'
   67 | constexpr uint8_t kAllgatherSlotPrefix = 0x02;
      |           ^
external/gloo/gloo/types.h:68:11: error: unknown type name 'uint8_t'
   68 | constexpr uint8_t kReduceSlotPrefix = 0x03;
      |           ^
external/gloo/gloo/types.h:69:11: error: unknown type name 'uint8_t'
   69 | constexpr uint8_t kAllreduceSlotPrefix = 0x04;
      |           ^
external/gloo/gloo/types.h:70:11: error: unknown type name 'uint8_t'
   70 | constexpr uint8_t kScatterSlotPrefix = 0x05;
      |           ^
external/gloo/gloo/types.h:71:11: error: unknown type name 'uint8_t'
   71 | constexpr uint8_t kBroadcastSlotPrefix = 0x06;
      |           ^
external/gloo/gloo/types.h:72:11: error: unknown type name 'uint8_t'
   72 | constexpr uint8_t kBarrierSlotPrefix = 0x07;
      |           ^
external/gloo/gloo/types.h:73:11: error: unknown type name 'uint8_t'
   73 | constexpr uint8_t kAlltoallSlotPrefix = 0x08;
      |           ^
external/gloo/gloo/types.h:77:21: error: unknown type name 'uint8_t'
   77 |   static Slot build(uint8_t prefix, uint32_t tag);
      |                     ^
external/gloo/gloo/types.h:77:37: error: unknown type name 'uint32_t'
   77 |   static Slot build(uint8_t prefix, uint32_t tag);
      |                                     ^
external/gloo/gloo/types.h:79:12: error: unknown type name 'uint64_t'
   79 |   operator uint64_t() const {
      |            ^
external/gloo/gloo/types.h:86:17: error: unknown type name 'uint64_t'
   86 |   explicit Slot(uint64_t base, uint64_t delta) : base_(base), delta_(delta) {}
      |                 ^
external/gloo/gloo/types.h:86:32: error: unknown type name 'uint64_t'
   86 |   explicit Slot(uint64_t base, uint64_t delta) : base_(base), delta_(delta) {}
      |                                ^
external/gloo/gloo/types.h:88:9: error: unknown type name 'uint64_t'
   88 |   const uint64_t base_;
      |         ^
external/gloo/gloo/types.h:89:9: error: unknown type name 'uint64_t'
   89 |   const uint64_t delta_;
      |         ^
external/gloo/gloo/types.h:97:3: error: unknown type name 'uint16_t'
   97 |   uint16_t x;
      |   ^
external/gloo/gloo/types.cc:16:18: error: unknown type name 'uint8_t'
   16 | Slot Slot::build(uint8_t prefix, uint32_t tag) {
      |                  ^
external/gloo/gloo/types.cc:16:34: error: unknown type name 'uint32_t'
   16 | Slot Slot::build(uint8_t prefix, uint32_t tag) {
      |                                  ^
external/gloo/gloo/types.cc:17:3: error: unknown type name 'uint64_t'
   17 |   uint64_t u64prefix = ((uint64_t)prefix) << 56;
      |   ^
fatal error: too many errors emitted, stopping now [-ferror-limit=]
20 errors generated.

I did a quicksearch on upstream and I can't find anything related to this. Thoughts?

medaminezghal commented on 2025-04-22 19:02 (UTC)

@daskol the build will fails because you use --bazel_options="--action_env=TF_CUDA_COMPUTE_CAPABILITIES=${CUDA_COMPUTE_CAPABILITIES}" \.

It will not succeed because clang can't compile for sm_100 and sm_120 that exist in .bazelrc.

Instead, use --bazel_options="--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES=${CUDA_COMPUTE_CAPABILITIES}" \.

And I have some suggestions:

1- Use the default Clang provided by system (The previous problem was fixed).

2- Add support for more cards by using: export CUDA_COMPUTE_CAPABILITIES=sm_50,sm_52,sm_53,sm_60,sm_61,sm_62,sm_70,sm_72,sm_75,sm_80,sm_86,sm_87,sm_89,sm_90,sm_90a,compute_90

daskol commented on 2025-03-20 22:39 (UTC)

@medaminezghal Thanks for your PKGBUILD. My build server was under maintenance for a while.

However, there is a building issue for 0.5.3 regarding transition to bazel 7. It turns out that python rules does not handle properly symlink creation. This is why python -m build fails with missing Lorem ipsum.txt file.

Looking for a solution. Probably, manual downgrading to bazel 6 will solve the issue.