summarylogtreecommitdiffstats
path: root/test.cpp
blob: c1b0149915c203f8c9974e721d45690b1830cb0d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <rocfft/rocfft.h>
#include <hip/hip_runtime.h>
#include <hip/hip_vector_types.h>
#include <vector>
#include <numeric>
#include <cmath>
#include <iostream>

int main()
{
    size_t size = 1024 * 1024;

    rocfft_setup();

    float2 *x;
    hipMalloc((void**)&x, sizeof *x * size);


    std::vector<float2> xin(size);
    for(auto &xx: xin){
        xx.x = 1.0f;
        xx.y = 0.0f;
    }
    hipMemcpy(x, xin.data(), sizeof *x * size, hipMemcpyHostToDevice);

    rocfft_plan plan = nullptr;
    size_t len = size;
    rocfft_plan_create(&plan, rocfft_placement_inplace,
        rocfft_transform_type_complex_forward, rocfft_precision_single,
        1, &len, 1, nullptr);
    size_t work_size = 0;
    rocfft_plan_get_work_buffer_size(plan, &work_size);
    void *work;
    rocfft_execution_info info = nullptr;
    if(work_size){
        rocfft_execution_info_create(&info);
        hipMalloc((void**)&work, work_size);
        rocfft_execution_info_set_work_buffer(info, work, work_size);
    }
    rocfft_execute(plan, (void**)&x, nullptr, info);

    std::vector<float2> xout(size);
    hipMemcpy(xout.data(), x, sizeof *x * size, hipMemcpyDeviceToHost);

    std::vector<float2> xref(size);
    for(auto &xx: xref){
        xx.x = 0.0f;
        xx.y = 0.0f;
    }
    xref[0].x = 1.0f * size;
    
    float tol = 0.001f;
    for(size_t i = 0; i < size; i++){
        if(std::abs(xref[i].x - xout[i].x) + std::abs(xref[i].y - xout[i].y) > tol){
            std::cout << "Element mismatch at index " << i << "\n";
            std::cout << "Expected: " << xref[i].x << " " << xref[i].y << "\n";
            std::cout << "Actual  : " << xout[i].x << " " << xout[i].y << "\n";
            return 1;
        }
    }

    std::cout << "TESTS PASSED!" << std::endl;

    hipFree(x);
    rocfft_plan_destroy(plan);
    rocfft_cleanup();
}