summarylogtreecommitdiffstats
path: root/PKGBUILD
blob: 5241be7d94ba07b0494c51cba83128bd652b9a6f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Maintainer: Daniel Bershatsky <bepshatsky@yandex.ru>

pkgname=python-jaxlib
pkgver=0.9.0
pkgrel=4
pkgdesc='XLA library for JAX'
arch=('x86_64')
url='https://github.com/jax-ml/jax'
license=('Apache-2.0')
groups=('jax')
depends=(
    'python-ml-dtypes'
    'python-numpy'
    'python-scipy'
)
makedepends=(
    'clang20'
    'python-build'
    'python-installer'
    'python-setuptools'
    'python-wheel'
    'xxd'
)
_bazel_ver=7.7.0
_xla_sha='bb760b047bdbfeff962f0366ad5cc782c98657e0'
source=("jax-${pkgver}.tar.gz::$url/archive/refs/tags/jax-v${pkgver}.tar.gz"
        "jax-xla-${pkgver}.tar.gz::https://github.com/openxla/xla/archive/${_xla_sha}.tar.gz"
        "bazel-${_bazel_ver}-linux-x86_64::https://github.com/bazelbuild/bazel/releases/download/${_bazel_ver}/bazel-${_bazel_ver}-linux-x86_64"
        'bazelrc.sh'
        'xla.diff')
noextract=("bazel-${_bazel_ver}-linux-x86_64")
sha256sums=('8525c72ac7ea01851297df5b25ca4622c65299c265c87dfe78420bb29e7b1bb3'
            '665ec74f3ca69905ac21dbf075ccaeed97755d8e125e068a4297b5c6c35d3a5c'
            'fe7e799cbc9140f986b063e06800a3d4c790525075c877d00a7112669824acbf'
            'SKIP'
            'SKIP')

prepare() {
    mkdir -p {base,cache,dist,repo,sandbox}

    ln -sf "$(readlink bazel-${_bazel_ver}-linux-x86_64)" "$srcdir/bazel"
    chmod +x $srcdir/bazel-${_bazel_ver}-linux-x86_64

    env -i srcdir="${srcdir}" envsubst < bazelrc.sh > bazelrc

    cd "$srcdir/xla-$_xla_sha"
    patch -p 1 -i ../xla.diff
}

build() {
    cd jax-jax-v$pkgver

    # Override default version.
    export JAXLIB_RELEASE=$pkgver

    # ArchLinux have been started to use source fortification level 3 since
    # 2023. It is too restrictive to build complex projects.
    #
    # [rfc0017]: https://rfc.archlinux.page/0017-increase-fortification-level/
    export CFLAGS="${CFLAGS/-Wp,-D_FORTIFY_SOURCE=3/} -U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=2"
    export CXXFLAGS="${CXXFLAGS/-Wp,-D_FORTIFY_SOURCE=3/} -U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=2"

    ../bazel-${_bazel_ver}-linux-x86_64 --bazelrc=../bazelrc build \
        --override_repository=xla="${srcdir}/xla-${_xla_sha}" \
        //jaxlib/tools:jaxlib_wheel
}

package() {
    cd jax-jax-v$pkgver
    install -Dm644 LICENSE "${pkgdir}/usr/share/licenses/${pkgname}/LICENSE"

    bazel_bin=$(../bazel-${_bazel_ver}-linux-x86_64 --bazelrc=../bazelrc \
                info bazel-bin 2> /dev/null)
    wheel_abi=cp314-cp314-manylinux_2_27_x86_64
    wheel="jaxlib-${pkgver}.dev0+selfbuilt-${wheel_abi}.whl"
    python -m installer --compile-bytecode=1 --destdir=$pkgdir \
        "${bazel_bin}/jaxlib/tools/dist/$wheel"
}