From c6769e20bf6096d5828e2590def2b25edb3189d6 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 17 Aug 2020 14:12:02 -0700 Subject: [PATCH] Use CUB from the CUDA Toolkit starting with version 11.0. PiperOrigin-RevId: 327096097 Change-Id: I444ec3ac3348f76728c931a4bb4aa1b7cbe1b673 --- tensorflow/core/kernels/BUILD | 8 ++--- tensorflow/core/kernels/gpu_prim.h | 26 +++++++------- tensorflow/core/util/BUILD | 2 +- third_party/cub.BUILD | 1 - third_party/cub.pr170.patch | 48 ------------------------- third_party/gpus/cuda/BUILD.tpl | 6 ++++ third_party/gpus/cuda/BUILD.windows.tpl | 5 +++ third_party/gpus/cuda_configure.bzl | 7 ++++ 8 files changed, 36 insertions(+), 67 deletions(-) delete mode 100644 third_party/cub.pr170.patch diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 88958cdaa9878..19dc5c73252a8 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -490,7 +490,7 @@ cc_library( name = "gpu_prim_hdrs", hdrs = ["gpu_prim.h"], deps = if_cuda([ - "@cub_archive//:cub", + "@local_config_cuda//cuda:cub_headers", ]) + if_rocm([ "@local_config_rocm//rocm:rocprim", ]), @@ -3896,7 +3896,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", ] + if_cuda([ - "@cub_archive//:cub", + "@local_config_cuda//cuda:cub_headers", "@local_config_cuda//cuda:cudnn_header", ]) + if_rocm([ "@local_config_rocm//rocm:rocprim", @@ -3986,7 +3986,7 @@ tf_kernel_library( ] + if_cuda_or_rocm([ ":reduction_ops", ]) + if_cuda([ - "@cub_archive//:cub", + "@local_config_cuda//cuda:cub_headers", "//tensorflow/core:stream_executor", "//tensorflow/stream_executor/cuda:cuda_stream", ]) + if_rocm([ @@ -4708,7 +4708,7 @@ tf_kernel_library( ] + if_cuda_or_rocm([ ":reduction_ops", ]) + if_cuda([ - "@cub_archive//:cub", + "@local_config_cuda//cuda:cub_headers", ]) + if_rocm([ "@local_config_rocm//rocm:rocprim", ]), diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h index 82fcb21e0ac04..33c5df1ae2371 100644 --- a/tensorflow/core/kernels/gpu_prim.h +++ b/tensorflow/core/kernels/gpu_prim.h @@ -15,19 +15,19 @@ limitations under the license, the license you must see. #define TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_ #if GOOGLE_CUDA -#include "third_party/cub/block/block_load.cuh" -#include "third_party/cub/block/block_scan.cuh" -#include "third_party/cub/block/block_store.cuh" -#include "third_party/cub/device/device_histogram.cuh" -#include "third_party/cub/device/device_radix_sort.cuh" -#include "third_party/cub/device/device_reduce.cuh" -#include "third_party/cub/device/device_segmented_radix_sort.cuh" -#include "third_party/cub/device/device_segmented_reduce.cuh" -#include "third_party/cub/device/device_select.cuh" -#include "third_party/cub/iterator/counting_input_iterator.cuh" -#include "third_party/cub/iterator/transform_input_iterator.cuh" -#include "third_party/cub/thread/thread_operators.cuh" -#include "third_party/cub/warp/warp_reduce.cuh" +#include "cub/block/block_load.cuh" +#include "cub/block/block_scan.cuh" +#include "cub/block/block_store.cuh" +#include "cub/device/device_histogram.cuh" +#include "cub/device/device_radix_sort.cuh" +#include "cub/device/device_reduce.cuh" +#include "cub/device/device_segmented_radix_sort.cuh" +#include "cub/device/device_segmented_reduce.cuh" +#include "cub/device/device_select.cuh" +#include "cub/iterator/counting_input_iterator.cuh" +#include "cub/iterator/transform_input_iterator.cuh" +#include "cub/thread/thread_operators.cuh" +#include "cub/warp/warp_reduce.cuh" #include "third_party/gpus/cuda/include/cusparse.h" namespace gpuprim = ::cub; diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index 4d2ff9a805811..241e382a650ba 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -626,7 +626,7 @@ tf_kernel_library( "//tensorflow/core:lib", ] + if_cuda([ "//tensorflow/stream_executor/cuda:cusparse_lib", - "@cub_archive//:cub", + "@local_config_cuda//cuda:cub_headers", ]) + if_rocm([ "@local_config_rocm//rocm:hipsparse", ]), diff --git a/third_party/cub.BUILD b/third_party/cub.BUILD index a04347b21eefb..29159c9dad3d3 100644 --- a/third_party/cub.BUILD +++ b/third_party/cub.BUILD @@ -20,7 +20,6 @@ filegroup( cc_library( name = "cub", hdrs = if_cuda([":cub_header_files"]), - include_prefix = "third_party", deps = [ "@local_config_cuda//cuda:cuda_headers", ], diff --git a/third_party/cub.pr170.patch b/third_party/cub.pr170.patch deleted file mode 100644 index 5b7432e885867..0000000000000 --- a/third_party/cub.pr170.patch +++ /dev/null @@ -1,48 +0,0 @@ -From fd6e7a61a16a17fa155cbd717de0c79001af71e6 Mon Sep 17 00:00:00 2001 -From: Artem Belevich -Date: Mon, 23 Sep 2019 11:18:56 -0700 -Subject: [PATCH] Fix CUDA version detection in CUB - -This fixes the problem with CUB using deprecated shfl/vote instructions when CUB -is compiled with clang (e.g. some TensorFlow builds). ---- - cub/util_arch.cuh | 3 ++- - cub/util_type.cuh | 4 ++-- - 2 files changed, 4 insertions(+), 3 deletions(-) - -diff --git a/cub/util_arch.cuh b/cub/util_arch.cuh -index 87c5ea2fb..9ad9d1cbb 100644 ---- a/cub/util_arch.cuh -+++ b/cub/util_arch.cuh -@@ -44,7 +44,8 @@ namespace cub { - - #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document - --#if (__CUDACC_VER_MAJOR__ >= 9) && !defined(CUB_USE_COOPERATIVE_GROUPS) -+#if !defined(CUB_USE_COOPERATIVE_GROUPS) && \ -+ (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) - #define CUB_USE_COOPERATIVE_GROUPS - #endif - -diff --git a/cub/util_type.cuh b/cub/util_type.cuh -index 0ba41e1ed..b2433d735 100644 ---- a/cub/util_type.cuh -+++ b/cub/util_type.cuh -@@ -37,7 +37,7 @@ - #include - #include - --#if (__CUDACC_VER_MAJOR__ >= 9) -+#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) - #include - #endif - -@@ -1063,7 +1063,7 @@ struct FpLimits - }; - - --#if (__CUDACC_VER_MAJOR__ >= 9) -+#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) - template <> - struct FpLimits<__half> - { diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index e5833e7cdbbc2..a4a21abc36769 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -176,6 +176,11 @@ cc_library( ], ) +alias( + name = "cub_headers", + actual = "%{cub_actual}" +) + cuda_header_library( name = "cupti_headers", hdrs = [":cuda-extras"], @@ -224,3 +229,4 @@ py_library( ) %{copy_rules} + diff --git a/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/gpus/cuda/BUILD.windows.tpl index 55a9ec3d1ab10..cabfac28fc357 100644 --- a/third_party/gpus/cuda/BUILD.windows.tpl +++ b/third_party/gpus/cuda/BUILD.windows.tpl @@ -171,6 +171,11 @@ cc_library( ], ) +alias( + name = "cub_headers", + actual = "%{cub_actual}" +) + cuda_header_library( name = "cupti_headers", hdrs = [":cuda-extras"], diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 70bb91159de1a..ea33963fe19fb 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -692,6 +692,7 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script): return struct( cuda_toolkit_path = toolkit_path, cuda_version = cuda_version, + cuda_version_major = cuda_major, cublas_version = cublas_version, cusolver_version = cusolver_version, curand_version = curand_version, @@ -776,6 +777,7 @@ def _create_dummy_repository(repository_ctx): "%{curand_lib}": lib_name("curand", cpu_value), "%{cupti_lib}": lib_name("cupti", cpu_value), "%{cusparse_lib}": lib_name("cusparse", cpu_value), + "%{cub_actual}": ":cuda_headers", "%{copy_rules}": """ filegroup(name="cuda-include") filegroup(name="cublas-include") @@ -1122,6 +1124,10 @@ def _create_local_cuda_repository(repository_ctx): }, ) + cub_actual = "@cub_archive//:cub" + if int(cuda_config.cuda_version_major) >= 11: + cub_actual = ":cuda_headers" + repository_ctx.template( "cuda/BUILD", tpl_paths["cuda:BUILD"], @@ -1137,6 +1143,7 @@ def _create_local_cuda_repository(repository_ctx): "%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]), "%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]), "%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]), + "%{cub_actual}": cub_actual, "%{copy_rules}": "\n".join(copy_rules), }, )