diff options
Diffstat (limited to 'cuda_call.patch')
-rw-r--r-- | cuda_call.patch | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/cuda_call.patch b/cuda_call.patch new file mode 100644 index 000000000000..7f1dae135808 --- /dev/null +++ b/cuda_call.patch @@ -0,0 +1,43 @@ +diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu +index 930e5e07b..e116b8c5f 100644 +--- a/src/common/random_generator.cu ++++ b/src/common/random_generator.cu +@@ -59,6 +59,17 @@ void RandGenerator<gpu, float>::Seed(mshadow::Stream<gpu> *s, uint32_t seed) { + s->Wait(); + } + ++template<> ++void RandGenerator<gpu, float>::AllocState(RandGenerator<gpu> *inst) { ++ CUDA_CALL(cudaMalloc(&inst->states_, ++ kNumRandomStates * sizeof(curandStatePhilox4_32_10_t))); ++} ++ ++template<> ++void RandGenerator<gpu, float>::FreeState(RandGenerator<gpu> *inst) { ++ CUDA_CALL(cudaFree(inst->states_)); ++} ++ + } // namespace random + } // namespace common + } // namespace mxnet +diff --git a/src/common/random_generator.h b/src/common/random_generator.h +index 5d78b616e..1c8ae01de 100644 +--- a/src/common/random_generator.h ++++ b/src/common/random_generator.h +@@ -150,14 +150,9 @@ class RandGenerator<gpu, DType> { + curandStatePhilox4_32_10_t state_; + }; // class RandGenerator<gpu, DType>::Impl + +- static void AllocState(RandGenerator<gpu, DType> *inst) { +- CUDA_CALL(cudaMalloc(&inst->states_, +- kNumRandomStates * sizeof(curandStatePhilox4_32_10_t))); +- } ++ static void AllocState(RandGenerator<gpu, DType> *inst); + +- static void FreeState(RandGenerator<gpu, DType> *inst) { +- CUDA_CALL(cudaFree(inst->states_)); +- } ++ static void FreeState(RandGenerator<gpu, DType> *inst); + + void Seed(mshadow::Stream<gpu> *s, uint32_t seed); + |