summarylogtreecommitdiffstats
path: root/cuda_call.patch
diff options
context:
space:
mode:
Diffstat (limited to 'cuda_call.patch')
-rw-r--r--cuda_call.patch43
1 files changed, 43 insertions, 0 deletions
diff --git a/cuda_call.patch b/cuda_call.patch
new file mode 100644
index 000000000000..7f1dae135808
--- /dev/null
+++ b/cuda_call.patch
@@ -0,0 +1,43 @@
+diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu
+index 930e5e07b..e116b8c5f 100644
+--- a/src/common/random_generator.cu
++++ b/src/common/random_generator.cu
+@@ -59,6 +59,17 @@ void RandGenerator<gpu, float>::Seed(mshadow::Stream<gpu> *s, uint32_t seed) {
+ s->Wait();
+ }
+
++template<>
++void RandGenerator<gpu, float>::AllocState(RandGenerator<gpu> *inst) {
++ CUDA_CALL(cudaMalloc(&inst->states_,
++ kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
++}
++
++template<>
++void RandGenerator<gpu, float>::FreeState(RandGenerator<gpu> *inst) {
++ CUDA_CALL(cudaFree(inst->states_));
++}
++
+ } // namespace random
+ } // namespace common
+ } // namespace mxnet
+diff --git a/src/common/random_generator.h b/src/common/random_generator.h
+index 5d78b616e..1c8ae01de 100644
+--- a/src/common/random_generator.h
++++ b/src/common/random_generator.h
+@@ -150,14 +150,9 @@ class RandGenerator<gpu, DType> {
+ curandStatePhilox4_32_10_t state_;
+ }; // class RandGenerator<gpu, DType>::Impl
+
+- static void AllocState(RandGenerator<gpu, DType> *inst) {
+- CUDA_CALL(cudaMalloc(&inst->states_,
+- kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
+- }
++ static void AllocState(RandGenerator<gpu, DType> *inst);
+
+- static void FreeState(RandGenerator<gpu, DType> *inst) {
+- CUDA_CALL(cudaFree(inst->states_));
+- }
++ static void FreeState(RandGenerator<gpu, DType> *inst);
+
+ void Seed(mshadow::Stream<gpu> *s, uint32_t seed);
+