diff --git a/cmake/external/dnnl.cmake b/cmake/external/dnnl.cmake index 175ad41b6f0..397c4d6abeb 100644 --- a/cmake/external/dnnl.cmake +++ b/cmake/external/dnnl.cmake @@ -2,16 +2,16 @@ include (ExternalProject) set(DNNL_URL https://github.com/oneapi-src/onednn.git) # If DNNL_TAG is updated, check if MKLML_VERSION and platform.cmake.patch need to be updated. -set(DNNL_TAG v2.7.1) +set(DNNL_TAG v3.0) if(WIN32) set(DNNL_SHARED_LIB dnnl.dll) set(DNNL_IMPORT_LIB dnnl.lib) else() if (APPLE) - set(DNNL_SHARED_LIB libdnnl.2.dylib) + set(DNNL_SHARED_LIB libdnnl.3.dylib) else() - set(DNNL_SHARED_LIB libdnnl.so.2) + set(DNNL_SHARED_LIB libdnnl.so.3) endif() endif() diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index c147a0f4923..c6ee7e9f451 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -345,7 +345,7 @@ Status DnnlExecutionProvider::Compile(const std::vector& fuse for (size_t i = 0; i < context_num_outputs; i++) { auto output_name = subgraph_primitive->GetOrderedOutputs()[i]; auto output_md = subgraph_primitive->GetOutputInfo(output_name); - auto output_shape = output_md.dims(); + auto output_shape = output_md.get_dims(); //if an output is a scaler, onednn internally uses tensor representation (eg, (1,1,...)) //but allocating an output with no shape instead of the equivalent tensorshape to avoid shape mismatch if (subgraph_primitive->IsScalarOutput(output_name)) { diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_batchnorm.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_batchnorm.cc index a91c5dfc8d5..0e8e9a2f7ad 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_batchnorm.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_batchnorm.cc @@ -26,7 +26,7 @@ void DnnlBatchNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto batchnorm_scale_mem = sp.GetMemory(node.Input(IN_SCALE)); auto scale_md = batchnorm_scale_mem.get_desc(); - auto scale_dims = scale_md.dims(); + auto scale_dims = scale_md.get_dims(); auto batchnorm_bias_mem = sp.GetMemory(node.Input(IN_B)); auto bias_md = batchnorm_bias_mem.get_desc(); @@ -37,41 +37,30 @@ void DnnlBatchNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto batchnorm_var_mem = sp.GetMemory(node.Input(IN_VAR)); auto var_md = batchnorm_var_mem.get_desc(); + // Primitive desc info + auto dst_md = dnnl::memory::desc(src_md.get_dims(), src_md.get_data_type(), dnnl::memory::format_tag::any); + auto flags = dnnl::normalization_flags::use_scale + | dnnl::normalization_flags::use_shift + | dnnl::normalization_flags::use_global_stats; - std::vector src_mds; - src_mds.push_back(scale_md); - src_mds.push_back(bias_md); - const int axis = 0; - - //To make the inputs compatible with OneDNN, we need to concatenate scale and bias into a single tensor of length 2XC - //Then, we create the batchnorm pd and feed in the inputs. - auto concat_pd = dnnl::concat::primitive_desc(axis, src_mds, dnnl_engine); - - //If using GPU this will move the memory from the CPU to the GPU. - batchnorm_scale_mem = sp.GetMemoryAndReshape(node.Input(IN_SCALE), concat_pd.src_desc(), dnnl_engine); - batchnorm_bias_mem = sp.GetMemoryAndReshape(node.Input(IN_B), concat_pd.src_desc(), dnnl_engine); - batchnorm_mean_mem = sp.GetMemoryAndReshape(node.Input(IN_MEAN), mean_md, dnnl_engine); - batchnorm_var_mem = sp.GetMemoryAndReshape(node.Input(IN_VAR), var_md, dnnl_engine); - auto batchnorm_scale_shift_mem = dnnl::memory(concat_pd.dst_desc(), dnnl_engine); - - auto batchnorm_desc = dnnl::batch_normalization_forward::desc(dnnl::prop_kind::forward_inference, src_md, epsilon, - dnnl::normalization_flags::use_scale_shift | dnnl::normalization_flags::use_global_stats); - auto batchnorm_pd = dnnl::batch_normalization_forward::primitive_desc(batchnorm_desc, dnnl_engine); + auto batchnorm_pd = + dnnl::batch_normalization_forward::primitive_desc(dnnl_engine, dnnl::prop_kind::forward_inference, + src_md, dst_md, epsilon, flags); // If using GPU this will move the memory from the CPU to the GPU. batchnorm_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), batchnorm_pd.src_desc(), dnnl_engine); + batchnorm_scale_mem = sp.GetMemoryAndReshape(node.Input(IN_SCALE), scale_md, dnnl_engine); + batchnorm_bias_mem = sp.GetMemoryAndReshape(node.Input(IN_B), bias_md, dnnl_engine); + batchnorm_mean_mem = sp.GetMemoryAndReshape(node.Input(IN_MEAN), mean_md, dnnl_engine); + batchnorm_var_mem = sp.GetMemoryAndReshape(node.Input(IN_VAR), var_md, dnnl_engine); auto batchnorm_dst_mem = dnnl::memory(batchnorm_pd.dst_desc(), dnnl_engine); - auto concat_op = dnnl::concat(concat_pd); - sp.AddPrimitive(concat_op, {{DNNL_ARG_MULTIPLE_SRC, batchnorm_scale_mem}, - {DNNL_ARG_MULTIPLE_SRC+1, batchnorm_bias_mem}, - {DNNL_ARG_DST, batchnorm_scale_shift_mem}}); - auto batchnorm_op = dnnl::batch_normalization_forward(batchnorm_pd); sp.AddPrimitive(batchnorm_op, {{DNNL_ARG_SRC, batchnorm_src_mem}, {DNNL_ARG_MEAN, batchnorm_mean_mem}, {DNNL_ARG_VARIANCE, batchnorm_var_mem}, - {DNNL_ARG_SCALE_SHIFT, batchnorm_scale_shift_mem}, + {DNNL_ARG_SCALE, batchnorm_scale_mem}, + {DNNL_ARG_SHIFT, batchnorm_bias_mem}, {DNNL_ARG_DST, batchnorm_dst_mem}}); sp.SetMemory(node.Output(OUT_Y), batchnorm_dst_mem); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_binary.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_binary.cc index 6445aeaec8c..0d845ce2ebf 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_binary.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_binary.cc @@ -19,8 +19,8 @@ void DnnlBinary::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto src_0_ori_md = binary_src0_mem.get_desc(); auto src_1_ori_md = binary_src1_mem.get_desc(); - auto src_0_dims = src_0_ori_md.dims(); - auto src_1_dims = src_1_ori_md.dims(); + auto src_0_dims = src_0_ori_md.get_dims(); + auto src_1_dims = src_1_ori_md.get_dims(); if (src_0_dims.size() != src_1_dims.size()) { while (src_0_dims.size() < src_1_dims.size()) { src_0_dims.insert(src_0_dims.begin(), 1); @@ -42,8 +42,7 @@ void DnnlBinary::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any); - auto binary_d = dnnl::binary::desc(algo, src_0_md, src_1_md, dst_md); - auto binary_pd = dnnl::binary::primitive_desc(binary_d, eng); + auto binary_pd = dnnl::binary::primitive_desc(eng, algo, src_0_md, src_1_md, dst_md); auto binary_dst_mem = dnnl::memory(binary_pd.dst_desc(), eng); auto binary_prim = dnnl::binary(binary_pd); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.cc index 1a21d290e2b..9100b16377f 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.cc @@ -19,7 +19,7 @@ void DnnlCast::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto src_mem = sp.GetMemory(node.Input(IN_INPUT)); auto src_tag = node.Input(IN_INPUT).Format(); auto src_md = src_mem.get_desc(); - auto src_dims = src_md.dims(); + auto src_dims = src_md.get_dims(); // dst characteristics dnnl::memory::data_type dst_type; @@ -71,7 +71,7 @@ void DnnlCast::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { } // Generate the dst memory descriptor - auto dst_md = dnnl::memory::desc(src_md.dims(), dst_type, dst_tag); + auto dst_md = dnnl::memory::desc(src_md.get_dims(), dst_type, dst_tag); // Create the reorder primitive descriptor. auto reorder_pd = dnnl::reorder::primitive_desc(dnnl_engine, src_md, dnnl_engine, dst_md); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_concat.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_concat.cc index fcc72621b41..5ca4f24eef1 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_concat.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_concat.cc @@ -31,7 +31,7 @@ void DnnlConcat::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto axis = GetAxis(node, input_rank != -1 ? input_rank : 0); // Create primitive descriptor - auto concat_pd = dnnl::concat::primitive_desc(static_cast(axis), src_mds, dnnl_engine); + auto concat_pd = dnnl::concat::primitive_desc(dnnl_engine, static_cast(axis), src_mds); // Create primitive memory objects std::vector concat_src_mems; diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.cc index a076633ca8d..a9d2d3eb6f3 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.cc @@ -21,13 +21,13 @@ void DnnlConv::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto conv_src_mem = sp.GetMemory(node.Input(IN_X)); auto src_md = conv_src_mem.get_desc(); - src_md.data.format_kind = dnnl_format_kind_t::dnnl_format_kind_any; - auto src_dims = conv_src_mem.get_desc().dims(); + src_md = dnnl::memory::desc(src_md.get_dims(), src_md.get_data_type(), dnnl::memory::format_tag::any); + auto src_dims = conv_src_mem.get_desc().get_dims(); auto conv_weights_mem = sp.GetMemory(node.Input(IN_W)); auto weight_md = conv_weights_mem.get_desc(); - weight_md.data.format_kind = dnnl_format_kind_t::dnnl_format_kind_any; - auto weight_dims_original = conv_weights_mem.get_desc().dims(); + weight_md = dnnl::memory::desc(weight_md.get_dims(), weight_md.get_data_type(), dnnl::memory::format_tag::any); + auto weight_dims_original = conv_weights_mem.get_desc().get_dims(); dnnl::memory::dims weight_dims = weight_dims_original; bool bias_exists = node.Input(IN_B).Exists(); @@ -97,27 +97,20 @@ void DnnlConv::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { dnnl::primitive_attr attr; if (has_relu) { - const float ops_scale = 1.f; - const float ops_alpha = 0.f; - const float ops_beta = 0.f; dnnl::post_ops ops; - ops.append_eltwise(ops_scale, dnnl::algorithm::eltwise_relu, ops_alpha, ops_beta); + ops.append_eltwise(dnnl::algorithm::eltwise_relu, 0.f, 0.f); attr.set_post_ops(ops); } dnnl::convolution_forward::primitive_desc conv_pd; if (bias_exists) { - auto conv_desc = dnnl::convolution_forward::desc( - prop_kind, dnnl::algorithm::convolution_direct, - src_md, weight_md, bias_md, dst_md, - strides, dilations, padding_left, padding_right); - conv_pd = dnnl::convolution_forward::primitive_desc(conv_desc, attr, dnnl_engine); + conv_pd = dnnl::convolution_forward::primitive_desc(dnnl_engine, prop_kind, dnnl::algorithm::convolution_direct, + src_md, weight_md, bias_md, dst_md, strides, dilations, + padding_left, padding_right, attr); } else { - auto conv_desc = dnnl::convolution_forward::desc( - prop_kind, dnnl::algorithm::convolution_direct, - src_md, weight_md, dst_md, - strides, dilations, padding_left, padding_right); - conv_pd = dnnl::convolution_forward::primitive_desc(conv_desc, attr, dnnl_engine); + conv_pd = dnnl::convolution_forward::primitive_desc(dnnl_engine, prop_kind, dnnl::algorithm::convolution_direct, + src_md, weight_md, dst_md, strides, dilations, padding_left, + padding_right, attr); } // If using GPU this will move the memory from the CPU to the GPU. diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.cc index 1208f206d7f..d8a245b5f7f 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.cc @@ -49,15 +49,15 @@ void DnnlConvGrad::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto dy_mem = sp.GetMemory(node.Input(IN_DY)); auto dy_md = dy_mem.get_desc(); - auto dy_dims = dy_mem.get_desc().dims(); + auto dy_dims = dy_mem.get_desc().get_dims(); auto x_mem = sp.GetMemory(node.Input(IN_X)); auto x_md = x_mem.get_desc(); - auto x_dims = x_mem.get_desc().dims(); + auto x_dims = x_mem.get_desc().get_dims(); auto w_mem = sp.GetMemory(node.Input(IN_W)); auto w_md = w_mem.get_desc(); - auto w_dims_original = w_mem.get_desc().dims(); + auto w_dims_original = w_mem.get_desc().get_dims(); auto w_dims = w_dims_original; bool dx_required = node.Output(OUT_DX).Exists(); @@ -122,37 +122,39 @@ void DnnlConvGrad::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // Reproduce the forward convolution pd. dnnl::convolution_forward::primitive_desc conv_forward_pd; if (db_required) { - auto conv_forward_desc = dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, - dnnl::algorithm::convolution_direct, - fwd_x_md, w_md, fwd_b_md, fwd_y_md, - strides, dilations, padding_left, padding_right); - conv_forward_pd = dnnl::convolution_forward::primitive_desc(conv_forward_desc, dnnl_engine); + conv_forward_pd = dnnl::convolution_forward::primitive_desc(dnnl_engine, dnnl::prop_kind::forward_training, + dnnl::algorithm::convolution_direct, + fwd_x_md, w_md, fwd_b_md, fwd_y_md, + strides, dilations, padding_left, padding_right); } else { - auto conv_forward_desc = dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, - dnnl::algorithm::convolution_direct, - fwd_x_md, w_md, fwd_y_md, - strides, dilations, padding_left, padding_right); - conv_forward_pd = dnnl::convolution_forward::primitive_desc(conv_forward_desc, dnnl_engine); + + conv_forward_pd = dnnl::convolution_forward::primitive_desc(dnnl_engine, dnnl::prop_kind::forward_training, + dnnl::algorithm::convolution_direct, + fwd_x_md, w_md, fwd_y_md, + strides, dilations, padding_left, padding_right); } // Create the convolution backward data primitive desc - auto conv_backward_data_desc = dnnl::convolution_backward_data::desc(dnnl::algorithm::convolution_direct, - dx_md, w_md, dy_md, - strides, dilations, padding_left, padding_right); - auto conv_backward_data_pd = dnnl::convolution_backward_data::primitive_desc(conv_backward_data_desc, dnnl_engine, conv_forward_pd); + auto conv_backward_data_pd = + dnnl::convolution_backward_data::primitive_desc(dnnl_engine, dnnl::algorithm::convolution_direct, + dx_md, w_md, dy_md, strides, dilations, padding_left, + padding_right, conv_forward_pd); // Create the convolution backward weights primitve desc dnnl::convolution_backward_weights::primitive_desc conv_backward_weights_pd; if (db_required) { - auto conv_backward_weights_desc = dnnl::convolution_backward_weights::desc(dnnl::algorithm::convolution_direct, - x_md, dw_md, db_md, dy_md, - strides, dilations, padding_left, padding_right); - conv_backward_weights_pd = dnnl::convolution_backward_weights::primitive_desc(conv_backward_weights_desc, dnnl_engine, conv_forward_pd); + + conv_backward_weights_pd = + dnnl::convolution_backward_weights::primitive_desc( dnnl_engine, dnnl::algorithm::convolution_direct, + x_md, dw_md, db_md, dy_md, + strides, dilations, padding_left, + padding_right, conv_forward_pd); } else { - auto conv_backward_weights_desc = dnnl::convolution_backward_weights::desc(dnnl::algorithm::convolution_direct, - x_md, dw_md, dy_md, - strides, dilations, padding_left, padding_right); - conv_backward_weights_pd = dnnl::convolution_backward_weights::primitive_desc(conv_backward_weights_desc, dnnl_engine, conv_forward_pd); + conv_backward_weights_pd = + dnnl::convolution_backward_weights::primitive_desc( dnnl_engine, dnnl::algorithm::convolution_direct, + x_md, dw_md, dy_md, + strides, dilations, padding_left, padding_right, + conv_forward_pd); } // check if memory needs to be moved to GPU diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_dequantizelinear.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_dequantizelinear.cc index cde20fdaca2..074df058806 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_dequantizelinear.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_dequantizelinear.cc @@ -47,7 +47,7 @@ void DnnlDequantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& // Get descs auto x_md = x_mem.get_desc(); auto x_scale_md = x_scale_mem.get_desc(); - auto x_dims = x_md.dims().size(); + auto x_dims = x_md.get_dims().size(); // Fix scale dims int64_t axis = GetAxis(node, x_dims); @@ -65,11 +65,11 @@ void DnnlDequantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& } // Create dst mem - auto dst_md = dnnl::memory::desc(x_md.dims(), node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any); + auto dst_md = dnnl::memory::desc(x_md.get_dims(), node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any); dnnl::memory dst_mem; // If zero point exists and we are NOT dequantizing int32, then substract zp from x and scale - if (isZeroPointUseful && (x_mem.get_desc().data_type() != dnnl::memory::data_type::s32)) { + if (isZeroPointUseful && (x_mem.get_desc().get_data_type() != dnnl::memory::data_type::s32)) { // Get Zero point auto x_zp_mem = sp.GetMemory(node.Input(IN_X_ZERO_POINT)); // Get mds for operands @@ -84,8 +84,6 @@ void DnnlDequantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& Padd(&x_zp_md, static_cast(axis) + 1, x_dims); } - // Create binary desc - auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_sub, x_md, x_zp_md, dst_md); // Add post op scale dnnl::primitive_attr binary_attr; { @@ -94,7 +92,8 @@ void DnnlDequantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& binary_attr.set_post_ops(binary_ops); } // Add post op to scale result - auto binary_pd = dnnl::binary::primitive_desc(binary_d, binary_attr, dnnl_engine); + auto binary_pd = dnnl::binary::primitive_desc(dnnl_engine, dnnl::algorithm::binary_sub, + x_md, x_zp_md, dst_md, binary_attr); // Move to GPU if available x_zp_mem = sp.GetMemoryAndReshape(node.Input(IN_X_ZERO_POINT), x_zp_md, dnnl_engine); // Create primitive and set dst mem @@ -108,9 +107,9 @@ void DnnlDequantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& // If zp doesn't exists or we are dequantizing from int32, only need to scale } else { - // Create binary and primitive desc - auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_mul, x_md, x_scale_md, dst_md); - auto binary_pd = dnnl::binary::primitive_desc(binary_d, dnnl_engine); + // Create binary primitive desc + auto binary_pd = dnnl::binary::primitive_desc(dnnl_engine, dnnl::algorithm::binary_mul, + x_md, x_scale_md, dst_md); // Create primitive dst_mem = dnnl::memory(binary_pd.dst_desc(), dnnl_engine); @@ -133,8 +132,8 @@ bool DnnlDequantizeLinear::isZeroPointNonZero(dnnl::memory* zp_mem) { // Because zp will always be int8, uint8 or int32, this cast is always valid auto zp_data = static_cast(zp_mem->get_data_handle()); // Adjust the iteration num - auto topline = zp_mem->get_desc().dims().size(); - if (zp_mem->get_desc().data_type() == dnnl::memory::data_type::s32) { + auto topline = zp_mem->get_desc().get_dims().size(); + if (zp_mem->get_desc().get_data_type() == dnnl::memory::data_type::s32) { topline *= 4; } // ZP is either a scalar or a 1-D vector so iterate over all the dimensions @@ -150,7 +149,7 @@ bool DnnlDequantizeLinear::isZeroPointNonZero(dnnl::memory* zp_mem) { void DnnlDequantizeLinear::Padd(dnnl::memory::desc* target_md, size_t front_pad, size_t back_pad) { // Pads an input to broadcast the op correctly - auto target_dims = target_md->dims(); + auto target_dims = target_md->get_dims(); // Add front padding while (target_dims.size() < front_pad) { @@ -185,8 +184,8 @@ int64_t DnnlDequantizeLinear::GetAxis(DnnlNode& node, size_t x_dims) { void DnnlDequantizeLinear::ValidateDims(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // We only need to validate when zp is provided if (node.Input(IN_X_ZERO_POINT).Exists()) { - auto x_scale_dims = sp.GetMemory(node.Input(IN_X_SCALE)).get_desc().dims(); - auto x_zp_dims = sp.GetMemory(node.Input(IN_X_ZERO_POINT)).get_desc().dims(); + auto x_scale_dims = sp.GetMemory(node.Input(IN_X_SCALE)).get_desc().get_dims(); + auto x_zp_dims = sp.GetMemory(node.Input(IN_X_ZERO_POINT)).get_desc().get_dims(); if (x_zp_dims != x_scale_dims) { ORT_THROW("x_scale and x_zero_point dimensions does not match"); @@ -200,7 +199,7 @@ void DnnlDequantizeLinear::ValidateType(DnnlSubgraphPrimitive& sp, DnnlNode& nod auto x_md = sp.GetMemory(node.Input(IN_X)).get_desc(); auto x_zp_md = sp.GetMemory(node.Input(IN_X_ZERO_POINT)).get_desc(); - if (x_md.data_type() != x_zp_md.data_type()) { + if (x_md.get_data_type() != x_zp_md.get_data_type()) { ORT_THROW("x and x_zero_point have different datatypes"); } } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_dynamicquantizelinear.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_dynamicquantizelinear.cc index 1d24e863297..b62cd7cb628 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_dynamicquantizelinear.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_dynamicquantizelinear.cc @@ -4,6 +4,8 @@ #include "dnnl_dynamicquantizelinear.h" #include "dnnl_subgraph.h" #include "dnnl_subgraph_primitive.h" +#include "dnnl_util.h" + namespace onnxruntime { namespace ort_dnnl { @@ -23,7 +25,7 @@ void DnnlDynamicQuantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlN // Get src mem auto x_mem = sp.GetMemory(node.Input(IN_X)); auto x_md = x_mem.get_desc(); - auto x_size = x_md.dims().size(); + auto x_size = x_md.get_dims().size(); auto x_format = sp.GetDnnlFormat(x_size); x_mem = sp.GetMemoryAndReshape(node.Input(IN_X), x_md, eng); @@ -31,10 +33,8 @@ void DnnlDynamicQuantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlN dnnl::memory::dims one_dim(x_size, 1); // Y_SCALE COMPUTATION - // Create descriptor for reduction max and min - auto y_scale_md = dnnl::memory::desc(one_dim, x_md.data_type(), x_format); - auto max_reduction_d = dnnl::reduction::desc(dnnl::algorithm::reduction_max, x_md, y_scale_md, 0.f, 0.f); - auto min_reduction_d = dnnl::reduction::desc(dnnl::algorithm::reduction_min, x_md, y_scale_md, 0.f, 0.f); + // Create descriptor for y_scale + auto y_scale_md = dnnl::memory::desc(one_dim, x_md.get_data_type(), x_format); // Fill memory with 0's, needed for min and max binary auto zero_mem = dnnl::memory(y_scale_md, eng); @@ -50,7 +50,7 @@ void DnnlDynamicQuantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlN // y_scale = x_max - x_min calc_y_scale.append_binary(dnnl::algorithm::binary_sub, y_scale_md); // y_scale =/ 255 - calc_y_scale.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f / 255.0f, 0.0f); + calc_y_scale.append_eltwise(dnnl::algorithm::eltwise_linear, 1.0f / 255.0f, 0.0f); max_reduction_attr.set_post_ops(calc_y_scale); } @@ -63,8 +63,11 @@ void DnnlDynamicQuantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlN } // Create reduction primitive - auto max_reduction_prim = dnnl::reduction(dnnl::reduction::primitive_desc(max_reduction_d, max_reduction_attr, eng)); - auto min_reduction_prim = dnnl::reduction(dnnl::reduction::primitive_desc(min_reduction_d, min_reduction_attr, eng)); + auto max_reduction_prim = dnnl::reduction({eng, dnnl::algorithm::reduction_max, + x_md, y_scale_md, 0.f, 0.f, max_reduction_attr}); + auto min_reduction_prim = dnnl::reduction( {eng, dnnl::algorithm::reduction_min, + x_md, y_scale_md, 0.f, 0.f, min_reduction_attr}); + // Create y_scale and min memory auto y_scale_mem = dnnl::memory(y_scale_md, eng); @@ -85,43 +88,48 @@ void DnnlDynamicQuantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlN // Y_ZERO_POINT COMPUTATION // Create memory and primitive descriptors auto y_zp_md = dnnl::memory::desc(one_dim, dnnl::memory::data_type::u8, x_format); - auto zp_prim_d = dnnl::binary::desc(dnnl::algorithm::binary_div, y_scale_md, y_scale_md, y_zp_md); // Add round and clip post ops dnnl::primitive_attr zp_prim_attr; { - zp_prim_attr.set_scales(DNNL_ARG_SRC_0, 0, {-1.0f}); dnnl::post_ops div_saturate_round; - div_saturate_round.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f); + div_saturate_round.append_eltwise(dnnl::algorithm::eltwise_round, 0.0f, 0.0f); zp_prim_attr.set_post_ops(div_saturate_round); } + // Set the value to scale DNNL_ARG_SRC_0 with mask 0 + zp_prim_attr.set_scales_mask(DNNL_ARG_SRC_0, 0); + // Create the memory object related to the scale + auto scale_mem = dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, eng); + // Write the alpha value into the memory object + sp.WriteToDnnlMemory(scale_mem, {-1.0f}); // Create primitives - auto zp_prim_pd = dnnl::binary::primitive_desc(zp_prim_d, zp_prim_attr, eng); + auto zp_prim_pd = dnnl::binary::primitive_desc( eng, dnnl::algorithm::binary_div, + y_scale_md, y_scale_md, y_zp_md, zp_prim_attr); auto zp_prim = dnnl::binary(zp_prim_pd); // Create zp memory dst auto y_zp_mem = dnnl::memory(zp_prim_pd.dst_desc(), eng); // Calc zp - sp.AddPrimitive(zp_prim,{{DNNL_ARG_SRC_0, min_reduction_mem}, - {DNNL_ARG_SRC_1, y_scale_mem}, - {DNNL_ARG_DST, y_zp_mem}}); + sp.AddPrimitive(zp_prim,{ {DNNL_ARG_SRC_0, min_reduction_mem}, + {DNNL_ARG_SRC_1, y_scale_mem}, + {DNNL_ARG_DST, y_zp_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, scale_mem}}); // Y COMPUTATION // Create y md and binary desc - auto y_md = dnnl::memory::desc(x_md.dims(), dnnl::memory::data_type::u8, x_format); - auto y_bin_d = dnnl::binary::desc(dnnl::algorithm::binary_div, x_mem.get_desc(), y_scale_mem.get_desc(), y_md); + auto y_md = dnnl::memory::desc(x_md.get_dims(), dnnl::memory::data_type::u8, x_format); // Add post ops dnnl::primitive_attr y_bin_attr; { dnnl::post_ops round_add; - round_add.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f); + round_add.append_eltwise(dnnl::algorithm::eltwise_round, 0.0f, 0.0f); round_add.append_binary(dnnl::algorithm::binary_add, y_zp_mem.get_desc()); y_bin_attr.set_post_ops(round_add); } // Create binary primitive with post ops - auto y_pd = dnnl::binary::primitive_desc(y_bin_d, y_bin_attr, eng); + auto y_pd = dnnl::binary::primitive_desc(eng, dnnl::algorithm::binary_div, x_mem.get_desc(), y_scale_mem.get_desc(), y_md, y_bin_attr); auto y_prim = dnnl::binary(y_pd); // Create y_dst mem auto y_mem = dnnl::memory(y_pd.dst_desc(), eng); @@ -139,8 +147,8 @@ void DnnlDynamicQuantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlN //change md to targeted data type of cast op dst dnnl::memory::desc DnnlDynamicQuantizeLinear::ChangeMemoryDescDataType(dnnl::memory::desc md, dnnl::memory::data_type dt) { - auto dims = md.dims(); - auto strides = md.data.format_desc.blocking.strides; + auto dims = md.get_dims(); + auto strides = md.get_strides(); dnnl::memory::dims strides_vec; for (size_t i = 0; i < dims.size(); i++) { strides_vec.push_back(strides[i]); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc index 4d825474d8b..a2c8a02f42f 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc @@ -35,17 +35,25 @@ void DnnlElementwise::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) } break; } + case dnnl::algorithm::eltwise_soft_relu: { + if (node.OpType() == "Softplus") { + requires_alpha = true; + alpha = 1.0f; + } + break; + } default: alpha = 0.0; } + // Generate a dst_md from the src data + auto dst_md = dnnl::memory::desc(src_md.get_dims(), src_md.get_data_type(), dnnl::memory::format_tag::any); + dnnl::eltwise_forward::primitive_desc elementwise_pd; if (requires_alpha) { - auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, src_md, alpha); - elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine); + elementwise_pd = dnnl::eltwise_forward::primitive_desc(dnnl_engine, dnnl::prop_kind::forward_inference, algo, src_md, dst_md, alpha); } else { - auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, src_md); - elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine); + elementwise_pd = dnnl::eltwise_forward::primitive_desc(dnnl_engine, dnnl::prop_kind::forward_inference, algo, src_md, dst_md); } // If using GPU this will move the memory from the CPU to the GPU. diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_gelu.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gelu.cc index 6e1b6fcd5a6..d0df371b488 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_gelu.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gelu.cc @@ -29,8 +29,8 @@ void DnnlGelu::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto src0_ori_md = src_mem.get_desc(); auto src1_ori_md = bias_mem.get_desc(); - auto src0_dims = src0_ori_md.dims(); - auto src1_dims = src1_ori_md.dims(); + auto src0_dims = src0_ori_md.get_dims(); + auto src1_dims = src1_ori_md.get_dims(); if (src0_dims.size() != src1_dims.size()) { while (src0_dims.size() < src1_dims.size()) { src0_dims.insert(src0_dims.begin(), 1); @@ -53,13 +53,12 @@ void DnnlGelu::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { dnnl::primitive_attr attr; dnnl::post_ops ops; dnnl::algorithm algo = dnnl_util::OrtOperatorToDnnlAlgorithm(node.OpType()); - ops.append_eltwise(1.0f, algo, 1.0f, 1.0f); + ops.append_eltwise(algo, 1.0f, 1.0f); attr.set_post_ops(ops); auto dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any); - auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_md, src1_md, dst_md); - auto binary_pd = dnnl::binary::primitive_desc(binary_d, attr, dnnl_engine); + auto binary_pd = dnnl::binary::primitive_desc(dnnl_engine, dnnl::algorithm::binary_add, src0_md, src1_md, dst_md, attr); dst_mem = dnnl::memory(binary_pd.dst_desc(), dnnl_engine); auto binary_prim = dnnl::binary(binary_pd); @@ -68,9 +67,12 @@ void DnnlGelu::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { {DNNL_ARG_SRC_1, bias_mem}, {DNNL_ARG_DST, dst_mem}}); } else { + auto dst_md = dnnl::memory::desc( src_mem.get_desc().get_dims(), + node.Output(OUT_Y).Type(), + dnnl::memory::format_tag::any); dnnl::algorithm algo = dnnl_util::OrtOperatorToDnnlAlgorithm(node.OpType()); - auto gelu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, gelu_src_mem.get_desc()); - auto gelu_pd = dnnl::eltwise_forward::primitive_desc(gelu_desc, dnnl_engine); + auto gelu_pd = dnnl::eltwise_forward::primitive_desc( dnnl_engine, dnnl::prop_kind::forward_inference, algo, + gelu_src_mem.get_desc(), dst_md); // If using GPU this will move the memory from the CPU to the GPU. gelu_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), gelu_pd.src_desc(), dnnl_engine); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc index 6178bbab85b..364ebdf5f22 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc @@ -4,6 +4,7 @@ #include "dnnl_gemm.h" #include "dnnl_subgraph.h" #include "dnnl_subgraph_primitive.h" +#include "dnnl_util.h" namespace onnxruntime { namespace ort_dnnl { @@ -56,8 +57,8 @@ OneDNN algorithm: void DnnlGemm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto eng = sp.GetEngine(); - auto a_dims = sp.GetMemory(node.Input(IN_A)).get_desc().dims(); - auto b_dims = sp.GetMemory(node.Input(IN_B)).get_desc().dims(); + auto a_dims = sp.GetMemory(node.Input(IN_A)).get_desc().get_dims(); + auto b_dims = sp.GetMemory(node.Input(IN_B)).get_desc().get_dims(); bool input_c_exists = node.Input(IN_C).Exists(); @@ -92,14 +93,17 @@ void DnnlGemm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { dnnl::primitive_attr matmul_attr; // scale the output from MatMul to alpha float alpha = GetAlpha(node); - std::vector alphaScale({alpha}); - matmul_attr.set_output_scales(0, alphaScale); + // Set the value to scale DNNL_ARG_SRC with mask 0 + matmul_attr.set_scales_mask(DNNL_ARG_SRC, 0); + // Create the memory object related to the scale + auto alpha_mem = dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, eng); + // Write the alpha value into the memory object + sp.WriteToDnnlMemory(alpha_mem, {alpha}); auto matmul_dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), {N, 1}); - auto matmul_d = dnnl::matmul::desc(a_md, b_md, matmul_dst_md); dnnl::matmul::primitive_desc matmul_pd; - matmul_pd = dnnl::matmul::primitive_desc(matmul_d, matmul_attr, eng); + matmul_pd = dnnl::matmul::primitive_desc(eng, a_md, b_md, matmul_dst_md, matmul_attr); auto matmul_a_mem = sp.GetMemoryAndReshape(node.Input(IN_A), matmul_pd.src_desc(), eng, transA); auto matmul_b_mem = sp.GetMemoryAndReshape(node.Input(IN_B), matmul_pd.weights_desc(), eng, transB); @@ -111,12 +115,14 @@ void DnnlGemm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { args.insert({DNNL_ARG_SRC, matmul_a_mem}); args.insert({DNNL_ARG_WEIGHTS, matmul_b_mem}); args.insert({DNNL_ARG_DST, gemm_dst_mem}); + // Set alpha_mem to scale the output + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, alpha_mem}); sp.AddPrimitive(matmul_op, args); if (input_c_exists) { auto c_original_md = sp.GetMemory(node.Input(IN_C)).get_desc(); - auto c_dims = c_original_md.dims(); + auto c_dims = c_original_md.get_dims(); if (c_dims.size() != a_dims.size()) { while (c_dims.size() < a_dims.size()) { c_dims.insert(c_dims.begin(), 1); @@ -127,14 +133,18 @@ void DnnlGemm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto y_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any); - auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_add, matmul_pd.dst_desc(), c_md, y_md); - // Scale input C by beta before adding it to the MatMul output. dnnl::primitive_attr binary_attr; float beta = GetBeta(node); - binary_attr.set_scales(DNNL_ARG_SRC_1, 0, {beta}); + // Set the value to scale DNNL_ARG_SRC_1 with mask 0 + binary_attr.set_scales_mask(DNNL_ARG_SRC_1, 0); + // Create the memory object related to the scale + auto beta_mem = dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, eng); + // Write the alpha value into the memory object + sp.WriteToDnnlMemory(beta_mem, {beta}); - auto binary_pd = dnnl::binary::primitive_desc(binary_d, binary_attr,eng); + auto binary_pd = dnnl::binary::primitive_desc(eng, dnnl::algorithm::binary_add, + matmul_pd.dst_desc(), c_md, y_md, binary_attr); auto binary_c_mem = sp.GetMemoryAndReshape(node.Input(IN_C), binary_pd.src1_desc(), eng); @@ -142,7 +152,8 @@ void DnnlGemm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { sp.AddPrimitive(binary_op, {{DNNL_ARG_SRC_0, gemm_dst_mem}, {DNNL_ARG_SRC_1, binary_c_mem}, - {DNNL_ARG_DST, gemm_dst_mem}}); + {DNNL_ARG_DST, gemm_dst_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, beta_mem}}); } sp.SetMemory(node.Output(OUT_Y), gemm_dst_mem); } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_layernorm.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_layernorm.cc index 7d3d26bc972..1e21a955987 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_layernorm.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_layernorm.cc @@ -94,7 +94,7 @@ void DnnlLayerNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { src_mem = sp.GetMemoryAndReshape(node.Input(IN_INPUT), src_md, dnnl_engine); // Make dst desc, must be same as src - auto dst_md = dnnl::memory::desc(src_md.dims(), node.Output(OUT_OUTPUT).Type(), dnnl::memory::format_tag::any); + auto dst_md = dnnl::memory::desc(src_md.get_dims(), node.Output(OUT_OUTPUT).Type(), dnnl::memory::format_tag::any); // Add src + skip { @@ -105,8 +105,7 @@ void DnnlLayerNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto skip_mem = sp.GetMemoryAndReshape(node.Input(IN_SKIP), skip_md, dnnl_engine); // Create and add primitive - auto add_skip_d = dnnl::binary::desc(dnnl::algorithm::binary_add, src_md, skip_md, dst_md); - auto add_skip_pd = dnnl::binary::primitive_desc(add_skip_d, dnnl_engine); + auto add_skip_pd = dnnl::binary::primitive_desc(dnnl_engine, dnnl::algorithm::binary_add, src_md, skip_md, dst_md); auto add_skip = dnnl::binary(add_skip_pd); std::unordered_map add_skip_mem_map({{DNNL_ARG_SRC_0, src_mem}, {DNNL_ARG_SRC_1, skip_mem}, {DNNL_ARG_DST, src_mem}}); sp.AddPrimitive(add_skip, add_skip_mem_map); @@ -121,9 +120,9 @@ void DnnlLayerNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // Move the bias to GPU if needed auto bias_mem = sp.GetMemoryAndReshape(node.Input(IN_SLN_BIAS), bias_md, dnnl_engine); // Get bias dims - auto bias_dims = bias_md.dims(); + auto bias_dims = bias_md.get_dims(); // Get src dims - auto src_dims = src_md.dims(); + auto src_dims = src_md.get_dims(); // To follow the spec means our bias will always have less dimensions that our input // so we add the extra dimensions, reshape it and let OneDNN broadcast the value @@ -133,8 +132,7 @@ void DnnlLayerNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { bias_md = bias_md.reshape(bias_dims); // Create and add primitive - auto add_bias_d = dnnl::binary::desc(dnnl::algorithm::binary_add, src_md, bias_md, dst_md); - auto add_bias_pd = dnnl::binary::primitive_desc(add_bias_d, dnnl_engine); + auto add_bias_pd = dnnl::binary::primitive_desc(dnnl_engine, dnnl::algorithm::binary_add, src_md, bias_md, dst_md); auto add_bias = dnnl::binary(add_bias_pd); std::unordered_map add_bias_mem_map({{DNNL_ARG_SRC_0, src_mem}, {DNNL_ARG_SRC_1, bias_mem}, {DNNL_ARG_DST, src_mem}}); sp.AddPrimitive(add_bias, add_bias_mem_map); @@ -174,10 +172,8 @@ void DnnlLayerNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // Get epsilon to avoid zero division float epsilon = GetEpsilon(node); - // Operation desciptor - auto lnorm_desc = dnnl::layer_normalization_forward::desc(prop_kind, src_md, epsilon, op_flags); // Primitive desciptor - auto lnorm_pd = dnnl::layer_normalization_forward::primitive_desc(lnorm_desc, dnnl_engine); + auto lnorm_pd = dnnl::layer_normalization_forward::primitive_desc(dnnl_engine, prop_kind, src_md, src_md, epsilon, op_flags); // Primitive auto lnorm_prim = dnnl::layer_normalization_forward(lnorm_pd); @@ -190,8 +186,8 @@ void DnnlLayerNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { if (node.Input(scale_pos).Type() != dnnl::memory::data_type::f32) { // casting to fp32 if input with other data type auto gamma_md = gamma_mem.get_desc(); - auto dims = gamma_md.dims(); - auto strides = gamma_md.data.format_desc.blocking.strides; + auto dims = gamma_md.get_dims(); + auto strides = gamma_md.get_strides(); dnnl::memory::dims gamma_strides_vec; for (size_t i = 0; i < dims.size(); i++) { gamma_strides_vec.push_back(strides[i]); @@ -210,8 +206,8 @@ void DnnlLayerNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { if (node.Input(shift_pos).Type() != dnnl::memory::data_type::f32) { // casting to fp32 if input with other data type auto beta_md = beta_mem.get_desc(); - auto dims = beta_md.dims(); - auto strides = beta_md.data.format_desc.blocking.strides; + auto dims = beta_md.get_dims(); + auto strides = beta_md.get_strides(); dnnl::memory::dims beta_strides_vec; for (size_t i = 0; i < dims.size(); i++) { beta_strides_vec.push_back(strides[i]); @@ -249,7 +245,7 @@ void DnnlLayerNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { void DnnlLayerNorm::ValidateDims(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // Get input and evaluate - auto input_dims = sp.GetMemory(node.Input(IN_INPUT)).get_desc().dims(); + auto input_dims = sp.GetMemory(node.Input(IN_INPUT)).get_desc().get_dims(); auto input_dims_size = input_dims.size(); // Check the inputs are supported by OneDNN, this is mandatory since sometimes @@ -269,14 +265,14 @@ void DnnlLayerNorm::ValidateDims(DnnlSubgraphPrimitive& sp, DnnlNode& node) { } // Get skip and evaluate - auto skip_dims = sp.GetMemory(node.Input(IN_SKIP)).get_desc().dims(); + auto skip_dims = sp.GetMemory(node.Input(IN_SKIP)).get_desc().get_dims(); if (input_dims != skip_dims) { ORT_THROW("Input and skip dimmentions do not match"); } // Check if bias was provided and evaluate if (node.Input(IN_SLN_BIAS).Exists()) { - auto bias_dims = sp.GetMemory(node.Input(IN_SLN_BIAS)).get_desc().dims(); + auto bias_dims = sp.GetMemory(node.Input(IN_SLN_BIAS)).get_desc().get_dims(); if (bias_dims.size() != 1) { ORT_THROW("Bias is expected to have 1 dimension, got ", bias_dims.size()); } @@ -297,7 +293,7 @@ void DnnlLayerNorm::ValidateDims(DnnlSubgraphPrimitive& sp, DnnlNode& node) { } // Get gamma and evaluate - auto gamma_dims = sp.GetMemory(node.Input(gamma_pos)).get_desc().dims(); + auto gamma_dims = sp.GetMemory(node.Input(gamma_pos)).get_desc().get_dims(); if (gamma_dims.size() != 1) { ORT_THROW("Gamma is expected to have 1 dimension, got ", gamma_dims.size()); } @@ -307,7 +303,7 @@ void DnnlLayerNorm::ValidateDims(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // Check if shift was provided and evaluate if (node.Input(shift_pos).Exists()) { - auto beta_dims = sp.GetMemory(node.Input(shift_pos)).get_desc().dims(); + auto beta_dims = sp.GetMemory(node.Input(shift_pos)).get_desc().get_dims(); if (beta_dims.size() != 1) { ORT_THROW("Beta is expected to have 1 dimension, got ", beta_dims.size()); } @@ -334,7 +330,7 @@ dnnl::memory DnnlLayerNorm::CastAndTransformMemory(DnnlSubgraphPrimitive& sp, dn // Make a new memory descriptor based on the source descriptor and given destination dataype and strides auto src_md = src_mem.get_desc(); - dnnl::memory::desc dst_md = dnnl::memory::desc(src_md.dims(), dst_datatype, dst_strides); + dnnl::memory::desc dst_md = dnnl::memory::desc(src_md.get_dims(), dst_datatype, dst_strides); dst_mem = dnnl::memory(dst_md, eng); // Reorder source memory to destination memory as per the given dataype and strides diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_lrn.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_lrn.cc index c44e4772680..16f795127d7 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_lrn.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_lrn.cc @@ -24,17 +24,35 @@ void DnnlLrn::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto lrn_src_mem = sp.GetMemory(node.Input(IN_X)); auto lrn_src_md = lrn_src_mem.get_desc(); + // Create a dst_md from src_md + auto dst_md = dnnl::memory::desc(lrn_src_md.get_dims(), lrn_src_md.get_data_type(), dnnl::memory::format_tag::any); - auto lrn_desc = dnnl::lrn_forward::desc(dnnl::prop_kind::forward_scoring, dnnl::algorithm::lrn_across_channels, lrn_src_md, size, alpha, beta, bias); - auto lrn_pd = dnnl::lrn_forward::primitive_desc(lrn_desc, dnnl_engine); + // Define prop kind according to training status + dnnl::prop_kind prop_kind; +#ifdef ENABLE_TRAINING + prop_kind = dnnl::prop_kind::forward_training; +#else + prop_kind = dnnl::prop_kind::forward_inference; +#endif // ENABLE_TRAINING + + auto lrn_pd = dnnl::lrn_forward::primitive_desc(dnnl_engine, prop_kind, dnnl::algorithm::lrn_across_channels, + lrn_src_md, dst_md, size, alpha, beta, bias); // If using GPU this will move the memory from the CPU to the GPU. lrn_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), lrn_pd.src_desc(), dnnl_engine); auto lrn_dst_mem = dnnl::memory(lrn_pd.dst_desc(), dnnl_engine); auto lrn_op = dnnl::lrn_forward(lrn_pd); +#ifdef ENABLE_TRAINING + auto workspace_mem = dnnl::memory(lrn_pd.workspace_desc(), dnnl_engine); + + sp.AddPrimitive(lrn_op, {{DNNL_ARG_SRC, lrn_src_mem}, + {DNNL_ARG_WORKSPACE, workspace_mem}, + {DNNL_ARG_DST, lrn_dst_mem}}); +#else sp.AddPrimitive(lrn_op, {{DNNL_ARG_SRC, lrn_src_mem}, - {DNNL_ARG_DST, lrn_dst_mem}}); + {DNNL_ARG_DST, lrn_dst_mem}}); +#endif // ENABLE_TRAINING sp.SetMemory(node.Output(OUT_Y), lrn_dst_mem); } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc index 49b7094559b..8ac0d37f88a 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc @@ -61,8 +61,8 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { alpha = GetAlpha(node); } - auto src_dims = sp.GetMemory(node.Input(IN_A)).get_desc().dims(); - auto weights_dims = sp.GetMemory(node.Input(IN_B)).get_desc().dims(); + auto src_dims = sp.GetMemory(node.Input(IN_A)).get_desc().get_dims(); + auto weights_dims = sp.GetMemory(node.Input(IN_B)).get_desc().get_dims(); // If this is required for transposed inputs, then this will be done later on in the code. @@ -190,7 +190,7 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // Handle Binary post ops including the input memory if (binary_ops.count(post_ops[i]) != 0) { auto ori_binary_md = sp.GetMemory(node.Input(IN_BINARY_0 + binary_count).Name()).get_desc(); - auto ori_binary_dims = ori_binary_md.dims(); + auto ori_binary_dims = ori_binary_md.get_dims(); auto binary_mem_dims = ori_binary_dims; if (ori_binary_dims.size() != output_shape.size()) { if (ori_binary_dims.size() > output_shape.size()) { @@ -225,25 +225,29 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { post_op_alpha = GetFloatAttr(node, "alpha", /*default_alpha*/ 1.0f); break; } + case dnnl::algorithm::eltwise_soft_relu: { + if (post_ops[i] == "Softplus") { + post_op_alpha = 1.0f; + } + break; + } default: post_op_alpha = 0.0; } - ops.append_eltwise(1.0f, algo, post_op_alpha, 0.0f); + ops.append_eltwise(algo, post_op_alpha, 0.0f); } } attr.set_post_ops(ops); } if (is_fusedmatmul) { - // Set the scaling of output as a post op in the primitive attribute, taking the value from alpha attribute - std::vector alphaScale({alpha}); - attr.set_output_scales(0, alphaScale); + // Set the value to scale DNNL_ARG_SRC with mask 0 + attr.set_scales_mask(DNNL_ARG_SRC, 0); } auto dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any); - auto matmul_d = dnnl::matmul::desc(src_md, weights_md, dst_md); - auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, eng); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, src_md, weights_md, dst_md, attr); dnnl::memory matmul_src_mem, matmul_weights_mem; auto matmul_dst_mem = dnnl::memory(matmul_pd.dst_desc(), eng); @@ -265,6 +269,15 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { {DNNL_ARG_WEIGHTS, matmul_weights_mem}, {DNNL_ARG_DST, matmul_dst_mem}}); + if (is_fusedmatmul) { + // Create the memory object related to the scale + auto alpha_mem = dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, eng); + // Write the alpha value into the memory object + sp.WriteToDnnlMemory(alpha_mem, {alpha}); + // Set alpha_mem to scale the output + mem_map.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, alpha_mem}); + } + // add to memory map with extra third input if fused with add if (has_postop_fusion) { // add to memory map for extra binary inputs diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul_integer.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul_integer.cc index 7c92243f986..ffa146298e2 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul_integer.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul_integer.cc @@ -38,8 +38,8 @@ void DnnlMatMulInteger::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& nod } } - auto src_dims = sp.GetMemory(node.Input(IN_A)).get_desc().dims(); - auto weights_dims = sp.GetMemory(node.Input(IN_B)).get_desc().dims(); + auto src_dims = sp.GetMemory(node.Input(IN_A)).get_desc().get_dims(); + auto weights_dims = sp.GetMemory(node.Input(IN_B)).get_desc().get_dims(); if (src_dims.size() != weights_dims.size()) { while (src_dims.size() < weights_dims.size()) { @@ -70,11 +70,11 @@ void DnnlMatMulInteger::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& nod bool has_b_zero_point = node.Input(IN_B_ZERO_POINT).Name() != ""; if (has_a_zero_point) { - matmul_attr.set_zero_points(DNNL_ARG_SRC, /* mask */ 0, {DNNL_RUNTIME_S32_VAL}); + matmul_attr.set_zero_points_mask(DNNL_ARG_SRC, /* mask */ 0); } if (has_b_zero_point) { - matmul_attr.set_zero_points(DNNL_ARG_WEIGHTS, /* mask */ 0, {DNNL_RUNTIME_S32_VAL}); + matmul_attr.set_zero_points_mask(DNNL_ARG_WEIGHTS, /* mask */ 0); } /* @@ -94,7 +94,7 @@ void DnnlMatMulInteger::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& nod // Handle Binary post ops including the input memory if (binary_ops.count(post_ops[i]) != 0) { auto ori_binary_md = sp.GetMemory(node.Input(IN_BINARY_0 + binary_count).Name()).get_desc(); - auto ori_binary_dims = ori_binary_md.dims(); + auto ori_binary_dims = ori_binary_md.get_dims(); auto binary_mem_dims = ori_binary_dims; if (ori_binary_dims.size() != output_shape.size()) { if (ori_binary_dims.size() > output_shape.size()) { @@ -129,17 +129,22 @@ void DnnlMatMulInteger::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& nod post_op_alpha = GetFloatAttr(node, "alpha", /*default_alpha*/ 1.0f); break; } + case dnnl::algorithm::eltwise_soft_relu: { + if (post_ops[i] == "Softplus") { + post_op_alpha = 1.0f; + } + break; + } default: post_op_alpha = 0.0; } - ops.append_eltwise(1.0f, algo, post_op_alpha, 0.0f); + ops.append_eltwise(algo, post_op_alpha, 0.0f); } } matmul_attr.set_post_ops(ops); } - auto matmul_d = dnnl::matmul::desc(src_md, weights_md, dst_md); - auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, matmul_attr, eng); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, src_md, weights_md, dst_md, matmul_attr); auto matmul_src_mem = sp.GetMemoryAndReshape(node.Input(IN_A), matmul_pd.src_desc(), eng); auto matmul_weights_mem = sp.GetMemoryAndReshape(node.Input(IN_B), matmul_pd.weights_desc(), eng); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc index 341868a3c70..32b9c64a920 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc @@ -22,9 +22,9 @@ void DnnlPool::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto pool_src_mem = sp.GetMemory(node.Input(IN_X)); #endif // ENABLE_TRAINING auto src_md = pool_src_mem.get_desc(); - auto src_dims = pool_src_mem.get_desc().dims(); + auto src_dims = pool_src_mem.get_desc().get_dims(); - #ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING auto prop_kind = dnnl::prop_kind::forward; #else auto prop_kind = dnnl::prop_kind::forward_inference; @@ -43,20 +43,16 @@ void DnnlPool::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto strides = GetStrides(node, shape); auto dst_mem_dims = InferOutputDims(node, src_dims, kernel_shape, strides); - dnnl::memory::desc dst_md = dnnl::memory::desc(dst_mem_dims, node.Input(IN_X).Type(), dnnl::memory::format_tag::any); + dnnl::memory::desc dst_md = dnnl::memory::desc(dst_mem_dims, node.Input(OUT_Y).Type(), dnnl::memory::format_tag::any); auto padding = InferPadding(node, src_dims, kernel_shape, strides); auto padding_left = GetPaddingLeft(padding); auto padding_right = GetPaddingRight(padding); + auto dilation = dnnl::memory::dims(kernel_shape.size(), 0); - - auto pool_desc = dnnl::pooling_forward::desc(prop_kind, algo, - src_md, dst_md, - strides, kernel_shape, - padding_left, padding_right); - - auto pool_pd = dnnl::pooling_forward::primitive_desc(pool_desc, dnnl_engine); + auto pool_pd = dnnl::pooling_forward::primitive_desc(dnnl_engine, prop_kind, algo, src_md, dst_md, strides, + kernel_shape, dilation, padding_left, padding_right); #ifndef ENABLE_TRAINING // If using GPU this will move the memory from the CPU to the GPU. diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_poolgrad.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_poolgrad.cc index 946d5a5543f..301de8ee3e1 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_poolgrad.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_poolgrad.cc @@ -59,7 +59,7 @@ void DnnlPoolGrad::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto dy_mem = sp.GetMemory(node.Input(IN_DY)); auto dy_md = dy_mem.get_desc(); - auto dy_dims = dy_mem.get_desc().dims(); + auto dy_dims = dy_mem.get_desc().get_dims(); dnnl::memory indices_mem; dnnl::memory::desc indices_md; @@ -69,7 +69,7 @@ void DnnlPoolGrad::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { if (maxpoolgrad_optype) { indices_mem = sp.GetMemory(node.Input(IN_INDICES)); indices_md = indices_mem.get_desc(); - indices_dims = indices_mem.get_desc().dims(); + indices_dims = indices_mem.get_desc().get_dims(); } auto dx_dims = node.Output(OUT_DX).Dim(); @@ -92,15 +92,15 @@ void DnnlPoolGrad::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { } } - dnnl::pooling_forward::desc pool_forward_desc(dnnl::prop_kind::forward, algo, - fwd_dx_md, dy_md, - strides, kernel_shape, - padding_left, padding_right); - dnnl::pooling_forward::primitive_desc pool_forward_pd(pool_forward_desc, dnnl_engine); + // Dilatation of 1 + auto dilatation = dnnl::memory::dims(kernel_shape.size(), 1); - dnnl::pooling_backward::desc pool_backword_desc(algo, dx_md, dy_md, - strides, kernel_shape, padding_left, padding_right); - dnnl::pooling_backward::primitive_desc pool_backward_pd(pool_backword_desc, dnnl_engine, pool_forward_pd); + + dnnl::pooling_forward::primitive_desc pool_forward_pd(dnnl_engine, dnnl::prop_kind::forward, algo, fwd_dx_md, dy_md, + strides, kernel_shape, dilatation, padding_left, padding_right); + + dnnl::pooling_backward::primitive_desc pool_backward_pd(dnnl_engine, algo, dx_md, dy_md, strides, kernel_shape, + dilatation, padding_left, padding_right, pool_forward_pd); dnnl::pooling_backward pool_backward_op(pool_backward_pd); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pow.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pow.cc index 470f30e551f..ccc42ef6a77 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pow.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pow.cc @@ -44,9 +44,11 @@ void DnnlPow::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { ORT_THROW("Pow exponent data type not supported"); } + auto dst_md = dnnl::memory::desc(src_md.get_dims(), src_md.get_data_type(), dnnl::memory::format_tag::any); + // DNNL eltwise_pow is defined as alpha*x^beta. We don't use alpha so it is hard coded to 1.0 - dnnl::eltwise_forward::desc elementwise_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_pow, src_md, 1.0, beta); - dnnl::eltwise_forward::primitive_desc elementwise_pd(elementwise_desc, dnnl_engine); + dnnl::eltwise_forward::primitive_desc elementwise_pd(dnnl_engine, dnnl::prop_kind::forward_inference, + dnnl::algorithm::eltwise_pow, src_md, dst_md, 1.0, beta); // If using GPU this will move the memory from the CPU to the GPU. elementwise_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), elementwise_pd.src_desc(), dnnl_engine); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.cc index 05eee228b73..ebec6f4e74c 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License #include "dnnl_qattention.h" +#include "dnnl_util.h" namespace onnxruntime { namespace ort_dnnl { @@ -21,8 +22,7 @@ dnnl::memory DnnlQAttention::ComputeTotalScale(DnnlSubgraphPrimitive& sp, DnnlNo auto src_0_md = input_scale_mem.get_desc().reshape({1}); auto src_1_md = weights_scale_mem.get_desc().reshape({1}); auto dst_md = src_1_md; - auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_mul, src_0_md, src_1_md, dst_md); - auto binary_pd = dnnl::binary::primitive_desc(binary_d, eng); + auto binary_pd = dnnl::binary::primitive_desc(eng, dnnl::algorithm::binary_mul, src_0_md, src_1_md, dst_md); auto binary_src0_mem = sp.GetMemoryAndReshape(node.Input(INPUT_SCALE), binary_pd.src0_desc(), eng); auto binary_src1_mem = sp.GetMemoryAndReshape(node.Input(WEIGHTS_SCALE), binary_pd.src1_desc(), eng); @@ -115,12 +115,12 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { //set input zp if (has_input_zero_point) { - matmul_attr.set_zero_points(DNNL_ARG_SRC, 0, {DNNL_RUNTIME_S32_VAL}); + matmul_attr.set_zero_points_mask(DNNL_ARG_SRC, 0); } //set weight zp if (has_weights_zero_point) { - matmul_attr.set_zero_points(DNNL_ARG_WEIGHTS, 0, {DNNL_RUNTIME_S32_VAL}); + matmul_attr.set_zero_points_mask(DNNL_ARG_WEIGHTS, 0); } } @@ -131,18 +131,18 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) auto input_md_ori = sp.GetMemory(node.Input(INPUT)).get_desc(); auto weights_md_ori = sp.GetMemory(node.Input(WEIGHTS)).get_desc(); - auto weights_dims = weights_md_ori.dims(); + auto weights_dims = weights_md_ori.get_dims(); weights_dims.insert(weights_dims.begin(), 1); - input_md = dnnl::memory::desc(input_md_ori.dims(), input_md_ori.data_type(), dnnl::memory::format_tag::any); - weights_md = dnnl::memory::desc(weights_dims, weights_md_ori.data_type(), dnnl::memory::format_tag::any); + input_md = dnnl::memory::desc(input_md_ori.get_dims(), input_md_ori.get_data_type(), dnnl::memory::format_tag::any); + weights_md = dnnl::memory::desc(weights_dims, weights_md_ori.get_data_type(), dnnl::memory::format_tag::any); } dnnl::memory::desc QKV_md; { //the output of int8 matmul is always 3 dims and consists of Q,K,V values - auto QKV_dims = input_md.dims(); - QKV_dims[2] = weights_md.dims()[2]; + auto QKV_dims = input_md.get_dims(); + QKV_dims[2] = weights_md.get_dims()[2]; //use format any for optimization if (isBF16Acc) { QKV_md = dnnl::memory::desc(QKV_dims, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::any); @@ -151,8 +151,7 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) } } - auto matmul_d = dnnl::matmul::desc(input_md, weights_md, QKV_md); - auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, matmul_attr, eng); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, input_md, weights_md, QKV_md, matmul_attr); // (input-input_zero_point)*(weight-weight_zero_point) auto matmul_prim = dnnl::matmul(matmul_pd); @@ -189,19 +188,16 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) auto total_scale_mem = ComputeTotalScale(sp, node); auto bias_md = sp.GetMemory(node.Input(BIAS)).get_desc(); - bias_md = bias_md.reshape({1, 1, bias_md.dims()[0]}); + bias_md = bias_md.reshape({1, 1, bias_md.get_dims()[0]}); auto QKV_desc = QKV_mem.get_desc(); - //always broadcast from bias to QKV - auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_add, QKV_desc, bias_md, QKV_desc); - dnnl::primitive_attr binary_attr; //scale source 0, matmul output if (total_scale_mem) { - binary_attr.set_scales(DNNL_ARG_SRC_0, 0, {DNNL_RUNTIME_F32_VAL}); + binary_attr.set_scales_mask(DNNL_ARG_SRC_0, 0); } - auto binary_pd = dnnl::binary::primitive_desc(binary_d, binary_attr, eng); + auto binary_pd = dnnl::binary::primitive_desc(eng, dnnl::algorithm::binary_add, QKV_desc, bias_md, QKV_desc, binary_attr); auto binary_prim = dnnl::binary(binary_pd); auto bias_mem = sp.GetMemoryAndReshape(node.Input(BIAS), binary_pd.src1_desc(), eng); @@ -211,7 +207,7 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {DNNL_ARG_DST, QKV_mem}}); if (total_scale_mem) { - binary_mem_map[DNNL_ARG_ATTR_INPUT_SCALES | DNNL_ARG_SRC_0] = total_scale_mem; + binary_mem_map[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0] = total_scale_mem; } sp.AddPrimitive(binary_prim, binary_mem_map); @@ -219,10 +215,10 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) //parse some dim information for permute and reshape //eg, 8,512,2034 = 8,512,(3,12,64) - auto batch_size = QKV_mem.get_desc().dims()[0]; - auto sequence_length = QKV_mem.get_desc().dims()[1]; + auto batch_size = QKV_mem.get_desc().get_dims()[0]; + auto sequence_length = QKV_mem.get_desc().get_dims()[1]; auto num_heads = GetNumHeads(node); - auto hidden_size = QKV_mem.get_desc().dims()[2] / 3; + auto hidden_size = QKV_mem.get_desc().get_dims()[2] / 3; auto head_size = hidden_size / num_heads; // Slice QKV into submemories @@ -257,12 +253,16 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) //need a reorder of data type from s32 to f32 to let mask to have the same data type as QK result if (has_mask_index) { auto mask_index_mem_desc = sp.GetMemory(node.Input(MASK_INDEX)).get_desc(); - - auto linear_d = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_linear, mask_index_mem_desc, 10000.0f, -10000.0f); - auto linear_pd = dnnl::eltwise_forward::primitive_desc(linear_d, eng); + auto linear_dst_mem = dnnl::memory::desc( mask_index_mem_desc.get_dims(), + mask_index_mem_desc.get_data_type(), + dnnl::memory::format_tag::any); + auto linear_pd = dnnl::eltwise_forward::primitive_desc( eng, dnnl::prop_kind::forward_inference, + dnnl::algorithm::eltwise_linear, + mask_index_mem_desc, linear_dst_mem, + 10000.0f, -10000.0f); auto mask_index_ori_mem = sp.GetMemoryAndReshape(node.Input(MASK_INDEX), linear_pd.src_desc(), eng); - assert(linear_pd.dst_desc().data_type() == dnnl::memory::data_type::s32); + assert(linear_pd.dst_desc().get_data_type() == dnnl::memory::data_type::s32); auto mask_index_mem_unbroadcasted_src = dnnl::memory(linear_pd.dst_desc(), eng); auto linear_prim = dnnl::eltwise_forward(linear_pd); @@ -272,8 +272,8 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) dnnl::memory mask_index_mem_unbroadcasted_dst; { auto mask_index_md_unbroadcasted = mask_index_mem_unbroadcasted_src.get_desc(); - auto dims = mask_index_md_unbroadcasted.dims(); - auto strides = mask_index_md_unbroadcasted.data.format_desc.blocking.strides; + auto dims = mask_index_md_unbroadcasted.get_dims(); + auto strides = mask_index_md_unbroadcasted.get_strides(); dnnl::memory::dims strides_vec; for (size_t i = 0; i < dims.size(); i++) { strides_vec.push_back(strides[i]); @@ -288,7 +288,7 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) //unsqueeze the mem for broadcasting - auto mask_index_dims = mask_index_mem_unbroadcasted_dst.get_desc().dims(); + auto mask_index_dims = mask_index_mem_unbroadcasted_dst.get_desc().get_dims(); //not symetric, simply broadcasting //eg 8,512 -> 8,1,1,512 //eg 8,1,1,512 -> 8,12,512,512 @@ -297,9 +297,7 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) auto mask_index_broadcasted_md = mask_index_mem_unbroadcasted_dst.get_desc().reshape(mask_index_dims); //set mask_index_mem mask_index_mem = dnnl::memory(mask_index_broadcasted_md, eng, nullptr); - dnnl::stream s(eng); - mask_index_mem.set_data_handle(mask_index_mem_unbroadcasted_dst.get_data_handle(), s); - s.wait(); + mask_index_mem.set_data_handle(mask_index_mem_unbroadcasted_dst.get_data_handle()); } dnnl::memory QK_mem; @@ -308,8 +306,8 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { dnnl::primitive_attr QK_attr; { - auto scales = std::vector({float(1 / std::sqrt(head_size))}); - QK_attr.set_output_scales(0, scales); + // Set output scales + QK_attr.set_scales_mask(DNNL_ARG_SRC, 0); if (mask_index_mem) { dnnl::post_ops add_bias; @@ -326,26 +324,32 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) QK_md = dnnl::memory::desc({batch_size, num_heads, sequence_length, sequence_length}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::any); } } - auto QK_d = dnnl::matmul::desc(Q_md, K_md, QK_md); - auto QK_pd = dnnl::matmul::primitive_desc(QK_d, QK_attr, eng); + auto QK_pd = dnnl::matmul::primitive_desc(eng, Q_md, K_md, QK_md, QK_attr); auto QK_prim = dnnl::matmul(QK_pd); + // Create the memory object related to the scale + auto out_scales_mem = dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, eng); + // Write the alpha value into the memory object + sp.WriteToDnnlMemory(out_scales_mem, std::vector({float(1 / std::sqrt(head_size))})); + QK_mem = dnnl::memory(QK_pd.dst_desc(), eng); { //QKV_mem is used as both input and weight but since matmul is defined on submemory, computation will be applied to correct submemory std::unordered_map QK_mem_map({{DNNL_ARG_SRC, Q_mem}, {DNNL_ARG_WEIGHTS, K_mem}, - {DNNL_ARG_DST, QK_mem}}); + {DNNL_ARG_DST, QK_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, out_scales_mem}}); if (mask_index_mem) { QK_mem_map[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] = mask_index_mem; } - sp.AddPrimitive(QK_prim, QK_mem_map); + sp.AddPrimitive(QK_prim, QK_mem_map, {DNNL_ARG_DST}); } //apply softmax in place to produce attention prob { - auto softmax_desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, QK_mem.get_desc(), 3); - auto softmax_pd = dnnl::softmax_forward::primitive_desc(softmax_desc, eng); + auto softmax_pd = dnnl::softmax_forward::primitive_desc(eng, dnnl::prop_kind::forward_inference, + dnnl::algorithm::softmax_accurate, + QK_mem.get_desc(), QK_mem.get_desc(), 3); auto softmax_prim = dnnl::softmax_forward::primitive(softmax_pd); //QK = softmax(QK) in place @@ -367,8 +371,7 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) } } - auto Prob_V_d = dnnl::matmul::desc(QK_mem.get_desc(), V_md, QAttention_dst_md); - auto Prob_V_pd = dnnl::matmul::primitive_desc(Prob_V_d, eng); + auto Prob_V_pd = dnnl::matmul::primitive_desc(eng, QK_mem.get_desc(), V_md, QAttention_dst_md); auto Prob_V_prim = dnnl::matmul(Prob_V_pd); QAttention_dst_mem = dnnl::memory(Prob_V_pd.dst_desc(), eng); @@ -424,7 +427,7 @@ dnnl::memory DnnlQAttention::CopySubMemory(DnnlSubgraphPrimitive& sp, dnnl::memo // Make destination memory object from source descriptor given sub memory dimension and offset auto src_md = src_mem.get_desc().submemory_desc(sub_mem_dims, sub_mem_offset); - dnnl::memory::desc dst_md = dnnl::memory::desc(src_md.dims(), src_md.data_type(), sp.GetDnnlFormat(src_md.dims().size())); + dnnl::memory::desc dst_md = dnnl::memory::desc(src_md.get_dims(), src_md.get_data_type(), sp.GetDnnlFormat(src_md.get_dims().size())); dst_mem = dnnl::memory(dst_md, eng); // Copy submemory from source to destination given dimensions and offset @@ -446,7 +449,7 @@ dnnl::memory DnnlQAttention::CastMemory(DnnlSubgraphPrimitive& sp, dnnl::memory& // Make a new memory descriptor based on the source descriptor and given destination datatype auto src_md = src_mem.get_desc(); - dnnl::memory::desc dst_md = dnnl::memory::desc(src_md.dims(), dst_datatype, sp.GetDnnlFormat(src_md.dims().size())); + dnnl::memory::desc dst_md = dnnl::memory::desc(src_md.get_dims(), dst_datatype, sp.GetDnnlFormat(src_md.get_dims().size())); dst_mem = dnnl::memory(dst_md, eng); // Reorder source memory to destination memory as per the given datatype @@ -468,7 +471,7 @@ dnnl::memory DnnlQAttention::CastAndTransformMemory(DnnlSubgraphPrimitive& sp, d // Make a new memory descriptor based on the source descriptor and given destination dataype and strides auto src_md = src_mem.get_desc(); - dnnl::memory::desc dst_md = dnnl::memory::desc(src_md.dims(), dst_datatype, dst_strides); + dnnl::memory::desc dst_md = dnnl::memory::desc(src_md.get_dims(), dst_datatype, dst_strides); dst_mem = dnnl::memory(dst_md, eng); // Reorder source memory to destination memory as per the given dataype and strides diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h index 82879047679..d1cea23fca2 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h @@ -5,7 +5,6 @@ #include #include "dnnl_subgraph.h" #include "dnnl_subgraph_primitive.h" -#include "dnnl_util.h" namespace onnxruntime { namespace ort_dnnl { diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_reduce.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_reduce.cc index 1b06724e26d..cd8901a5043 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_reduce.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_reduce.cc @@ -68,7 +68,7 @@ void DnnlReduce::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { } else { if (node.Input(IN_AXES).Exists()) { auto axes_mem = sp.GetMemory(node.Input(IN_AXES)); - dnnl::memory::dims axes_dims = axes_mem.get_desc().dims(); + dnnl::memory::dims axes_dims = axes_mem.get_desc().get_dims(); int64_t* p_axes_data = (int64_t*)axes_mem.get_data_handle(); axes = std::vector(p_axes_data, p_axes_data + axes_dims[0]); } @@ -93,7 +93,7 @@ void DnnlReduce::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { //We need to calculate output tensor shape //First we initialize it with input shape and then we modify it based on the attribute values //This is because the DNNL primitive functionality is determined by the input and output shapes. - auto src_dims = src_md.dims(); + auto src_dims = src_md.get_dims(); auto ndim = src_dims.size(); // convert negative axis values to the positive axis @@ -120,13 +120,13 @@ void DnnlReduce::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto dst_shape = TensorShape(src_dims.data(), ndim); dnnl::memory::dims dst_dims_mkl(dst_shape.GetDims().begin(), dst_shape.GetDims().end()); - auto dst_md = dnnl::memory::desc({dst_dims_mkl}, src_md.data_type(), dnnl::memory::format_tag::any); + auto dst_md = dnnl::memory::desc({dst_dims_mkl}, src_md.get_data_type(), dnnl::memory::format_tag::any); // Check to see if the destination shape and source shape are the same. bool src_and_dst_dims_equal = true; - if (src_md.dims().size() == dst_md.dims().size()) { - for (size_t i = 0; i < src_md.dims().size(); ++i) { - if (src_md.dims()[i] != dst_md.dims()[i]) { + if (src_md.get_dims().size() == dst_md.get_dims().size()) { + for (size_t i = 0; i < src_md.get_dims().size(); ++i) { + if (src_md.get_dims()[i] != dst_md.get_dims()[i]) { src_and_dst_dims_equal = false; break; } @@ -164,22 +164,25 @@ void DnnlReduce::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { dnnl::primitive_attr dnnl_primitive_attr; if ((reduce_op == ReduceLogSum || reduce_op == ReduceLogSumExp ) && !src_and_dst_dims_equal) { dnnl::post_ops eltwise_post_op; - eltwise_post_op.append_eltwise(1.0f, dnnl::algorithm::eltwise_log, 1.0f, 1.0f); + eltwise_post_op.append_eltwise(dnnl::algorithm::eltwise_log, 1.0f, 1.0f); dnnl_primitive_attr.set_post_ops(eltwise_post_op); } if (reduce_op == ReduceLogSumExp) { if (!src_and_dst_dims_equal) { - auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_exp, src_md); - auto elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine); + auto elementwise_pd = dnnl::eltwise_forward::primitive_desc(dnnl_engine, dnnl::prop_kind::forward_inference, + dnnl::algorithm::eltwise_exp, src_md, + dnnl::memory::desc(src_md.get_dims(), + src_md.get_data_type(), + dnnl::memory::format_tag::any)); auto elementwise_dst_mem = dnnl::memory(elementwise_pd.dst_desc(), dnnl_engine); auto elemenwise_primitive = dnnl::eltwise_forward(elementwise_pd); sp.AddPrimitive(elemenwise_primitive, {{DNNL_ARG_SRC, src_mem}, {DNNL_ARG_DST, elementwise_dst_mem}}); - auto reduce_desc = dnnl::reduction::desc(algo, src_md, dst_md, 0.f, 0.f); - auto reduce_pd = dnnl::reduction::primitive_desc(reduce_desc, dnnl_primitive_attr, dnnl_engine); + auto reduce_pd = dnnl::reduction::primitive_desc(dnnl_engine, algo, src_md, dst_md, 0.f, 0.f, + dnnl_primitive_attr); reduce_dst_mem = dnnl::memory(reduce_pd.dst_desc(), dnnl_engine); @@ -190,8 +193,11 @@ void DnnlReduce::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { reduce_dst_mem = src_mem; } } else if(reduce_op == ReduceSumSquare) { - auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_square, src_md); - auto elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine); + auto elementwise_pd = dnnl::eltwise_forward::primitive_desc(dnnl_engine, dnnl::prop_kind::forward_inference, + dnnl::algorithm::eltwise_square, src_md, + dnnl::memory::desc(src_md.get_dims(), + src_md.get_data_type(), + dnnl::memory::format_tag::any)); auto elementwise_dst_mem = dnnl::memory(elementwise_pd.dst_desc(), dnnl_engine); @@ -199,8 +205,7 @@ void DnnlReduce::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { sp.AddPrimitive(elemenwise_primitive, {{DNNL_ARG_SRC, src_mem}, {DNNL_ARG_DST, elementwise_dst_mem}}); if (!src_and_dst_dims_equal) { - auto reduce_desc = dnnl::reduction::desc(algo, src_md, dst_md, 0.f, 0.f); - auto reduce_pd = dnnl::reduction::primitive_desc(reduce_desc, dnnl_engine); + auto reduce_pd = dnnl::reduction::primitive_desc(dnnl_engine, algo, src_md, dst_md, 0.f, 0.f); reduce_dst_mem = dnnl::memory(reduce_pd.dst_desc(), dnnl_engine); @@ -220,8 +225,8 @@ void DnnlReduce::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { p_val = 2.0f; } - auto reduce_desc = dnnl::reduction::desc(algo, src_md, dst_md, p_val, 0.f); - auto reduce_pd = dnnl::reduction::primitive_desc(reduce_desc, dnnl_primitive_attr, dnnl_engine); + auto reduce_pd = dnnl::reduction::primitive_desc(dnnl_engine, algo, src_md, dst_md, p_val, 0.f, + dnnl_primitive_attr); // If using GPU this will move the memory from the CPU to the GPU. reduce_src_mem = sp.GetMemoryAndReshape(node.Input(IN_DATA), reduce_pd.src_desc(), dnnl_engine); @@ -232,8 +237,11 @@ void DnnlReduce::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { {DNNL_ARG_DST, reduce_dst_mem}}); } else { if (reduce_op == ReduceLogSum) { - auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_log, src_md); - auto elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine); + auto elementwise_pd = dnnl::eltwise_forward::primitive_desc(dnnl_engine, dnnl::prop_kind::forward_inference, + dnnl::algorithm::eltwise_log, src_md, + dnnl::memory::desc(src_md.get_dims(), + src_md.get_data_type(), + dnnl::memory::format_tag::any)); reduce_dst_mem = dnnl::memory(elementwise_pd.dst_desc(), dnnl_engine); @@ -274,7 +282,7 @@ void DnnlReduce::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { if ((j < axes.size() && axes[j] == static_cast(i) && src_dims[i] == 0) || (axes.size() == 0 && src_dims[i] == 0)) { if (!keepdims) { - auto dims = src_md.dims(); + auto dims = src_md.get_dims(); ORT_ENFORCE(keepdims, "Can't reduce on dim with value of 0 if 'keepdims' is false. " "Invalid output shape would be produced. input_shape:", diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_relugrad.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_relugrad.cc index 62da9cb3d89..a542a7d67b8 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_relugrad.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_relugrad.cc @@ -18,13 +18,20 @@ void DnnlReluGrad::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto relu_bwd_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), src_mem.get_desc(), eng); auto relu_bwd_diff_dst_mem = sp.GetMemoryAndReshape(node.Input(IN_dY), diff_dst_mem.get_desc(), eng); - //create hints on the fly - auto hints_d = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward, dnnl::algorithm::eltwise_relu, relu_bwd_src_mem.get_desc(), 0.0, 0.0); - auto hints_pd = dnnl::eltwise_forward::primitive_desc(hints_d, eng); - - auto relu_bwd_d = dnnl::eltwise_backward::desc(dnnl::algorithm::eltwise_relu, relu_bwd_diff_dst_mem.get_desc(), relu_bwd_src_mem.get_desc(), 0.0, 0.0); + // Generate the dst_md + auto dst_md = dnnl::memory::desc(src_mem.get_desc().get_dims(), + node.Output(OUT_dX).Type(), + dnnl::memory::format_tag::any); - auto relu_bwd_pd = dnnl::eltwise_backward::primitive_desc(relu_bwd_d, eng, hints_pd); + //create hints on the fly + auto hints_pd = dnnl::eltwise_forward::primitive_desc(eng, dnnl::prop_kind::forward, dnnl::algorithm::eltwise_relu, + relu_bwd_src_mem.get_desc(), dst_md, 0.0, 0.0); + + auto relu_bwd_pd = dnnl::eltwise_backward::primitive_desc(eng, dnnl::algorithm::eltwise_relu, + relu_bwd_diff_dst_mem.get_desc(), + relu_bwd_src_mem.get_desc(), + src_mem.get_desc(), + 0.0, 0.0, hints_pd); auto relu_bwd_diff_src_mem = dnnl::memory(relu_bwd_pd.diff_src_desc(), eng); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_reshape.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_reshape.cc index 16090e86cf3..1e4ca5ddd4e 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_reshape.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_reshape.cc @@ -15,10 +15,10 @@ void DnnlReshape::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // the input shape assumes OrtFormat so we get the memory in OrtFormat. auto data_mem = sp.GetMemoryInOrtFormat(node.Input(IN_DATA), dnnl_engine); - dnnl::memory::dims data_dims = data_mem.get_desc().dims(); + dnnl::memory::dims data_dims = data_mem.get_desc().get_dims(); auto shape_mem = sp.GetMemory(node.Input(IN_SHAPE)); - dnnl::memory::dims shape_dims = shape_mem.get_desc().dims(); + dnnl::memory::dims shape_dims = shape_mem.get_desc().get_dims(); int64_t* shape_data = (int64_t*)shape_mem.get_data_handle(); // Reshape helper will take input data_dims shape and the reshape_shape and replace the -1 and 0s with the calculated diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_softmax.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_softmax.cc index fbb0754a3fe..c44abd913e8 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_softmax.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_softmax.cc @@ -23,11 +23,18 @@ void DnnlSoftmax::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto softmax_src_mem = sp.GetMemory(node.Input(IN_X)); auto softmax_src_md = softmax_src_mem.get_desc(); - if (axis < 0) - axis = softmax_src_md.dims().size() + axis; + if (axis < 0){ + axis = softmax_src_md.get_dims().size() + axis; + } + + // Generate the dst_md + auto dst_md = dnnl::memory::desc(softmax_src_md.get_dims(), + node.Output(OUT_Y).Type(), + dnnl::memory::format_tag::any); - auto softmax_desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, softmax_src_md, (int) axis); - auto softmax_pd = dnnl::softmax_forward::primitive_desc(softmax_desc, dnnl_engine); + auto softmax_pd = dnnl::softmax_forward::primitive_desc(dnnl_engine, dnnl::prop_kind::forward_training, + dnnl::algorithm::softmax_accurate, softmax_src_md, dst_md, + static_cast(axis)); // If using GPU this will move the memory from the CPU to the GPU. softmax_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), softmax_pd.src_desc(), dnnl_engine); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_softmaxgrad.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_softmaxgrad.cc index f033b665776..930d7fe843b 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_softmaxgrad.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_softmaxgrad.cc @@ -18,18 +18,27 @@ void DnnlSoftmaxGrad::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) auto softmax_bwd_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), src_mem.get_desc(), eng); auto softmax_bwd_diff_dst_mem = sp.GetMemoryAndReshape(node.Input(IN_dY), diff_dst_mem.get_desc(), eng); - auto axis = ReadAxis(node); + int axis; + { + auto axis64 = ReadAxis(node); + if (axis64 < 0) + axis64 = src_mem.get_desc().get_dims().size() + axis64; - if (axis < 0) - axis = src_mem.get_desc().dims().size() + axis; + axis = static_cast(axis64); + } - //create hints on the fly - auto hints_d = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, softmax_bwd_src_mem.get_desc(), (int) axis); - auto hints_pd = dnnl::softmax_forward::primitive_desc(hints_d, eng); + auto fws_dst_md = dnnl::memory::desc(diff_dst_mem.get_desc().get_dims(), + diff_dst_mem.get_desc().get_data_type(), + dnnl::memory::format_tag::any); - auto softmax_bwd_d = dnnl::softmax_backward::desc(softmax_bwd_diff_dst_mem.get_desc(), softmax_bwd_src_mem.get_desc(), (int) axis); + //create hints on the fly + auto hints_pd = dnnl::softmax_forward::primitive_desc(eng, dnnl::prop_kind::forward_training, + dnnl::algorithm::softmax_accurate, + softmax_bwd_src_mem.get_desc(), fws_dst_md, axis); - auto softmax_bwd_pd = dnnl::softmax_backward::primitive_desc(softmax_bwd_d, eng, hints_pd); + auto softmax_bwd_pd = dnnl::softmax_backward::primitive_desc(eng, dnnl::algorithm::softmax_accurate, + fws_dst_md, softmax_bwd_diff_dst_mem.get_desc(), + softmax_bwd_src_mem.get_desc(), axis, hints_pd); auto softmax_bwd_diff_src_mem = dnnl::memory(softmax_bwd_pd.diff_src_desc(), eng); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_squeeze.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_squeeze.cc index f9c2fe9b6bf..024dbb1f779 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_squeeze.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_squeeze.cc @@ -15,14 +15,14 @@ void DnnlSqueeze::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // the input shape assumes OrtFormat so we get the memory in OrtFormat. auto data_mem = sp.GetMemoryInOrtFormat(node.Input(IN_DATA), dnnl_engine); - dnnl::memory::dims data_dims = data_mem.get_desc().dims(); + dnnl::memory::dims data_dims = data_mem.get_desc().get_dims(); std::vector axes_data; // ONNX Squeeze version 13+ the axes is an input tensor // ONNX Squeeze before version 13 axes comes from an Attribute. if (node.Input(IN_AXES).Exists()) { auto axes_mem = sp.GetMemory(node.Input(IN_AXES)); - dnnl::memory::dims axes_dims = axes_mem.get_desc().dims(); + dnnl::memory::dims axes_dims = axes_mem.get_desc().get_dims(); int64_t* p_axes_data = (int64_t*)axes_mem.get_data_handle(); axes_data = std::vector(p_axes_data, p_axes_data + axes_dims[0]); } else { diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc index 0854bb29e5b..ce747daf623 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc @@ -78,8 +78,8 @@ inline bool Contains(const Map& map, const Key& key) { #if DNNL_TENSOR_PRINT_MEMORY void DnnlSubgraphPrimitive::PrintMemory(const dnnl::memory& mem) { auto md = mem.get_desc(); - auto dt = md.data_type(); - auto dims = md.dims(); + auto dt = md.get_data_type(); + auto dims = md.get_dims(); if (Product(dims) > DNNL_TENSOR_PRINT_MEMORY_MAX_TENSOR_ELEMENTS) { printf("tensor too long ignore printing \n"); return; @@ -87,7 +87,7 @@ void DnnlSubgraphPrimitive::PrintMemory(const dnnl::memory& mem) { dnnl::memory to_mem; if (!IsMemoryInExpectedOrtFormat(md)|| mem.get_engine().get_kind() != dnnl::engine::kind::cpu) { printf("\n print memory reorder started \n"); - dnnl::memory::desc to_md = dnnl::memory::desc(md.dims(), md.data_type(), GetDnnlFormat(md.dims().size())); + dnnl::memory::desc to_md = dnnl::memory::desc(md.get_dims(), md.get_data_type(), GetDnnlFormat(md.get_dims().size())); to_mem = dnnl::memory(to_md, GetCPUEngine()); auto stream = dnnl::stream(mem.get_engine()); dnnl::reorder(mem, to_mem).execute(stream, {{DNNL_ARG_FROM, mem}, {DNNL_ARG_TO, to_mem}}); @@ -411,7 +411,7 @@ void DnnlSubgraphPrimitive::AddOutputs() { auto dnnl_tensor_name = tensor->Name(); auto engine = GetCPUEngine(); auto output_mem_dnnl = GetMemory(dnnl_tensor_name); - auto output_md = dnnl::memory::desc(output_mem_dnnl.get_desc().dims(), dnnl_data_type, GetDnnlFormat(output_mem_dnnl.get_desc().dims().size())); + auto output_md = dnnl::memory::desc(output_mem_dnnl.get_desc().get_dims(), dnnl_data_type, GetDnnlFormat(output_mem_dnnl.get_desc().get_dims().size())); // if output already in correct memory format, just place it to outputs instead of reorder bool copy_output = outputs_are_always_copied_.find(dnnl_tensor_name) != outputs_are_always_copied_.end(); if (output_mem_dnnl.get_desc() == output_md && output_mem_dnnl.get_engine() == engine && !copy_output) { @@ -557,9 +557,9 @@ dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(const DnnlTensor& tensor auto mem_to = dnnl::memory(mem_desc, eng); // if it is a reshape, ensure reorder is possible by making the same dims - if (mem_from.get_desc().dims() != mem_to.get_desc().dims() || transpose) { - auto mem_from_dims = mem_from.get_desc().dims(); - auto mem_to_dims = mem_to.get_desc().dims(); + if (mem_from.get_desc().get_dims() != mem_to.get_desc().get_dims() || transpose) { + auto mem_from_dims = mem_from.get_desc().get_dims(); + auto mem_to_dims = mem_to.get_desc().get_dims(); if (Product(mem_from_dims) != Product(mem_to_dims)) { LOGS_DEFAULT(ERROR) << tensor.Name() << ", Dims From: " << mem_from_dims << ", To: " << mem_to_dims; throw std::invalid_argument("not a valid reshape, inconsistent dim product"); @@ -571,14 +571,12 @@ dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(const DnnlTensor& tensor //TODO: expand to arbitrary permutation or transpose on given 2 dims for higher dimensional tensors mem_from_reshape_md = mem_from_reshape_md.permute_axes({1, 0}); } - mem_from_reshape_md = mem_from_reshape_md.reshape(mem_desc.dims()); + mem_from_reshape_md = mem_from_reshape_md.reshape(mem_desc.get_dims()); auto mem_from_reshape = dnnl::memory(mem_from_reshape_md, mem_from.get_engine(), nullptr); if (is_constant) { // if constant, do reshape now LOGS_DEFAULT(INFO) << "reshaped now"; //use the stream as a hint to make sure data handle gets set - dnnl::stream s{eng}; - mem_from_reshape.set_data_handle(mem_from.get_data_handle(),s); - s.wait(); + mem_from_reshape.set_data_handle(mem_from.get_data_handle()); } else { AddReshape(mem_from, mem_from_reshape); } @@ -614,7 +612,7 @@ dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(const DnnlTensor& tensor dnnl::memory DnnlSubgraphPrimitive::GetMemoryInOrtFormat(const DnnlTensor& tensor, const dnnl::engine& eng) { auto from_mem = GetMemory(tensor); auto from_desc = from_mem.get_desc(); - auto from_dims = from_desc.dims(); + auto from_dims = from_desc.get_dims(); if (!IsMemoryInExpectedOrtFormat(from_desc)) { dnnl::memory::desc to_md = dnnl::memory::desc(from_dims, tensor.Type(), GetDnnlFormat(from_dims.size())); dnnl::memory to_mem = dnnl::memory(to_md, eng); @@ -628,18 +626,18 @@ dnnl::memory DnnlSubgraphPrimitive::GetMemoryInOrtFormat(const DnnlTensor& tenso } bool DnnlSubgraphPrimitive::IsMemoryInExpectedOrtFormat(const dnnl::memory::desc& desc) const { - if (desc.data.format_kind != dnnl_blocked) { + if (desc.get_format_kind() != dnnl::memory::format_kind::blocked) { return false; } - if (desc.data.format_desc.blocking.inner_nblks != 0) { + if (desc.get_inner_nblks() != 0) { return false; } - auto strides = desc.data.format_desc.blocking.strides; + auto strides = desc.get_strides(); // if a data format is dnnl_format::abcd... the stride will go from largest to smallest // if for example we have a shape {2,3,4} we expect a stride of {12, 4, 1} if it were // of dnnl_format::abc if instead the stride were {12, 1, 4} that would be dnnl_format::acb // which does not match what is expected from Onnxruntime. - for (size_t i = 1; i < desc.dims().size(); ++i) { + for (size_t i = 1; i < desc.get_dims().size(); ++i) { if (strides[i - 1] < strides[i]) { return false; } @@ -666,23 +664,20 @@ onnxruntime::common::Status DnnlSubgraphPrimitive::Predict(const std::unordered_ for (auto& input : inputs) { if (Contains(inputs_, input.first)) { - inputs_.at(input.first).set_data_handle(input.second.buffer, stream); - stream.wait(); + inputs_.at(input.first).set_data_handle(input.second.buffer); } } for (auto& output : outputs) { if (Contains(outputs_, output.first)) { - outputs_.at(output.first).set_data_handle(output.second.buffer, stream); - stream.wait(); + outputs_.at(output.first).set_data_handle(output.second.buffer); } } // reshapes (eg, unsqueeze) // it is safe to set data handle because all external data handles have been set and onednn managed memory data handles will not change for (auto& reshape_pair : reshapes_) { - reshape_pair.second.set_data_handle(reshape_pair.first.get_data_handle(),stream); - stream.wait(); + reshape_pair.second.set_data_handle(reshape_pair.first.get_data_handle()); } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h index b8e9079d029..cf9c8514a2f 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h @@ -76,6 +76,25 @@ class DnnlSubgraphPrimitive { dnnl::memory GetMemoryInOrtFormat(const DnnlTensor& tensor, const dnnl::engine& eng); bool IsMemoryInExpectedOrtFormat(const dnnl::memory::desc& desc) const; + template + void WriteToDnnlMemory(dnnl::memory& mem, std::vector values) { + if (mem.get_engine().get_kind() == dnnl::engine::kind::gpu) { + // Create a CPU memory + auto cpu_memory = dnnl::memory(mem.get_desc(), GetCPUEngine()); + // Copy data from the vector into the CPU memory data handle + std::copy(values.begin(), values.end(), static_cast(cpu_memory.get_data_handle())); + // Use reorder to copy data from CPU to GPU + dnnl::stream s{mem.get_engine()}; + // mem now contains all zero + dnnl::reorder(cpu_memory, mem).execute(s, cpu_memory, mem); + // wait for reorder to complete + s.wait(); + } else { + // Copy data from the vector into the memory data handle + std::copy(values.begin(), values.end(), static_cast(mem.get_data_handle())); + } + } + private: std::string shape_key_; diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_sum.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_sum.cc index 8832c2ea5b4..2d692f47c18 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_sum.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_sum.cc @@ -25,10 +25,10 @@ void DnnlSum::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { scales.push_back(1.0f); } - auto dst_dims = srcs_pd[0].dims(); + auto dst_dims = srcs_pd[0].get_dims(); auto dst_md = dnnl::memory::desc({dst_dims}, node.Input(IN_DATA_0).Type(), dnnl::memory::format_tag::any); - auto sum_pd = dnnl::sum::primitive_desc(dst_md, scales, srcs_pd, dnnl_engine); + auto sum_pd = dnnl::sum::primitive_desc(dnnl_engine, dst_md, scales, srcs_pd); for (size_t i = 0; i < src_mems.size(); ++i) { src_mems[i] = sp.GetMemoryAndReshape(node.Input(static_cast(IN_DATA_0 + i)), sum_pd.src_desc(), dnnl_engine); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_transpose.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_transpose.cc index 2f161e4ebda..a6952ab5fa8 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_transpose.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_transpose.cc @@ -31,7 +31,7 @@ void DnnlTranspose::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto dnnl_engine = sp.GetEngine(); auto data_mem = sp.GetMemory(node.Input(IN_DATA)); - auto data_dims = data_mem.get_desc().dims(); + auto data_dims = data_mem.get_desc().get_dims(); auto ndata_dims = data_dims.size(); auto perm = GetPerm(node); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_unsqueeze.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_unsqueeze.cc index 9532686028a..88cd212101d 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_unsqueeze.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_unsqueeze.cc @@ -22,7 +22,7 @@ void DnnlUnsqueeze::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // To counter this data_dims is left empty if the input is from a scalar. dnnl::memory::dims data_dims; if (!data_is_scalar) { - data_dims = data_mem.get_desc().dims(); + data_dims = data_mem.get_desc().get_dims(); } std::vector axes_data; @@ -30,7 +30,7 @@ void DnnlUnsqueeze::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // ONNX Unsqueeze before version 13 axes comes from an Attribute. if (node.Input(IN_AXES).Exists()) { auto axes_mem = sp.GetMemory(node.Input(IN_AXES)); - dnnl::memory::dims axes_dims = axes_mem.get_desc().dims(); + dnnl::memory::dims axes_dims = axes_mem.get_desc().get_dims(); int64_t* p_axes_data = (int64_t*)axes_mem.get_data_handle(); axes_data = std::vector(p_axes_data, p_axes_data + axes_dims[0]); } else { diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_util.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_util.cc index 0279f4f7430..db9329e8b1f 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_util.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_util.cc @@ -40,12 +40,14 @@ bool GetGPUInfo(GPUInfo gpu_info) { gpuRuntimeFound = true; // attempt to make a dnnl::matmul::desc. If we are able to successfully make a bf16 matmul::desc // assume the GPU supports all BF16 operations. + dnnl::primitive_attr attr; + attr.set_scales_mask(DNNL_ARG_SRC, 0); + attr.set_zero_points_mask(DNNL_ARG_SRC, /* mask */ 0); auto src0_md = dnnl::memory::desc({1,1}, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::ab); auto src1_md = dnnl::memory::desc({1,1}, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::ab); auto dst_md = dnnl::memory::desc({1,1}, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::ab); - auto matmul_d = dnnl::matmul::desc(src0_md, src1_md, dst_md); try { - auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, gpu_engine); + auto matmul_pd = dnnl::matmul::primitive_desc(gpu_engine, src0_md, src1_md, dst_md, attr); gpuBF16Supported = true; } catch(const dnnl::error& e) { if (e.status == dnnl_unimplemented) { diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_util.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_util.h index 3dbb913f0d5..c8a96597c65 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_util.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_util.h @@ -10,8 +10,11 @@ namespace onnxruntime { namespace ort_dnnl { namespace dnnl_util { bool IsGPURuntimeAvalible(); + bool IsBF16Supported(); + dnnl::algorithm OrtOperatorToDnnlAlgorithm(std::string op); + } // namespace dnnl_util } // namespace ort_dnnl } // namespace onnxruntime