summarylogtreecommitdiffstats
path: root/cuda10deprecation.patch
diff options
context:
space:
mode:
Diffstat (limited to 'cuda10deprecation.patch')
-rw-r--r--cuda10deprecation.patch32
1 files changed, 32 insertions, 0 deletions
diff --git a/cuda10deprecation.patch b/cuda10deprecation.patch
new file mode 100644
index 000000000000..f09c13659b42
--- /dev/null
+++ b/cuda10deprecation.patch
@@ -0,0 +1,32 @@
+diff --git a/catboost/cuda/cuda_lib/cuda_base.h b/catboost/cuda/cuda_lib/cuda_base.h
+index 9c72b8ef9..b20f76137 100644
+--- a/catboost/cuda/cuda_lib/cuda_base.h
++++ b/catboost/cuda/cuda_lib/cuda_base.h
+@@ -126,7 +126,13 @@ namespace NCudaLib {
+ cudaPointerAttributes attributes;
+ CUDA_SAFE_CALL(cudaPointerGetAttributes(&attributes, (void*)(ptr)));
+ //TODO(noxoomo): currently don't distinguish pinned/non-pinned memory
++#ifndef CUDART_VERSION
++#error "CUDART_VERSION is not defined: include cuda_runtime_api.h"
++#elif (CUDART_VERSION >= 10000)
++ return attributes.type == cudaMemoryTypeHost ? EPtrType::CudaHost : EPtrType::CudaDevice;
++#else
+ return attributes.memoryType == cudaMemoryTypeHost ? EPtrType::CudaHost : EPtrType::CudaDevice;
++#endif
+ }
+
+ template <EPtrType From, EPtrType To>
+@@ -258,7 +264,13 @@ namespace NCudaLib {
+ inline int GetDeviceForPointer(const T* ptr) {
+ cudaPointerAttributes result;
+ CUDA_SAFE_CALL(cudaPointerGetAttributes(&result, (const void*)ptr));
++#ifndef CUDART_VERSION
++#error "CUDART_VERSION is not defined: include cuda_runtime_api.h"
++#elif (CUDART_VERSION >= 10000)
++ CB_ENSURE(result.type == cudaMemoryTypeDevice, "Error: this pointer is not GPU pointer");
++#else
+ CB_ENSURE(result.memoryType == cudaMemoryTypeDevice, "Error: this pointer is not GPU pointer");
++#endif
+ return result.device;
+ }
+