diff --git a/vs_mxnet/vsMXNet.cpp b/vs_mxnet/vsMXNet.cpp index 6190ff2..f9e1184 100644 --- a/vs_mxnet/vsMXNet.cpp +++ b/vs_mxnet/vsMXNet.cpp @@ -3,20 +3,10 @@ #include #include -#include -#include +#include +#include -#include "MXDll.h" - -#ifdef _MSC_VER -#if defined (_WINDEF_) && defined(min) && defined(max) -#undef min -#undef max -#endif -#ifndef NOMINMAX -#define NOMINMAX -#endif -#endif +#include // no int8 and uint16 inline int VSFormatToMXDtype(const VSFormat *format) @@ -81,18 +71,16 @@ std::vector ReadFile(const std::string &file_path) return buf; } -MXNet mx("libmxnet.dll"); - inline int mxForward(mxnetData * VS_RESTRICT d) { int ch = d->vi.format->numPlanes; auto imageSize = d->patch_h * d->patch_w * ch; - if (mx.MXPredSetInput(d->hPred, "data", (float *)d->srcBuffer, imageSize) != 0) { + if (MXPredSetInput(d->hPred, "data", (float *)d->srcBuffer, imageSize) != 0) { return 2; } - if (mx.MXPredForward(d->hPred) != 0) { + if (MXPredForward(d->hPred) != 0) { return 2; } @@ -102,7 +90,7 @@ inline int mxForward(mxnetData * VS_RESTRICT d) uint32_t shape_len = 0; // Get Output Result - if (mx.MXPredGetOutputShape(d->hPred, output_index, &shape, &shape_len) != 0) { + if (MXPredGetOutputShape(d->hPred, output_index, &shape, &shape_len) != 0) { return 2; } @@ -113,7 +101,7 @@ inline int mxForward(mxnetData * VS_RESTRICT d) return 1; } - if (mx.MXPredGetOutput(d->hPred, output_index, (float *)d->dstBuffer, outputSize) != 0) { + if (MXPredGetOutput(d->hPred, output_index, (float *)d->dstBuffer, outputSize) != 0) { return 2; } @@ -208,7 +196,7 @@ static const VSFrameRef *VS_CC mxGetFrame(int n, int activationReason, void **in err = "mxnet: input and target shapes do not match"; else if (error == 2) { err = "mxnet: failed to process: "; - err += mx.MXGetLastError(); + err += MXGetLastError(); } else if (error == 3) err = "mxnet: not support clip format"; @@ -231,7 +219,7 @@ static void VS_CC mxFree(void *instanceData, VSCore *core, const VSAPI *vsapi) mxnetData *d = static_cast(instanceData); vsapi->freeNode(d->node); - mx.MXPredFree(d->hPred); + MXPredFree(d->hPred); vs_aligned_free(d->srcBuffer); vs_aligned_free(d->dstBuffer); @@ -439,19 +427,11 @@ static void VS_CC mxCreate(const VSMap *in, VSMap *out, void *userData, VSCore * d.hPred = nullptr; - if (!mx.IsInit()) { - mx.LoadDll(nullptr); - } - - if (!mx.IsInit()) { - throw std::string{ "Cannot load MXNet. Please check MXNet installation." }; - } - const char *arg_dtype_names[] = { "data" }; int arg_dtype[1] = { input_dtype }; // Create Predictor - if (mx.MXPredCreateEx( + if (MXPredCreateEx( json_data.data(), param_data.data(), static_cast(param_data.size()), dev_type, dev_id, @@ -459,11 +439,11 @@ static void VS_CC mxCreate(const VSMap *in, VSMap *out, void *userData, VSCore * input_keys, input_shape_indptr, input_shape_data, 1, arg_dtype_names, arg_dtype, &d.hPred) != 0) { - throw std::string{ "Create MXNet Predictor failed: "} + mx.MXGetLastError(); + throw std::string{ "Create MXNet Predictor failed: "} + MXGetLastError(); } if (d.hPred == nullptr) { - throw std::string{ "Invalid MXNet Predictor:" } + mx.MXGetLastError() + " Please Try to Upgrade MXNet."; + throw std::string{ "Invalid MXNet Predictor:" } + MXGetLastError() + " Please Try to Upgrade MXNet."; } } catch (const std::string & error) { vsapi->setError(out, ("mxnet: " + error).c_str());