summarylogtreecommitdiffstats
path: root/PKGBUILD
blob: 28eb77d5715ac46e35fb886001acb14210d2c2f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Maintainer: wuxxin <wuxxin@gmail.com>
# Based on https://github.com/google/jax/blob/main/build/rocm/README.md
# Based on python-jax, python-jax-opt-cuda-git; original contributors:
# Contributor: Daniel Bershatsky <bepshatsky@yandex.ru>
pkgname='python-jax-rocm'
pkgver=0.3.16
pkgrel=2
pkgdesc='Differentiate, compile, and transform Numpy code (with ROCM)'
_srcname="jax-jax-v${pkgver}"
# develop-upstream@2022-08-10:17:10:MESZ
_tfid="343a9e91963de6dd83e0f7470a641dca365d821f"
_tfname="tensorflow-upstream-${_tfid}"
arch=('x86_64')
url='https://github.com/google/jax'
license=('Apache')
depends=(
    'absl-py'
    'miopen'
    'python-etils'
    'python-numpy'
    'python-opt_einsum'
    'python-scipy'
    'python-six'
    'python'
    'rccl'
    'rocm-hip-runtime'
)
makedepends=(
    'bazel'
    'git'
    'miopen'
    'python-pip'
    'python-wheel'
    'rccl'
    'rocm-hip-sdk'
)
source=(
    "${pkgname}-${pkgver}.tar.gz::${url}/archive/refs/tags/jax-v${pkgver}.tar.gz"
    "${_tfname}.tar.gz::https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/archive/${_tfid}.tar.gz"
    "fix-rocblas-include.patch"
)
sha512sums=(
    'de2b16466009cfa56c46d44bbf58d2f8f293d6dbb6a0fdbc1591cd92fcf0eacbd119bd4d9c6ea4099a0443514a58e11e7c8fd8e94b73fa121d0f0a45a6849a53'
    '04c9ece4cb782f52925b1e7ee18ccc916a65dde051554b106164d371b3b7b96037218e635c235c8f38d088066421c0c5a4a7e201ed435b3e69c128f5ac20f0ac'
    '36596fd586cbdac990466a53cc0683de759b2f0646ed08edf04e88e3ee4de1a7381cf1b3a4784aa9a240e2ad894d55e8728226b95071539f774bc5d9b790b5fc'
)
conflicts=('python-jax' 'python-jaxlib')
provides=('python-jax' 'python-jaxlib')

# test
# python -c "import jax; print(jax.devices(),jax.devices()[0].device_kind); x=jax.numpy.array([1.2,3.4,5.6]); y=jax.numpy.exp(x); print(y)"

prepare() {
  # loosen acceptable bazel version
  echo "5.*.*" > $_srcname/.bazelversion
  cd "${srcdir}/$_tfname"
  patch -Np1 -i "${srcdir}/fix-rocblas-include.patch"
  # for a in tensorflow/core/util/gpu_solvers.h tensorflow/stream_executor/rocm/rocblas_wrapper.h tensorflow/stream_executor/rocm/rocm_blas.h ; do sed -i -E 's#^(.include "rocm/include/rocblas)(.h")#\1/rocblas\2#g' $a; done
  cd "${srcdir}"
}

build() {
  cd "${srcdir}/${_srcname}"
  if test -z "$TF_ROCM_AMDGPU_TARGETS"; then
    TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"
  fi
  export TF_ROCM_AMDGPU_TARGETS
  python build/build.py --enable_rocm \
    --bazel_options=--override_repository=org_tensorflow=${srcdir}/${_tfname} \
    --rocm_amdgpu_targets=${TF_ROCM_AMDGPU_TARGETS}
}

package() {
  cd "${srcdir}/${_srcname}"
  # installs jaxlib (includes XLA)
  pip install --root="${pkgdir}" --no-deps dist/*.whl
  # installs jax
  pip install --root="${pkgdir}" --no-deps .
}