summarylogtreecommitdiffstats
path: root/cuda11.1.patch
diff options
context:
space:
mode:
Diffstat (limited to 'cuda11.1.patch')
-rw-r--r--cuda11.1.patch136
1 files changed, 136 insertions, 0 deletions
diff --git a/cuda11.1.patch b/cuda11.1.patch
new file mode 100644
index 000000000000..e4e9dd769a36
--- /dev/null
+++ b/cuda11.1.patch
@@ -0,0 +1,136 @@
+From 4a64bbe4ff9fb03a948ee76f7349cfdb9e9b7528 Mon Sep 17 00:00:00 2001
+From: Nathan Luehr <nluehr@nvidia.com>
+Date: Thu, 13 Aug 2020 09:46:43 -0700
+Subject: [PATCH 1/2] Fix cudart 11.1 soname
+
+---
+ third_party/gpus/cuda_configure.bzl | 12 ++++++++++--
+ 1 file changed, 10 insertions(+), 2 deletions(-)
+
+diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
+index ea33963fe19fb..3e6bdc9d8eb22 100644
+--- a/third_party/gpus/cuda_configure.bzl
++++ b/third_party/gpus/cuda_configure.bzl
+@@ -534,14 +534,14 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
+ "cudart",
+ cpu_value,
+ cuda_config.config["cuda_library_dir"],
+- cuda_config.cuda_version,
++ cuda_config.cudart_version,
+ static = False,
+ ),
+ "cudart_static": _check_cuda_lib_params(
+ "cudart_static",
+ cpu_value,
+ cuda_config.config["cuda_library_dir"],
+- cuda_config.cuda_version,
++ cuda_config.cudart_version,
+ static = True,
+ ),
+ "cublas": _check_cuda_lib_params(
+@@ -651,6 +651,7 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
+ cuda_toolkit_path: The CUDA toolkit installation directory.
+ cudnn_install_basedir: The cuDNN installation directory.
+ cuda_version: The version of CUDA on the system.
++ cudart_version: The CUDA runtime version on the system.
+ cudnn_version: The version of cuDNN on the system.
+ compute_capabilities: A list of the system's CUDA compute capabilities.
+ cpu_value: The name of the host operating system.
+@@ -668,6 +669,10 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
+ cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"]
+
+ if int(cuda_major) >= 11:
++ if int(cuda_major) == 11:
++ cudart_version = "64_110" if is_windows else "11.0"
++ else:
++ cudart_version = ("64_%s" if is_windows else "%s") % cuda_major
+ cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0]
+ cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0]
+ curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0]
+@@ -677,12 +682,14 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
+ # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
+ # It changed from 'x.y' to just 'x' in CUDA 10.1.
+ cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major
++ cudart_version = cuda_version
+ cublas_version = cuda_lib_version
+ cusolver_version = cuda_lib_version
+ curand_version = cuda_lib_version
+ cufft_version = cuda_lib_version
+ cusparse_version = cuda_lib_version
+ else:
++ cudart_version = cuda_version
+ cublas_version = cuda_version
+ cusolver_version = cuda_version
+ curand_version = cuda_version
+@@ -693,6 +700,7 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
+ cuda_toolkit_path = toolkit_path,
+ cuda_version = cuda_version,
++ cudart_version = cudart_version,
+ cublas_version = cublas_version,
+ cusolver_version = cusolver_version,
+ curand_version = curand_version,
+
+From 2642e93e6cbb7a3a1e916abf1ab8e18fa2735237 Mon Sep 17 00:00:00 2001
+From: Nathan Luehr <nluehr@nvidia.com>
+Date: Fri, 14 Aug 2020 13:21:58 -0700
+Subject: [PATCH 2/2] Use correct cudart soname in GetDsoHandle
+
+---
+ tensorflow/stream_executor/platform/default/dso_loader.cc | 3 ++-
+ third_party/gpus/cuda/cuda_config.h.tpl | 1 +
+ third_party/gpus/cuda_configure.bzl | 2 ++
+ 3 files changed, 5 insertions(+), 1 deletion(-)
+
+diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc
+index 84293b7767a20..a78c738f32c2a 100644
+--- a/tensorflow/stream_executor/platform/default/dso_loader.cc
++++ b/tensorflow/stream_executor/platform/default/dso_loader.cc
+@@ -31,6 +31,7 @@ namespace internal {
+
+ namespace {
+ string GetCudaVersion() { return TF_CUDA_VERSION; }
++string GetCudaRtVersion() { return TF_CUDART_VERSION; }
+ string GetCudnnVersion() { return TF_CUDNN_VERSION; }
+ string GetCublasVersion() { return TF_CUBLAS_VERSION; }
+ string GetCusolverVersion() { return TF_CUSOLVER_VERSION; }
+@@ -77,7 +78,7 @@ port::StatusOr<void*> GetCudaDriverDsoHandle() {
+ }
+
+ port::StatusOr<void*> GetCudaRuntimeDsoHandle() {
+- return GetDsoHandle("cudart", GetCudaVersion());
++ return GetDsoHandle("cudart", GetCudaRtVersion());
+ }
+
+ port::StatusOr<void*> GetCublasDsoHandle() {
+diff --git a/third_party/gpus/cuda/cuda_config.h.tpl b/third_party/gpus/cuda/cuda_config.h.tpl
+index b59889938b1a9..ab26686ccb8b2 100644
+--- a/third_party/gpus/cuda/cuda_config.h.tpl
++++ b/third_party/gpus/cuda/cuda_config.h.tpl
+@@ -17,6 +17,7 @@ limitations under the License.
+ #define CUDA_CUDA_CONFIG_H_
+
+ #define TF_CUDA_VERSION "%{cuda_version}"
++#define TF_CUDART_VERSION "%{cudart_version}"
+ #define TF_CUBLAS_VERSION "%{cublas_version}"
+ #define TF_CUSOLVER_VERSION "%{cusolver_version}"
+ #define TF_CURAND_VERSION "%{curand_version}"
+diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
+index 3e6bdc9d8eb22..f85a53b1593b4 100644
+--- a/third_party/gpus/cuda_configure.bzl
++++ b/third_party/gpus/cuda_configure.bzl
+@@ -824,6 +824,7 @@ filegroup(name="cudnn-include")
+ "cuda:cuda_config.h",
+ {
+ "%{cuda_version}": "",
++ "%{cudart_version}": "",
+ "%{cublas_version}": "",
+ "%{cusolver_version}": "",
+ "%{curand_version}": "",
+@@ -1289,6 +1290,7 @@ def _create_local_cuda_repository(repository_ctx):
+ tpl_paths["cuda:cuda_config.h"],
+ {
+ "%{cuda_version}": cuda_config.cuda_version,
++ "%{cudart_version}": cuda_config.cudart_version,
+ "%{cublas_version}": cuda_config.cublas_version,
+ "%{cusolver_version}": cuda_config.cusolver_version,
+ "%{curand_version}": cuda_config.curand_version,