summarylogtreecommitdiffstats
path: root/cuda_call.patch
blob: 7f1dae135808f8e5ae3b4ec606d90aea77ee4d2c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu
index 930e5e07b..e116b8c5f 100644
--- a/src/common/random_generator.cu
+++ b/src/common/random_generator.cu
@@ -59,6 +59,17 @@ void RandGenerator<gpu, float>::Seed(mshadow::Stream<gpu> *s, uint32_t seed) {
   s->Wait();
 }
 
+template<>
+void RandGenerator<gpu, float>::AllocState(RandGenerator<gpu> *inst) {
+  CUDA_CALL(cudaMalloc(&inst->states_,
+                       kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
+}
+
+template<>
+void RandGenerator<gpu, float>::FreeState(RandGenerator<gpu> *inst) {
+  CUDA_CALL(cudaFree(inst->states_));
+}
+
 }  // namespace random
 }  // namespace common
 }  // namespace mxnet
diff --git a/src/common/random_generator.h b/src/common/random_generator.h
index 5d78b616e..1c8ae01de 100644
--- a/src/common/random_generator.h
+++ b/src/common/random_generator.h
@@ -150,14 +150,9 @@ class RandGenerator<gpu, DType> {
     curandStatePhilox4_32_10_t state_;
   };  // class RandGenerator<gpu, DType>::Impl
 
-  static void AllocState(RandGenerator<gpu, DType> *inst) {
-    CUDA_CALL(cudaMalloc(&inst->states_,
-                         kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
-  }
+  static void AllocState(RandGenerator<gpu, DType> *inst);
 
-  static void FreeState(RandGenerator<gpu, DType> *inst) {
-    CUDA_CALL(cudaFree(inst->states_));
-  }
+  static void FreeState(RandGenerator<gpu, DType> *inst);
 
   void Seed(mshadow::Stream<gpu> *s, uint32_t seed);