1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
|
diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu
index 930e5e07b..e116b8c5f 100644
--- a/src/common/random_generator.cu
+++ b/src/common/random_generator.cu
@@ -59,6 +59,17 @@ void RandGenerator<gpu, float>::Seed(mshadow::Stream<gpu> *s, uint32_t seed) {
s->Wait();
}
+template<>
+void RandGenerator<gpu, float>::AllocState(RandGenerator<gpu> *inst) {
+ CUDA_CALL(cudaMalloc(&inst->states_,
+ kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
+}
+
+template<>
+void RandGenerator<gpu, float>::FreeState(RandGenerator<gpu> *inst) {
+ CUDA_CALL(cudaFree(inst->states_));
+}
+
} // namespace random
} // namespace common
} // namespace mxnet
diff --git a/src/common/random_generator.h b/src/common/random_generator.h
index 5d78b616e..1c8ae01de 100644
--- a/src/common/random_generator.h
+++ b/src/common/random_generator.h
@@ -150,14 +150,9 @@ class RandGenerator<gpu, DType> {
curandStatePhilox4_32_10_t state_;
}; // class RandGenerator<gpu, DType>::Impl
- static void AllocState(RandGenerator<gpu, DType> *inst) {
- CUDA_CALL(cudaMalloc(&inst->states_,
- kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
- }
+ static void AllocState(RandGenerator<gpu, DType> *inst);
- static void FreeState(RandGenerator<gpu, DType> *inst) {
- CUDA_CALL(cudaFree(inst->states_));
- }
+ static void FreeState(RandGenerator<gpu, DType> *inst);
void Seed(mshadow::Stream<gpu> *s, uint32_t seed);
|