1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
diff --git a/build_deps/toolchains/gpu/cub.BUILD b/build_deps/toolchains/gpu/cub.BUILD
index cdc9e4f3..159e5f33 100644
--- a/build_deps/toolchains/gpu/cub.BUILD
+++ b/build_deps/toolchains/gpu/cub.BUILD
@@ -18,7 +18,6 @@ filegroup(
cc_library(
name = "cub",
hdrs = if_cuda([":cub_header_files"]),
- include_prefix = "gpu",
deps = [
"@local_config_cuda//cuda:cuda_headers",
],
diff --git a/build_deps/toolchains/gpu/cuda/BUILD.tpl b/build_deps/toolchains/gpu/cuda/BUILD.tpl
index 1ac5643f..2761abac 100644
--- a/build_deps/toolchains/gpu/cuda/BUILD.tpl
+++ b/build_deps/toolchains/gpu/cuda/BUILD.tpl
@@ -175,6 +175,11 @@ cc_library(
],
)
+alias(
+ name = "cub_headers",
+ actual = "%{cub_actual}",
+)
+
cc_library(
name = "cupti_headers",
hdrs = [
diff --git a/build_deps/toolchains/gpu/cuda/BUILD.windows.tpl b/build_deps/toolchains/gpu/cuda/BUILD.windows.tpl
index 3ed4fd41..14554a38 100644
--- a/build_deps/toolchains/gpu/cuda/BUILD.windows.tpl
+++ b/build_deps/toolchains/gpu/cuda/BUILD.windows.tpl
@@ -134,6 +134,11 @@ cc_library(
],
)
+alias(
+ name = "cub_headers",
+ actual = "%{cub_actual}",
+)
+
cc_library(
name = "cupti_headers",
hdrs = [
diff --git a/build_deps/toolchains/gpu/cuda_configure.bzl b/build_deps/toolchains/gpu/cuda_configure.bzl
index ba38c6b5..b30c67ac 100644
--- a/build_deps/toolchains/gpu/cuda_configure.bzl
+++ b/build_deps/toolchains/gpu/cuda_configure.bzl
@@ -668,6 +668,7 @@ def _get_cuda_config(repository_ctx):
return struct(
cuda_toolkit_path = toolkit_path,
cuda_version = cuda_version,
+ cuda_version_major = cuda_major,
cudart_version = cudart_version,
cublas_version = cublas_version,
cusolver_version = cusolver_version,
@@ -725,6 +726,7 @@ def _create_dummy_repository(repository_ctx):
"%{cufft_lib}": lib_name("cufft", cpu_value),
"%{curand_lib}": lib_name("curand", cpu_value),
"%{cupti_lib}": lib_name("cupti", cpu_value),
+ "%{cub_actual}": ":cuda_headers",
"%{copy_rules}": "",
"%{cuda_headers}": "",
},
@@ -950,6 +952,10 @@ def _create_local_cuda_repository(repository_ctx):
},
)
+ cub_actual = "@cub_archive//:cub"
+ if int(cuda_config.cuda_version_major) >= 11:
+ cub_actual = ":cuda_headers"
+
_tpl(
repository_ctx,
"cuda:BUILD",
@@ -964,6 +970,7 @@ def _create_local_cuda_repository(repository_ctx):
"%{cufft_lib}": cuda_libs["cufft"].basename,
"%{curand_lib}": cuda_libs["curand"].basename,
"%{cupti_lib}": cuda_libs["cupti"].basename,
+ "%{cub_actual}": cub_actual,
"%{copy_rules}": "\n".join(copy_rules),
"%{cuda_headers}": (
'":cuda-include",\n' + ' ":cudnn-include",'
diff --git a/tensorflow_addons/custom_ops/layers/BUILD b/tensorflow_addons/custom_ops/layers/BUILD
index 44d87710..a039f25d 100644
--- a/tensorflow_addons/custom_ops/layers/BUILD
+++ b/tensorflow_addons/custom_ops/layers/BUILD
@@ -12,7 +12,7 @@ custom_op_library(
"cc/ops/correlation_cost_op.cc",
],
cuda_deps = [
- "@cub_archive//:cub",
+ "@local_config_cuda//cuda:cub_headers",
],
cuda_srcs = [
"cc/kernels/correlation_cost_op.h",
diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc
index 16b012a7..722f317d 100644
--- a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc
+++ b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "gpu/cub/device/device_reduce.cuh"
+#include "cub/device/device_reduce.cuh"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
|