summarylogtreecommitdiffstats
path: root/use_cub_from_cuda.patch
blob: f2ac948c7183a59e10de5602d4fc3d1b7ac914f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
diff --git a/build_deps/toolchains/gpu/cub.BUILD b/build_deps/toolchains/gpu/cub.BUILD
index cdc9e4f3..159e5f33 100644
--- a/build_deps/toolchains/gpu/cub.BUILD
+++ b/build_deps/toolchains/gpu/cub.BUILD
@@ -18,7 +18,6 @@ filegroup(
 cc_library(
     name = "cub",
     hdrs = if_cuda([":cub_header_files"]),
-    include_prefix = "gpu",
     deps = [
         "@local_config_cuda//cuda:cuda_headers",
     ],
diff --git a/build_deps/toolchains/gpu/cuda/BUILD.tpl b/build_deps/toolchains/gpu/cuda/BUILD.tpl
index 1ac5643f..2761abac 100644
--- a/build_deps/toolchains/gpu/cuda/BUILD.tpl
+++ b/build_deps/toolchains/gpu/cuda/BUILD.tpl
@@ -175,6 +175,11 @@ cc_library(
     ],
 )
 
+alias(
+    name = "cub_headers",
+    actual = "%{cub_actual}",
+)
+
 cc_library(
     name = "cupti_headers",
     hdrs = [
diff --git a/build_deps/toolchains/gpu/cuda/BUILD.windows.tpl b/build_deps/toolchains/gpu/cuda/BUILD.windows.tpl
index 3ed4fd41..14554a38 100644
--- a/build_deps/toolchains/gpu/cuda/BUILD.windows.tpl
+++ b/build_deps/toolchains/gpu/cuda/BUILD.windows.tpl
@@ -134,6 +134,11 @@ cc_library(
     ],
 )
 
+alias(
+    name = "cub_headers",
+    actual = "%{cub_actual}",
+)
+
 cc_library(
     name = "cupti_headers",
     hdrs = [
diff --git a/build_deps/toolchains/gpu/cuda_configure.bzl b/build_deps/toolchains/gpu/cuda_configure.bzl
index ba38c6b5..b30c67ac 100644
--- a/build_deps/toolchains/gpu/cuda_configure.bzl
+++ b/build_deps/toolchains/gpu/cuda_configure.bzl
@@ -668,6 +668,7 @@ def _get_cuda_config(repository_ctx):
     return struct(
         cuda_toolkit_path = toolkit_path,
         cuda_version = cuda_version,
+        cuda_version_major = cuda_major,
         cudart_version = cudart_version,
         cublas_version = cublas_version,
         cusolver_version = cusolver_version,
@@ -725,6 +726,7 @@ def _create_dummy_repository(repository_ctx):
             "%{cufft_lib}": lib_name("cufft", cpu_value),
             "%{curand_lib}": lib_name("curand", cpu_value),
             "%{cupti_lib}": lib_name("cupti", cpu_value),
+            "%{cub_actual}": ":cuda_headers",
             "%{copy_rules}": "",
             "%{cuda_headers}": "",
         },
@@ -950,6 +952,10 @@ def _create_local_cuda_repository(repository_ctx):
         },
     )
 
+    cub_actual = "@cub_archive//:cub"
+    if int(cuda_config.cuda_version_major) >= 11:
+        cub_actual = ":cuda_headers"
+
     _tpl(
         repository_ctx,
         "cuda:BUILD",
@@ -964,6 +970,7 @@ def _create_local_cuda_repository(repository_ctx):
             "%{cufft_lib}": cuda_libs["cufft"].basename,
             "%{curand_lib}": cuda_libs["curand"].basename,
             "%{cupti_lib}": cuda_libs["cupti"].basename,
+            "%{cub_actual}": cub_actual,
             "%{copy_rules}": "\n".join(copy_rules),
             "%{cuda_headers}": (
                 '":cuda-include",\n' + '        ":cudnn-include",'
diff --git a/tensorflow_addons/custom_ops/layers/BUILD b/tensorflow_addons/custom_ops/layers/BUILD
index 44d87710..a039f25d 100644
--- a/tensorflow_addons/custom_ops/layers/BUILD
+++ b/tensorflow_addons/custom_ops/layers/BUILD
@@ -12,7 +12,7 @@ custom_op_library(
         "cc/ops/correlation_cost_op.cc",
     ],
     cuda_deps = [
-        "@cub_archive//:cub",
+        "@local_config_cuda//cuda:cub_headers",
     ],
     cuda_srcs = [
         "cc/kernels/correlation_cost_op.h",
diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc
index 16b012a7..722f317d 100644
--- a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc
+++ b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
 
 #define EIGEN_USE_GPU
 
-#include "gpu/cub/device/device_reduce.cuh"
+#include "cub/device/device_reduce.cuh"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/util/gpu_kernel_helper.h"