1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
|
# Maintainer: Daniel Bershatsky <bepshatsky@yandex.ru>
pkgname=python-jaxlib-cuda
pkgver=0.4.28
pkgrel=1
pkgdesc='XLA library for JAX'
arch=('x86_64')
url='https://github.com/google/jax/'
license=('Apache')
groups=('jax')
depends=('cuda' 'cudnn' 'nccl' 'openssl' 'python-absl' 'python-flatbuffers'
'python-ml-dtypes' 'python-numpy' 'python-scipy')
makedepends=('gcc12' 'pybind11' 'python-build' 'python-installer'
'python-setuptools' 'python-wheel')
conflicts=('python-jaxlib')
provides=("python-jaxlib=$pkgver")
xla_commit=12eee889e1f2ad41e27d7b0e970cb92d282d3ec5
source=("jaxlib-${pkgver}.tar.gz::https://github.com/google/jax/archive/refs/tags/jaxlib-v${pkgver}.tar.gz"
'bazelrc.user'
'https://github.com/bazelbuild/bazel/releases/download/6.1.2/bazel-6.1.2-linux-x86_64')
sha256sums=('4dd11577d4ba5a095fbc35258ddd4e4c020829ed6e6afd498c9e38ccbcdfe20b'
'07da4c3594dad382ee02748b860c629ffa083ba37ad22a892291bdc72efbac5e'
'e89747d63443e225b140d7d37ded952dacea73aaed896bca01ccd745827c6289')
prepare() {
# Allow any bazel version
echo "*" > jax-jaxlib-v${pkgver}/.bazelversion
# Add specific bazel's options.
cp bazelrc.user jax-jaxlib-v${pkgver}/.bazelrc.user
# TODO(@daskol): Prepare spcific bazel version ad hoc.
chmod +x bazel-6.1.2-linux-x86_64
}
build() {
# These environment variables influence the behavior of the configure call below.
export PYTHON_BIN_PATH=/usr/bin/python
export USE_DEFAULT_PYTHON_LIB_PATH=1
export TF_NEED_JEMALLOC=1
# See https://github.com/openxla/xla/blob/main/configure.py
# There is some vague reason why we should exclude pybind11. It exists in
# two version in source tree: abseil one and bazel one.
export TF_SYSTEM_LIBS="curl,cython,gif,icu,libjpeg_turbo,lmdb,nasm,png,zlib"
export TF_SET_ANDROID_WORKSPACE=0
export TF_DOWNLOAD_CLANG=0
export TF_NCCL_VERSION=$(pkg-config nccl --modversion | grep -Po '\d+\.\d+')
export TF_IGNORE_MAX_BAZEL_VERSION=1
export NCCL_INSTALL_PATH=/usr
# Does tensorflow really need the compiler overridden in 5 places? Yes.
export CC=gcc
export CXX=g++
# For some reason, simlink is not resolved.
export GCC_HOST_COMPILER_PATH=/usr/bin/gcc
export HOST_C_COMPILER=/usr/bin/${CC}
export HOST_CXX_COMPILER=/usr/bin/${CXX}
export TF_CUDA_CLANG=0 # Clang currently disabled because it's not compatible at the moment.
export CLANG_CUDA_COMPILER_PATH=/usr/bin/clang
export TF_CUDA_PATHS=/opt/cuda,/usr/lib,/usr
export TF_CUDA_VERSION=$(/opt/cuda/bin/nvcc --version | sed -n 's/^.*release \(.*\),.*/\1/p')
export TF_CUDNN_VERSION=$(sed -n 's/^#define CUDNN_MAJOR\s*\(.*\).*/\1/p' /usr/include/cudnn_version.h)
# https://github.com/tensorflow/tensorflow/blob/1ba2eb7b313c0c5001ee1683a3ec4fbae01105fd/third_party/gpus/cuda_configure.bzl#L411-L446
# according to the above, we should be specifying CUDA compute capabilities as 'sm_XX' or 'compute_XX' from now on
# add latest PTX for future compatibility
# Valid values can be discovered from nvcc --help
export TF_CUDA_COMPUTE_CAPABILITIES=sm_80,sm_86,sm_87,sm_89,sm_90,compute_90
# Override default version.
export JAXLIB_RELEASE=$pkgver
cd $srcdir/jax-jaxlib-v$pkgver
../bazel-6.1.2-linux-x86_64 --output_user_root=$srcdir/bazel \
run --action_env=JAXLIB_RELEASE --verbose_failures=true \
//jaxlib/tools:build_wheel -- \
--cpu x86_64 --output_path=$PWD/dist --jaxlib_git_hash ''
}
package() {
cd $srcdir/jax-jaxlib-v$pkgver
python -m installer \
--compile-bytecode 1 \
--destdir $pkgdir \
$srcdir/jax-jaxlib-v$pkgver/dist/jaxlib-$pkgver-*.whl
}
|