summarylogtreecommitdiffstats
path: root/linux-compile.patch
blob: 30bbba91d9cd8adfa1448e6c06df8274f35c2cff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
diff --git a/vs_mxnet/vsMXNet.cpp b/vs_mxnet/vsMXNet.cpp
index 6190ff2..f9e1184 100644
--- a/vs_mxnet/vsMXNet.cpp
+++ b/vs_mxnet/vsMXNet.cpp
@@ -3,20 +3,10 @@
 #include <algorithm>
 #include <vector>
 
-#include <VapourSynth/VapourSynth.h>
-#include <VapourSynth/VSHelper.h>
+#include <VapourSynth.h>
+#include <VSHelper.h>
 
-#include "MXDll.h"
-
-#ifdef _MSC_VER
-#if defined (_WINDEF_) && defined(min) && defined(max)
-#undef min
-#undef max
-#endif
-#ifndef NOMINMAX
-#define NOMINMAX
-#endif
-#endif
+#include <mxnet/c_predict_api.h>
 
 // no int8 and uint16
 inline int VSFormatToMXDtype(const VSFormat *format)
@@ -81,18 +71,16 @@ std::vector<char> ReadFile(const std::string &file_path)
     return buf;
 }
 
-MXNet mx("libmxnet.dll");
-
 inline int mxForward(mxnetData * VS_RESTRICT d)
 {
     int ch = d->vi.format->numPlanes;
     auto imageSize = d->patch_h * d->patch_w * ch;
 
-    if (mx.MXPredSetInput(d->hPred, "data", (float *)d->srcBuffer, imageSize) != 0) {
+    if (MXPredSetInput(d->hPred, "data", (float *)d->srcBuffer, imageSize) != 0) {
         return 2;
     }
 
-    if (mx.MXPredForward(d->hPred) != 0) {
+    if (MXPredForward(d->hPred) != 0) {
         return 2;
     }
 
@@ -102,7 +90,7 @@ inline int mxForward(mxnetData * VS_RESTRICT d)
     uint32_t shape_len = 0;
 
     // Get Output Result
-    if (mx.MXPredGetOutputShape(d->hPred, output_index, &shape, &shape_len) != 0) {
+    if (MXPredGetOutputShape(d->hPred, output_index, &shape, &shape_len) != 0) {
         return 2;
     }
 
@@ -113,7 +101,7 @@ inline int mxForward(mxnetData * VS_RESTRICT d)
         return 1;
     }
 
-    if (mx.MXPredGetOutput(d->hPred, output_index, (float *)d->dstBuffer, outputSize) != 0) {
+    if (MXPredGetOutput(d->hPred, output_index, (float *)d->dstBuffer, outputSize) != 0) {
         return 2;
     }
 
@@ -208,7 +196,7 @@ static const VSFrameRef *VS_CC mxGetFrame(int n, int activationReason, void **in
                 err = "mxnet: input and target shapes do not match";
             else if (error == 2) {
                 err = "mxnet: failed to process: ";
-                err += mx.MXGetLastError();
+                err += MXGetLastError();
             }
             else if (error == 3)
                 err = "mxnet: not support clip format";
@@ -231,7 +219,7 @@ static void VS_CC mxFree(void *instanceData, VSCore *core, const VSAPI *vsapi)
     mxnetData *d = static_cast<mxnetData *>(instanceData);
     vsapi->freeNode(d->node);
 
-    mx.MXPredFree(d->hPred);
+    MXPredFree(d->hPred);
 
     vs_aligned_free(d->srcBuffer);
     vs_aligned_free(d->dstBuffer);
@@ -439,19 +427,11 @@ static void VS_CC mxCreate(const VSMap *in, VSMap *out, void *userData, VSCore *
 
         d.hPred = nullptr;
 
-        if (!mx.IsInit()) {
-            mx.LoadDll(nullptr);
-        }
-
-        if (!mx.IsInit()) {
-            throw std::string{ "Cannot load MXNet. Please check MXNet installation." };
-        }
-
         const char *arg_dtype_names[] = { "data" };
         int arg_dtype[1] = { input_dtype };
 
         // Create Predictor
-        if (mx.MXPredCreateEx(
+        if (MXPredCreateEx(
             json_data.data(), param_data.data(), 
             static_cast<int>(param_data.size()),
             dev_type, dev_id,
@@ -459,11 +439,11 @@ static void VS_CC mxCreate(const VSMap *in, VSMap *out, void *userData, VSCore *
             input_keys, input_shape_indptr, input_shape_data,
             1, arg_dtype_names, arg_dtype,
             &d.hPred) != 0) {
-            throw std::string{ "Create MXNet Predictor failed: "} + mx.MXGetLastError();
+            throw std::string{ "Create MXNet Predictor failed: "} + MXGetLastError();
         }
 
         if (d.hPred == nullptr) {
-            throw std::string{ "Invalid MXNet Predictor:" } + mx.MXGetLastError() + " Please Try to Upgrade MXNet.";
+            throw std::string{ "Invalid MXNet Predictor:" } + MXGetLastError() + " Please Try to Upgrade MXNet.";
         }
     } catch (const std::string & error) {
         vsapi->setError(out, ("mxnet: " + error).c_str());