diff --git a/src/graph/backend/dnnl/dnnl_op_def.hpp b/src/graph/backend/dnnl/dnnl_op_def.hpp index afdd51a37a9..07df5aa646a 100644 --- a/src/graph/backend/dnnl/dnnl_op_def.hpp +++ b/src/graph/backend/dnnl/dnnl_op_def.hpp @@ -701,6 +701,7 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_binary, 1, .set_num_outputs(2) .set_input(0, "a") .set_input(1, "b") + .set_input(2, "cond") .set_output(0, "output") .set_output(1, "scratchpad") // Attributes inherited from front binary ops (Add, Multiply, @@ -713,6 +714,7 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_binary, 1, {"NXC", "NCX"}) // New added attributes .set_attr(op_attr::is_bias_add, false, attribute_kind::b, false) + .set_attr(op_attr::is_select, false, attribute_kind::b, false) .set_attr(op_attr::fusion_info_key, false, attribute_kind::i, (int64_t)-1) .set_attr(op_attr::alg_kind, true, attribute_kind::i) diff --git a/src/graph/backend/dnnl/dnnl_shape_infer.cpp b/src/graph/backend/dnnl/dnnl_shape_infer.cpp index 781fe979190..9b2c4ea0923 100644 --- a/src/graph/backend/dnnl/dnnl_shape_infer.cpp +++ b/src/graph/backend/dnnl/dnnl_shape_infer.cpp @@ -489,12 +489,15 @@ status_t infer_dnnl_binary_output_shape(op_t *n, std::vector &outputs) { const bool is_bias_add = n->has_attr(op_attr::is_bias_add) && n->get_attr(op_attr::is_bias_add); - - auto ret = is_bias_add - ? infer_bias_add_output_shape(n, inputs, outputs) - : infer_elemwise_arithmetic_output_shape(n, inputs, outputs); - - return ret; + const bool is_select = n->has_attr(op_attr::is_select) + && n->get_attr(op_attr::is_select); + if (is_select) { + return infer_select_output_shape(n, inputs, outputs); + } else if (is_bias_add) { + return infer_bias_add_output_shape(n, inputs, outputs); + } else { + return infer_elemwise_arithmetic_output_shape(n, inputs, outputs); + } } } // namespace dnnl_impl diff --git a/src/graph/backend/dnnl/internal_attrs.hpp b/src/graph/backend/dnnl/internal_attrs.hpp index 93c6f3e4e99..c00f1ffd67b 100644 --- a/src/graph/backend/dnnl/internal_attrs.hpp +++ b/src/graph/backend/dnnl/internal_attrs.hpp @@ -45,6 +45,7 @@ const op_attr_t with_runtime_dst_zps = 0x1000c; const op_attr_t is_bias_add = 0x1000d; const op_attr_t with_sum = 0x1000e; const op_attr_t keep_dst_layout = 0x1000f; +const op_attr_t is_select = 0x10010; // int64_t const op_attr_t alg_kind = 0x10100; diff --git a/src/graph/backend/dnnl/kernels/sdp_decomp.cpp b/src/graph/backend/dnnl/kernels/sdp_decomp.cpp index ffde91622e0..1bc0e69123c 100644 --- a/src/graph/backend/dnnl/kernels/sdp_decomp.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_decomp.cpp @@ -60,6 +60,7 @@ status_t sdp_decomp_kernel_t::compile_impl( pass_pipeline_t pipeline = pass_pipeline_t(vis); pass_pipeline_t select_pipeline = pass_pipeline_t(vis); BACKEND_DNNL_ADD_PASS(pipeline, lower_down); + BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_multiple_binary_ops); BACKEND_DNNL_ADD_PASS(pipeline, fuse_reshape_for_gqa); // Fusion and canonicalization passes begin if (quantized) { @@ -391,4 +392,4 @@ template struct sdp_decomp_kernel_t; } // namespace dnnl_impl } // namespace graph } // namespace impl -} // namespace dnnl +} // namespace dnnl \ No newline at end of file diff --git a/src/graph/backend/dnnl/op_executable.cpp b/src/graph/backend/dnnl/op_executable.cpp index 94e9d43b3aa..16d154c5447 100644 --- a/src/graph/backend/dnnl/op_executable.cpp +++ b/src/graph/backend/dnnl/op_executable.cpp @@ -1249,8 +1249,16 @@ binary_executable_t::desc_t binary_executable_t::create_desc( op->get_attr(op_attr::alg_kind)); dnnl::binary::primitive_desc pd; - pd = dnnl::binary::primitive_desc( - p_engine, algo, src0, src1, dst, prm_attr); + if (op->has_attr(op_attr::is_select) + && op->get_attr(op_attr::is_select)) { + auto src2 = make_dnnl_memory_desc( + op->get_input_value(2)->get_logical_tensor()); + pd = dnnl::binary::primitive_desc( + p_engine, algo, src0, src1, src2, dst, prm_attr); + } else { + pd = dnnl::binary::primitive_desc( + p_engine, algo, src0, src1, dst, prm_attr); + } pd_cache.insert({op.get(), pd}); @@ -1874,12 +1882,17 @@ arg_indices_t matmul_executable_t::get_arg_indices( arg_indices_t binary_executable_t::get_arg_indices( const op_t *op, fusion_info_mgr_t &mgr) { arg_indices_t arg_indices; + const bool is_select = op->has_attr(op_attr::is_select) + ? op->get_attr(op_attr::is_select) + : false; // add input args size_t index = 0; arg_indices.insert({DNNL_ARG_SRC_0, indices_t {input, index++}}); arg_indices.insert({DNNL_ARG_SRC_1, indices_t {input, index++}}); - + if (is_select) { + arg_indices.insert({DNNL_ARG_SRC_2, indices_t {input, index++}}); + } get_arg_indices_for_post_ops(op, mgr, arg_indices, index); // add output args diff --git a/src/graph/backend/dnnl/passes/lower.cpp b/src/graph/backend/dnnl/passes/lower.cpp index dca5d32909e..52956519c76 100644 --- a/src/graph/backend/dnnl/passes/lower.cpp +++ b/src/graph/backend/dnnl/passes/lower.cpp @@ -651,122 +651,195 @@ static status_t dynamic_dequant_handler( static status_t select_handler( const std::shared_ptr &op, subgraph_rewriter_t &rewriter) { - auto in_vals = op->get_input_values(); - auto out_vals = op->get_output_values(); - assertm(in_vals.size() == 3 && out_vals.size() == 1, - "select should have three inputs and a output"); - auto cond = in_vals[0]; - auto src0 = in_vals[1]; - auto src1 = in_vals[2]; - cond->set_data_type(dnnl::impl::data_type::u8); - - //TODO: This reorder can be removed once eltwise_clip support int8 input - op_ptr type_cast = std::make_shared(op_kind::dnnl_reorder); - type_cast->set_attr(op_attr::change_layout, false); - - op_ptr clip = std::make_shared(op_kind::dnnl_eltwise); - clip->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::eltwise_clip)); - clip->set_attr(op_attr::alpha, 0.f); - clip->set_attr(op_attr::beta, 1.f); - - // After reorder and clip. The cond value is 0 or 1. - // Then output = src0.*cond+src1.*(cond*-1 + 1) - op_ptr mul1 = std::make_shared(op_kind::dnnl_binary); - mul1->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::binary_mul)); - mul1->merge_attributes(op->get_attributes()); - - op_ptr mul2 = std::make_shared(op_kind::dnnl_binary); - mul2->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::binary_mul)); - mul2->merge_attributes(op->get_attributes()); - - op_ptr linear = std::make_shared(op_kind::dnnl_eltwise); - linear->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::eltwise_linear)); - const float alpha_value = -1.0f, beta_value = 1.0f; - linear->set_attr(op_attr::alpha, alpha_value); - linear->set_attr(op_attr::beta, beta_value); - - op_ptr add = std::make_shared(op_kind::dnnl_binary); - add->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::binary_add)); - - // reconnect - cond->remove_consumer(*op, 0); - src0->remove_consumer(*op, 1); - src1->remove_consumer(*op, 2); - - // first reorder and clip - cond->add_consumer(*type_cast, 0); - type_cast->add_input(cond); - logical_tensor_t float_cond = empty_logical_tensor_with_default_id(); - auto float_cond_val - = std::make_shared(*type_cast, 0, float_cond, true); - float_cond_val->set_data_type(dnnl::impl::data_type::f32); - type_cast->add_output(float_cond_val); - insert_empty_scratchpad(type_cast); - - float_cond_val->add_consumer(*clip, 0); - clip->add_input(float_cond_val); - logical_tensor_t clip_cond = empty_logical_tensor_with_default_id(); - auto clip_cond_val = std::make_shared(*clip, 0, clip_cond, true); - clip_cond_val->set_data_type( - float_cond_val->get_logical_tensor().data_type); - clip->add_output(clip_cond_val); - insert_empty_scratchpad(clip); - - // first multiply - src0->add_consumer(*mul1, 0); - clip_cond_val->add_consumer(*mul1, 1); - mul1->add_input(src0); - mul1->add_input(clip_cond_val); - - logical_tensor_t src0_cond = empty_logical_tensor_with_default_id(); - auto src0_val = std::make_shared(*mul1, 0, src0_cond, true); - src0_val->set_data_type(src0->get_logical_tensor().data_type); - mul1->add_output(src0_val); - insert_empty_scratchpad(mul1); - - //cond.*{-1} + 1 - clip_cond_val->add_consumer(*linear, 0); - linear->add_input(clip_cond_val); - - logical_tensor_t cond_inv = empty_logical_tensor_with_default_id(); - auto cond_inv_val = std::make_shared(*linear, 0, cond_inv, true); - cond_inv_val->set_data_type(clip_cond_val->get_logical_tensor().data_type); - linear->add_output(cond_inv_val); - insert_empty_scratchpad(linear); - - //src1.*(cond_inv) - - src1->add_consumer(*mul2, 0); - cond_inv_val->add_consumer(*mul2, 1); - mul2->add_input(src1); - mul2->add_input(cond_inv_val); - - logical_tensor_t src1_cond = empty_logical_tensor_with_default_id(); - auto src1_val = std::make_shared(*mul2, 0, src1_cond, true); - src1_val->set_data_type(src1->get_logical_tensor().data_type); - mul2->add_output(src1_val); - insert_empty_scratchpad(mul2); - - src0_val->add_consumer(*add, 0); - src1_val->add_consumer(*add, 1); - add->add_input(src0_val); - add->add_input(src1_val); - add->add_output(out_vals[0]); - insert_empty_scratchpad(add); - - // add new ops and delete select op - rewriter.to_insert(type_cast); - rewriter.to_insert(clip); - rewriter.to_insert(mul1); - rewriter.to_insert(linear); - rewriter.to_insert(mul2); - rewriter.to_insert(add); - rewriter.to_remove(op); + // For now, as primitive doesn't support broadcast for cond input, we use + // binary select primitive for non-broadcast case only. + auto require_broadcast_func = [](const std::shared_ptr &op) { + auto in_vals = op->get_input_values(); + + const dims input0_dims + = logical_tensor_wrapper_t(in_vals[0]->get_logical_tensor()) + .vdims(); + const dims input1_dims + = logical_tensor_wrapper_t(in_vals[1]->get_logical_tensor()) + .vdims(); + + const size_t input0_ndims + = logical_tensor_wrapper_t(in_vals[0]->get_logical_tensor()) + .ndims(); + const size_t input1_ndims + = logical_tensor_wrapper_t(in_vals[1]->get_logical_tensor()) + .ndims(); + + if (!(input0_ndims == input1_ndims)) { return true; } + const size_t min_ndims = std::min(input0_ndims, input1_ndims); + + // For bianry-select primitive, the select_other_input should have the + // same shape as cond input, this is the requirement of primitive + // creation. + for (size_t i = 0; i < min_ndims; i++) { + if (!(input0_dims[i] == input1_dims[i])) { return true; } + } + return false; + }; + const bool require_broadcast = require_broadcast_func(op); + if (!require_broadcast) { + auto in_vals = op->get_input_values(); + auto out_vals = op->get_output_values(); + assertm(in_vals.size() == 3 && out_vals.size() == 1, + "select should have three inputs and a output"); + auto cond = in_vals[0]; + auto src0 = in_vals[1]; + auto src1 = in_vals[2]; + // For the binary select operation, the conditional input tensor can + // only be of `s8` data type. + cond->set_data_type(dnnl::impl::data_type::s8); + + op_ptr new_op = std::make_shared(op_kind::dnnl_binary); + new_op->set_attr(op_attr::is_select, true); + new_op->set_attr(op_attr::alg_kind, + static_cast(get_binary_alg_map().at(op->get_kind()))); + new_op->merge_attributes(op->get_attributes()); + + // reconnect + cond->remove_consumer(*op, 0); + src0->remove_consumer(*op, 1); + src1->remove_consumer(*op, 2); + + // binary select primitive places the condition input tensor as the + // third input tensor. + src0->add_consumer(*new_op, 0); + src1->add_consumer(*new_op, 1); + cond->add_consumer(*new_op, 2); + + new_op->add_input(src0); + new_op->add_input(src1); + new_op->add_input(cond); + new_op->add_output(out_vals[0]); + + insert_empty_scratchpad(new_op); + rewriter.to_insert(new_op); + rewriter.to_remove(op); + } else { + auto in_vals = op->get_input_values(); + auto out_vals = op->get_output_values(); + assertm(in_vals.size() == 3 && out_vals.size() == 1, + "select should have three inputs and a output"); + auto cond = in_vals[0]; + auto src0 = in_vals[1]; + auto src1 = in_vals[2]; + cond->set_data_type(dnnl::impl::data_type::u8); + + //TODO: This reorder can be removed once eltwise_clip support int8 input + op_ptr type_cast = std::make_shared(op_kind::dnnl_reorder); + type_cast->set_attr(op_attr::change_layout, false); + + op_ptr clip = std::make_shared(op_kind::dnnl_eltwise); + clip->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::eltwise_clip)); + clip->set_attr(op_attr::alpha, 0.f); + clip->set_attr(op_attr::beta, 1.f); + + // After reorder and clip. The cond value is 0 or 1. + // Then output = src0.*cond+src1.*(cond*-1 + 1) + op_ptr mul1 = std::make_shared(op_kind::dnnl_binary); + mul1->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::binary_mul)); + mul1->merge_attributes(op->get_attributes()); + + op_ptr mul2 = std::make_shared(op_kind::dnnl_binary); + mul2->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::binary_mul)); + mul2->merge_attributes(op->get_attributes()); + + op_ptr linear = std::make_shared(op_kind::dnnl_eltwise); + linear->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::eltwise_linear)); + const float alpha_value = -1.0f, beta_value = 1.0f; + linear->set_attr(op_attr::alpha, alpha_value); + linear->set_attr(op_attr::beta, beta_value); + + op_ptr add = std::make_shared(op_kind::dnnl_binary); + add->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::binary_add)); + + // reconnect + cond->remove_consumer(*op, 0); + src0->remove_consumer(*op, 1); + src1->remove_consumer(*op, 2); + + // first reorder and clip + cond->add_consumer(*type_cast, 0); + type_cast->add_input(cond); + logical_tensor_t float_cond = empty_logical_tensor_with_default_id(); + auto float_cond_val + = std::make_shared(*type_cast, 0, float_cond, true); + float_cond_val->set_data_type(dnnl::impl::data_type::f32); + type_cast->add_output(float_cond_val); + insert_empty_scratchpad(type_cast); + + float_cond_val->add_consumer(*clip, 0); + clip->add_input(float_cond_val); + logical_tensor_t clip_cond = empty_logical_tensor_with_default_id(); + auto clip_cond_val + = std::make_shared(*clip, 0, clip_cond, true); + clip_cond_val->set_data_type( + float_cond_val->get_logical_tensor().data_type); + clip->add_output(clip_cond_val); + insert_empty_scratchpad(clip); + + // first multiply + src0->add_consumer(*mul1, 0); + clip_cond_val->add_consumer(*mul1, 1); + mul1->add_input(src0); + mul1->add_input(clip_cond_val); + + logical_tensor_t src0_cond = empty_logical_tensor_with_default_id(); + auto src0_val = std::make_shared(*mul1, 0, src0_cond, true); + src0_val->set_data_type(src0->get_logical_tensor().data_type); + mul1->add_output(src0_val); + insert_empty_scratchpad(mul1); + + //cond.*{-1} + 1 + clip_cond_val->add_consumer(*linear, 0); + linear->add_input(clip_cond_val); + + logical_tensor_t cond_inv = empty_logical_tensor_with_default_id(); + auto cond_inv_val + = std::make_shared(*linear, 0, cond_inv, true); + cond_inv_val->set_data_type( + clip_cond_val->get_logical_tensor().data_type); + linear->add_output(cond_inv_val); + insert_empty_scratchpad(linear); + + //src1.*(cond_inv) + + src1->add_consumer(*mul2, 0); + cond_inv_val->add_consumer(*mul2, 1); + mul2->add_input(src1); + mul2->add_input(cond_inv_val); + + logical_tensor_t src1_cond = empty_logical_tensor_with_default_id(); + auto src1_val = std::make_shared(*mul2, 0, src1_cond, true); + src1_val->set_data_type(src1->get_logical_tensor().data_type); + mul2->add_output(src1_val); + insert_empty_scratchpad(mul2); + + src0_val->add_consumer(*add, 0); + src1_val->add_consumer(*add, 1); + add->add_input(src0_val); + add->add_input(src1_val); + add->add_output(out_vals[0]); + insert_empty_scratchpad(add); + + // add new ops and delete select op + rewriter.to_insert(type_cast); + rewriter.to_insert(clip); + rewriter.to_insert(mul1); + rewriter.to_insert(linear); + rewriter.to_insert(mul2); + rewriter.to_insert(add); + rewriter.to_remove(op); + } return status::success; } @@ -895,7 +968,8 @@ status_t lower_down(std::shared_ptr &sg) { auto kind = cur_op->get_kind(); if (!handler_table.count(kind)) { assertm(false, - "All spec ops should be lowered to internal ops, except " + "All spec ops should be lowered to internal ops, " + "except " "for some utility ops like End, Wildcard"); return status::invalid_graph_op; } diff --git a/src/graph/backend/dnnl/passes/transform.cpp b/src/graph/backend/dnnl/passes/transform.cpp index a16106babad..04fedecbb5b 100644 --- a/src/graph/backend/dnnl/passes/transform.cpp +++ b/src/graph/backend/dnnl/passes/transform.cpp @@ -171,8 +171,8 @@ status_t replace_quant_data_with_binary_post_op( auto algo = (quant_data_op->get_kind() == op_kind::dnnl_mul_scales) ? dnnl::algorithm::binary_mul : quant_data_op->get_kind() == op_kind::dnnl_add_zps - ? dnnl::algorithm::binary_add - : dnnl::algorithm::binary_sub; + ? dnnl::algorithm::binary_add + : dnnl::algorithm::binary_sub; op_ptr bin_op = std::make_shared(op_kind::dnnl_binary); bin_op->set_attr( op_attr::alg_kind, static_cast(algo)); @@ -2266,7 +2266,8 @@ status_t binary_canonicalization(std::shared_ptr &sg) { int32_t src1_ndims = src1_lt.ndims; int32_t target_ndims = std::max(src0_ndims, src1_ndims); std::vector in_ndims {src0_ndims, src1_ndims}; - for (size_t i = 0; i < cur_op->num_inputs(); ++i) { + std::vector input_indices = {0, 1}; + for (auto i : input_indices) { if (in_ndims[i] == target_ndims) { continue; } std::vector axes(target_ndims - in_ndims[i]); @@ -2297,6 +2298,139 @@ status_t binary_canonicalization(std::shared_ptr &sg) { return infer_shape(sg); } +status_t decompose_select_to_multiple_binary_ops( + std::shared_ptr &sg) { + subgraph_rewriter_t rewriter(sg); + for (auto &op : sg->get_ops()) { + if (op->get_kind() != op_kind::dnnl_binary) continue; + if (!(op->has_attr(op_attr::is_select) + && op->get_attr(op_attr::is_select))) + continue; + op->set_attr(op_attr::is_select, false); + + auto in_vals = op->get_input_values(); + auto out_vals = op->get_output_values(); + + auto src0 = in_vals[0]; + auto src1 = in_vals[1]; + auto cond = in_vals[2]; + cond->set_data_type(dnnl::impl::data_type::u8); + + //TODO: This reorder can be removed once eltwise_clip support int8 input + op_ptr type_cast = std::make_shared(op_kind::dnnl_reorder); + type_cast->set_attr(op_attr::change_layout, false); + + op_ptr clip = std::make_shared(op_kind::dnnl_eltwise); + clip->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::eltwise_clip)); + clip->set_attr(op_attr::alpha, 0.f); + clip->set_attr(op_attr::beta, 1.f); + + // After reorder and clip. The cond value is 0 or 1. + // Then output = src0.*cond+src1.*(cond*-1 + 1) + op_ptr mul1 = std::make_shared(op_kind::dnnl_binary); + mul1->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::binary_mul)); + mul1->merge_attributes(op->get_attributes()); + + op_ptr mul2 = std::make_shared(op_kind::dnnl_binary); + mul2->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::binary_mul)); + mul2->merge_attributes(op->get_attributes()); + + op_ptr linear = std::make_shared(op_kind::dnnl_eltwise); + linear->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::eltwise_linear)); + const float alpha_value = -1.0f, beta_value = 1.0f; + linear->set_attr(op_attr::alpha, alpha_value); + linear->set_attr(op_attr::beta, beta_value); + + op_ptr add = std::make_shared(op_kind::dnnl_binary); + add->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::binary_add)); + + // reconnect + src0->remove_consumer(*op, 0); + src1->remove_consumer(*op, 1); + cond->remove_consumer(*op, 2); + + // first reorder and clip + cond->add_consumer(*type_cast, 0); + type_cast->add_input(cond); + logical_tensor_t float_cond = empty_logical_tensor_with_default_id(); + auto float_cond_val + = std::make_shared(*type_cast, 0, float_cond, true); + float_cond_val->set_data_type(dnnl::impl::data_type::f32); + type_cast->add_output(float_cond_val); + insert_empty_scratchpad(type_cast); + + float_cond_val->add_consumer(*clip, 0); + clip->add_input(float_cond_val); + logical_tensor_t clip_cond = empty_logical_tensor_with_default_id(); + auto clip_cond_val + = std::make_shared(*clip, 0, clip_cond, true); + clip_cond_val->set_data_type( + float_cond_val->get_logical_tensor().data_type); + clip->add_output(clip_cond_val); + insert_empty_scratchpad(clip); + + // first multiply + src0->add_consumer(*mul1, 0); + clip_cond_val->add_consumer(*mul1, 1); + mul1->add_input(src0); + mul1->add_input(clip_cond_val); + + logical_tensor_t src0_cond = empty_logical_tensor_with_default_id(); + auto src0_val = std::make_shared(*mul1, 0, src0_cond, true); + src0_val->set_data_type(src0->get_logical_tensor().data_type); + mul1->add_output(src0_val); + insert_empty_scratchpad(mul1); + + //cond.*{-1} + 1 + clip_cond_val->add_consumer(*linear, 0); + linear->add_input(clip_cond_val); + + logical_tensor_t cond_inv = empty_logical_tensor_with_default_id(); + auto cond_inv_val + = std::make_shared(*linear, 0, cond_inv, true); + cond_inv_val->set_data_type( + clip_cond_val->get_logical_tensor().data_type); + linear->add_output(cond_inv_val); + insert_empty_scratchpad(linear); + + //src1.*(cond_inv) + + src1->add_consumer(*mul2, 0); + cond_inv_val->add_consumer(*mul2, 1); + mul2->add_input(src1); + mul2->add_input(cond_inv_val); + + logical_tensor_t src1_cond = empty_logical_tensor_with_default_id(); + auto src1_val = std::make_shared(*mul2, 0, src1_cond, true); + src1_val->set_data_type(src1->get_logical_tensor().data_type); + mul2->add_output(src1_val); + insert_empty_scratchpad(mul2); + + src0_val->add_consumer(*add, 0); + src1_val->add_consumer(*add, 1); + add->add_input(src0_val); + add->add_input(src1_val); + add->add_output(out_vals[0]); + insert_empty_scratchpad(add); + + // add new ops and delete select op + rewriter.to_insert(type_cast); + rewriter.to_insert(clip); + rewriter.to_insert(mul1); + rewriter.to_insert(linear); + rewriter.to_insert(mul2); + rewriter.to_insert(add); + rewriter.to_remove(op); + } + rewriter.run(); + return infer_shape(sg); +} + status_t binary_broadcast_swap(std::shared_ptr &sg) { subgraph_rewriter_t rewriter(sg); diff --git a/src/graph/backend/dnnl/passes/transform.hpp b/src/graph/backend/dnnl/passes/transform.hpp index 4378c527e20..910581bdcc4 100644 --- a/src/graph/backend/dnnl/passes/transform.hpp +++ b/src/graph/backend/dnnl/passes/transform.hpp @@ -119,6 +119,15 @@ status_t fuse_to_dnnl_sum(std::shared_ptr &sg); // make the input shape meet the requirement of oneDNN binary primitive status_t binary_canonicalization(std::shared_ptr &sg); +// For now, we support two impl paths for select op: one is to use binary +// primitive with select alorithm, the other is to use multiple binary ops(we +// call it "leagcy impl" here), we have already decide which one to use in the +// previous lowering pass, however, for some kernels(such as decomp sdpa +// kernel), we want to always use the legacy impl, so this pass is created to +// decompose the select binary op back to multiple binary ops. +status_t decompose_select_to_multiple_binary_ops( + std::shared_ptr &sg); + // This pass is used to swap two inputs to broadcast src1 which is optimized in // oneDNN binary primitive. Notice that this should be applied after // binary_canonicalization and infer_shape diff --git a/src/graph/backend/dnnl/passes/utils.cpp b/src/graph/backend/dnnl/passes/utils.cpp index 571f9ea603a..955f851154b 100644 --- a/src/graph/backend/dnnl/passes/utils.cpp +++ b/src/graph/backend/dnnl/passes/utils.cpp @@ -249,7 +249,8 @@ const std::map &get_binary_alg_map() { {graph::op_kind::Minimum, dnnl::algorithm::binary_min}, {graph::op_kind::Maximum, dnnl::algorithm::binary_max}, {graph::op_kind::Subtract, dnnl::algorithm::binary_sub}, - {graph::op_kind::BiasAdd, dnnl::algorithm::binary_add}}; + {graph::op_kind::BiasAdd, dnnl::algorithm::binary_add}, + {graph::op_kind::Select, dnnl::algorithm::binary_select}}; return binary_alg_map; } diff --git a/tests/gtests/graph/unit/backend/dnnl/test_large_partition.cpp b/tests/gtests/graph/unit/backend/dnnl/test_large_partition.cpp index 3348d655b5f..e1a7f4403a3 100644 --- a/tests/gtests/graph/unit/backend/dnnl/test_large_partition.cpp +++ b/tests/gtests/graph/unit/backend/dnnl/test_large_partition.cpp @@ -593,6 +593,63 @@ TEST(test_large_partition_execute, Int8DistilBertMha) { strm->wait(); } +TEST(test_large_partition_execute, Int8DistilBertMhaWithoutBroadcast) { + graph::engine_t *eng = get_engine(); + graph::stream_t *strm = get_stream(); + + graph::graph_t g(eng->kind()); + utils::construct_select_int8_MHA_without_broadcast(&g); + g.finalize(); + + ASSERT_EQ(g.get_ops().size(), 13U); + + graph::pass::pass_base_ptr apass = get_pass("int8_sdp_fusion"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), 1U); + auto part = g.get_partitions()[0]; + + // compile + graph::partition_t p; + p.init(part); + + auto partition_inputs = p.get_inputs(); + auto partition_outputs = p.get_outputs(); + ASSERT_EQ(partition_inputs.size(), 6U); + ASSERT_EQ(partition_outputs.size(), 1U); + + std::vector inputs, outputs; + for (auto < : partition_inputs) { + inputs.emplace_back(<); + } + for (auto < : partition_outputs) { + // set output to be strided + lt = utils::logical_tensor_init( + lt.id, lt.data_type, graph::layout_type::strided); + outputs.emplace_back(<); + } + + graph::compiled_partition_t cp(p); + ASSERT_EQ(p.compile(&cp, inputs, outputs, eng), graph::status::success); + + std::vector inputs_ts, outputs_ts; + + for (auto < : inputs) { + inputs_ts.emplace_back(*lt, eng); + inputs_ts.back().fill(); + } + + for (auto < : outputs) { + graph::logical_tensor_t compiled_output; + cp.query_logical_tensor(lt->id, &compiled_output); + outputs_ts.emplace_back(compiled_output, eng); + } + + ASSERT_EQ(cp.execute(strm, test_tensor::to_graph_tensor(inputs_ts), + test_tensor::to_graph_tensor(outputs_ts)), + graph::status::success); + strm->wait(); +} + TEST(test_large_partition_execute, Int8GptMha) { graph::engine_t *eng = get_engine(); graph::stream_t *strm = get_stream(); @@ -650,6 +707,64 @@ TEST(test_large_partition_execute, Int8GptMha) { strm->wait(); } +TEST(test_large_partition_execute, Int8GptMhaWithoutBroadcast) { + graph::engine_t *eng = get_engine(); + graph::stream_t *strm = get_stream(); + + graph::graph_t g(eng->kind()); + utils::construct_select_int8_MHA_without_broadcast( + &g, 1, 32, 16, 4096, false, false, true); + g.finalize(); + + ASSERT_EQ(g.get_ops().size(), 14U); + + graph::pass::pass_base_ptr apass = get_pass("int8_fp32_gpt_sdp"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), 1U); + auto part = g.get_partitions()[0]; + + // compile + graph::partition_t p; + p.init(part); + + auto partition_inputs = p.get_inputs(); + auto partition_outputs = p.get_outputs(); + ASSERT_EQ(partition_inputs.size(), 7U); + ASSERT_EQ(partition_outputs.size(), 1U); + + std::vector inputs, outputs; + for (auto < : partition_inputs) { + inputs.emplace_back(<); + } + for (auto < : partition_outputs) { + // set output to be strided + lt = utils::logical_tensor_init( + lt.id, lt.data_type, graph::layout_type::strided); + outputs.emplace_back(<); + } + + graph::compiled_partition_t cp(p); + ASSERT_EQ(p.compile(&cp, inputs, outputs, eng), graph::status::success); + + std::vector inputs_ts, outputs_ts; + + for (auto < : inputs) { + inputs_ts.emplace_back(*lt, eng); + inputs_ts.back().fill(); + } + + for (auto < : outputs) { + graph::logical_tensor_t compiled_output; + cp.query_logical_tensor(lt->id, &compiled_output); + outputs_ts.emplace_back(compiled_output, eng); + } + + ASSERT_EQ(cp.execute(strm, test_tensor::to_graph_tensor(inputs_ts), + test_tensor::to_graph_tensor(outputs_ts)), + graph::status::success); + strm->wait(); +} + TEST(test_large_partition_execute, F32Mha) { graph::engine_t *eng = get_engine(); graph::stream_t *strm = get_stream(); @@ -770,6 +885,69 @@ TEST(test_large_partition_execute, F32DistilBertMha) { strm->wait(); } +TEST(test_large_partition_execute, F32DistilBertMhaWithoutBroadcast) { + graph::engine_t *eng = get_engine(); + graph::stream_t *strm = get_stream(); + + graph::graph_t g(eng->kind()); + utils::construct_select_float_MHA_without_broadcast(&g); + g.finalize(); + + ASSERT_EQ(g.get_ops().size(), 7U); + + graph::pass::pass_base_ptr apass = get_pass("float_sdp_fusion"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), 1U); + auto part = g.get_partitions()[0]; + + // compile + graph::partition_t p; + p.init(part); + + auto partition_inputs = p.get_inputs(); + auto partition_outputs = p.get_outputs(); + ASSERT_EQ(partition_inputs.size(), 6U); + ASSERT_EQ(partition_outputs.size(), 1U); + + std::vector inputs, outputs; + for (auto < : partition_inputs) { + inputs.emplace_back(<); + } + for (auto < : partition_outputs) { + // set output to be strided + lt = utils::logical_tensor_init( + lt.id, lt.data_type, graph::layout_type::strided); + outputs.emplace_back(<); + } + + graph::compiled_partition_t cp(p); + ASSERT_EQ(p.compile(&cp, inputs, outputs, eng), graph::status::success); + + using ltw = graph::logical_tensor_wrapper_t; + + std::vector inputs_ts, outputs_ts; + + for (auto < : inputs) { + inputs_ts.emplace_back(*lt, eng); + // For select op's bool input + if (ltw(lt).data_type() == graph::data_type::boolean) + inputs_ts.back().fill(); + else + inputs_ts.back().fill(); + } + + for (auto < : outputs) { + graph::logical_tensor_t compiled_output; + cp.query_logical_tensor(lt->id, &compiled_output); + outputs_ts.emplace_back(compiled_output, eng); + } + + ASSERT_EQ(cp.execute(strm, test_tensor::to_graph_tensor(inputs_ts), + test_tensor::to_graph_tensor(outputs_ts)), + graph::status::success); + strm->wait(); +} + TEST(test_large_partition_execute, F32GptMha) { graph::engine_t *eng = get_engine(); graph::stream_t *strm = get_stream(); @@ -834,6 +1012,70 @@ TEST(test_large_partition_execute, F32GptMha) { strm->wait(); } +TEST(test_large_partition_execute, F32GptMhaWithoutBroadcast) { + graph::engine_t *eng = get_engine(); + graph::stream_t *strm = get_stream(); + + graph::graph_t g(eng->kind()); + utils::construct_select_float_MHA_without_broadcast( + &g, graph::data_type::f32, 1, 32, 16, 4096, false, false, true); + g.finalize(); + + ASSERT_EQ(g.get_ops().size(), 8U); + + graph::pass::pass_base_ptr apass = get_pass("float_gpt_sdp"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), 1U); + auto part = g.get_partitions()[0]; + + // compile + graph::partition_t p; + p.init(part); + + auto partition_inputs = p.get_inputs(); + auto partition_outputs = p.get_outputs(); + ASSERT_EQ(partition_inputs.size(), 7U); + ASSERT_EQ(partition_outputs.size(), 1U); + + std::vector inputs, outputs; + for (auto < : partition_inputs) { + inputs.emplace_back(<); + } + for (auto < : partition_outputs) { + // set output to be strided + lt = utils::logical_tensor_init( + lt.id, lt.data_type, graph::layout_type::strided); + outputs.emplace_back(<); + } + + graph::compiled_partition_t cp(p); + ASSERT_EQ(p.compile(&cp, inputs, outputs, eng), graph::status::success); + + using ltw = graph::logical_tensor_wrapper_t; + + std::vector inputs_ts, outputs_ts; + + for (auto < : inputs) { + inputs_ts.emplace_back(*lt, eng); + // For select op's bool input + if (ltw(lt).data_type() == graph::data_type::boolean) + inputs_ts.back().fill(); + else + inputs_ts.back().fill(); + } + + for (auto < : outputs) { + graph::logical_tensor_t compiled_output; + cp.query_logical_tensor(lt->id, &compiled_output); + outputs_ts.emplace_back(compiled_output, eng); + } + + ASSERT_EQ(cp.execute(strm, test_tensor::to_graph_tensor(inputs_ts), + test_tensor::to_graph_tensor(outputs_ts)), + graph::status::success); + strm->wait(); +} + TEST(test_large_partition_execute, F32JaxMha) { graph::engine_t *eng = get_engine(); graph::stream_t *strm = get_stream(); @@ -1118,7 +1360,7 @@ TEST(test_large_partition_execute, Bf16GptMha) { strm->wait(); } -TEST(test_large_partition_execute, Bf16DistilBertMha) { +TEST(test_large_partition_execute, Bf16GptMhaWithoutBroadcast) { graph::engine_t *eng = get_engine(); graph::stream_t *strm = get_stream(); SKIP_IF(eng->kind() == graph::engine_kind::gpu, "skip on gpu"); @@ -1129,12 +1371,13 @@ TEST(test_large_partition_execute, Bf16DistilBertMha) { "Skip bf16 tests for systems that do not support avx512_core."); graph::graph_t g(eng->kind()); - utils::construct_select_float_MHA(&g, dnnl::impl::data_type::bf16); + utils::construct_select_float_MHA_without_broadcast( + &g, graph::data_type::bf16, 1, 32, 16, 4096, false, false, true); g.finalize(); - ASSERT_EQ(g.get_ops().size(), 7U); + ASSERT_EQ(g.get_ops().size(), 10U); - graph::pass::pass_base_ptr apass = get_pass("float_sdp_fusion"); + graph::pass::pass_base_ptr apass = get_pass("bfloat16_gpt_sdp"); apass->run(g); ASSERT_EQ(g.get_num_partitions(), 1U); auto part = g.get_partitions()[0]; @@ -1145,7 +1388,7 @@ TEST(test_large_partition_execute, Bf16DistilBertMha) { auto partition_inputs = p.get_inputs(); auto partition_outputs = p.get_outputs(); - ASSERT_EQ(partition_inputs.size(), 6U); + ASSERT_EQ(partition_inputs.size(), 7U); ASSERT_EQ(partition_outputs.size(), 1U); std::vector inputs, outputs; @@ -1187,10 +1430,9 @@ TEST(test_large_partition_execute, Bf16DistilBertMha) { strm->wait(); } -TEST(test_large_partition_execute, Int8Bf16Mha_CPU) { +TEST(test_large_partition_execute, Bf16DistilBertMha) { graph::engine_t *eng = get_engine(); graph::stream_t *strm = get_stream(); - SKIP_IF(eng->kind() == graph::engine_kind::gpu, "skip on gpu"); static auto isa = dnnl_get_effective_cpu_isa(); @@ -1199,12 +1441,12 @@ TEST(test_large_partition_execute, Int8Bf16Mha_CPU) { "Skip bf16 tests for systems that do not support avx512_core."); graph::graph_t g(eng->kind()); - utils::construct_int8_bf16_MHA(&g); + utils::construct_select_float_MHA(&g, dnnl::impl::data_type::bf16); g.finalize(); - ASSERT_EQ(g.get_ops().size(), 19U); + ASSERT_EQ(g.get_ops().size(), 7U); - graph::pass::pass_base_ptr apass = get_pass("int8_bf16_sdp_fusion"); + graph::pass::pass_base_ptr apass = get_pass("float_sdp_fusion"); apass->run(g); ASSERT_EQ(g.get_num_partitions(), 1U); auto part = g.get_partitions()[0]; @@ -1215,7 +1457,205 @@ TEST(test_large_partition_execute, Int8Bf16Mha_CPU) { auto partition_inputs = p.get_inputs(); auto partition_outputs = p.get_outputs(); - ASSERT_EQ(partition_inputs.size(), 5U); + ASSERT_EQ(partition_inputs.size(), 6U); + ASSERT_EQ(partition_outputs.size(), 1U); + + std::vector inputs, outputs; + for (auto < : partition_inputs) { + inputs.emplace_back(<); + } + for (auto < : partition_outputs) { + // set output to be strided + lt = utils::logical_tensor_init( + lt.id, lt.data_type, graph::layout_type::strided); + outputs.emplace_back(<); + } + + graph::compiled_partition_t cp(p); + ASSERT_EQ(p.compile(&cp, inputs, outputs, eng), graph::status::success); + + using ltw = graph::logical_tensor_wrapper_t; + + std::vector inputs_ts, outputs_ts; + + for (auto < : inputs) { + inputs_ts.emplace_back(*lt, eng); + // For select op's bool input + if (ltw(lt).data_type() == graph::data_type::boolean) + inputs_ts.back().fill(); + else + inputs_ts.back().fill(); + } + + for (auto < : outputs) { + graph::logical_tensor_t compiled_output; + cp.query_logical_tensor(lt->id, &compiled_output); + outputs_ts.emplace_back(compiled_output, eng); + } + + ASSERT_EQ(cp.execute(strm, test_tensor::to_graph_tensor(inputs_ts), + test_tensor::to_graph_tensor(outputs_ts)), + graph::status::success); + strm->wait(); +} + +TEST(test_large_partition_execute, Bf16DistilBertMhaWithoutBroadcast) { + graph::engine_t *eng = get_engine(); + graph::stream_t *strm = get_stream(); + SKIP_IF(eng->kind() == graph::engine_kind::gpu, "skip on gpu"); + + static auto isa = dnnl_get_effective_cpu_isa(); + SKIP_IF((isa < dnnl_cpu_isa_avx512_core) + && eng->kind() == graph::engine_kind::cpu, + "Skip bf16 tests for systems that do not support avx512_core."); + + graph::graph_t g(eng->kind()); + utils::construct_select_float_MHA_without_broadcast( + &g, dnnl::impl::data_type::bf16); + g.finalize(); + + ASSERT_EQ(g.get_ops().size(), 7U); + + graph::pass::pass_base_ptr apass = get_pass("float_sdp_fusion"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), 1U); + auto part = g.get_partitions()[0]; + + // compile + graph::partition_t p; + p.init(part); + + auto partition_inputs = p.get_inputs(); + auto partition_outputs = p.get_outputs(); + ASSERT_EQ(partition_inputs.size(), 6U); + ASSERT_EQ(partition_outputs.size(), 1U); + + std::vector inputs, outputs; + for (auto < : partition_inputs) { + inputs.emplace_back(<); + } + for (auto < : partition_outputs) { + // set output to be strided + lt = utils::logical_tensor_init( + lt.id, lt.data_type, graph::layout_type::strided); + outputs.emplace_back(<); + } + + graph::compiled_partition_t cp(p); + ASSERT_EQ(p.compile(&cp, inputs, outputs, eng), graph::status::success); + + using ltw = graph::logical_tensor_wrapper_t; + + std::vector inputs_ts, outputs_ts; + + for (auto < : inputs) { + inputs_ts.emplace_back(*lt, eng); + // For select op's bool input + if (ltw(lt).data_type() == graph::data_type::boolean) + inputs_ts.back().fill(); + else + inputs_ts.back().fill(); + } + + for (auto < : outputs) { + graph::logical_tensor_t compiled_output; + cp.query_logical_tensor(lt->id, &compiled_output); + outputs_ts.emplace_back(compiled_output, eng); + } + + ASSERT_EQ(cp.execute(strm, test_tensor::to_graph_tensor(inputs_ts), + test_tensor::to_graph_tensor(outputs_ts)), + graph::status::success); + strm->wait(); +} + +TEST(test_large_partition_execute, Int8Bf16Mha_CPU) { + graph::engine_t *eng = get_engine(); + graph::stream_t *strm = get_stream(); + + SKIP_IF(eng->kind() == graph::engine_kind::gpu, "skip on gpu"); + + static auto isa = dnnl_get_effective_cpu_isa(); + SKIP_IF((isa < dnnl_cpu_isa_avx512_core) + && eng->kind() == graph::engine_kind::cpu, + "Skip bf16 tests for systems that do not support avx512_core."); + + graph::graph_t g(eng->kind()); + utils::construct_int8_bf16_MHA(&g); + g.finalize(); + + ASSERT_EQ(g.get_ops().size(), 19U); + + graph::pass::pass_base_ptr apass = get_pass("int8_bf16_sdp_fusion"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), 1U); + auto part = g.get_partitions()[0]; + + // compile + graph::partition_t p; + p.init(part); + + auto partition_inputs = p.get_inputs(); + auto partition_outputs = p.get_outputs(); + ASSERT_EQ(partition_inputs.size(), 5U); + ASSERT_EQ(partition_outputs.size(), 1U); + + std::vector inputs, outputs; + for (auto < : partition_inputs) { + inputs.emplace_back(<); + } + for (auto < : partition_outputs) { + outputs.emplace_back(<); + } + + graph::compiled_partition_t cp(p); + ASSERT_EQ(p.compile(&cp, inputs, outputs, eng), graph::status::success); + + std::vector inputs_ts, outputs_ts; + + for (auto < : partition_inputs) { + inputs_ts.emplace_back(lt, eng); + } + + for (auto < : partition_outputs) { + outputs_ts.emplace_back(lt, eng); + } + + ASSERT_EQ(cp.execute(strm, test_tensor::to_graph_tensor(inputs_ts), + test_tensor::to_graph_tensor(outputs_ts)), + graph::status::success); + strm->wait(); +} + +TEST(test_large_partition_execute, Int8Bf16MhaWithoutBroadcast_CPU) { + graph::engine_t *eng = get_engine(); + graph::stream_t *strm = get_stream(); + + SKIP_IF(eng->kind() == graph::engine_kind::gpu, "skip on gpu"); + + static auto isa = dnnl_get_effective_cpu_isa(); + SKIP_IF((isa < dnnl_cpu_isa_avx512_core) + && eng->kind() == graph::engine_kind::cpu, + "Skip bf16 tests for systems that do not support avx512_core."); + + graph::graph_t g(eng->kind()); + utils::construct_int8_bf16_MHA_without_broadcast(&g); + g.finalize(); + + ASSERT_EQ(g.get_ops().size(), 19U); + + graph::pass::pass_base_ptr apass = get_pass("int8_bf16_sdp_fusion"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), 1U); + auto part = g.get_partitions()[0]; + + // compile + graph::partition_t p; + p.init(part); + + auto partition_inputs = p.get_inputs(); + auto partition_outputs = p.get_outputs(); + ASSERT_EQ(partition_inputs.size(), 5U); ASSERT_EQ(partition_outputs.size(), 1U); std::vector inputs, outputs; @@ -1304,6 +1744,65 @@ TEST(test_large_partition_execute, Int8Bf16DistilBertMha_CPU) { strm->wait(); } +TEST(test_large_partition_execute, Int8Bf16DistilBertMhaWithoutBroadcast_CPU) { + graph::engine_t *eng = get_engine(); + graph::stream_t *strm = get_stream(); + + SKIP_IF(eng->kind() == graph::engine_kind::gpu, "skip on gpu"); + + static auto isa = dnnl_get_effective_cpu_isa(); + SKIP_IF((isa < dnnl_cpu_isa_avx512_core) + && eng->kind() == graph::engine_kind::cpu, + "Skip bf16 tests for systems that do not support avx512_core."); + + graph::graph_t g(eng->kind()); + utils::construct_int8_bf16_MHA_without_broadcast( + &g, 1, 128, 12, 768, false, true, true, false); + g.finalize(); + + ASSERT_EQ(g.get_ops().size(), 19U); + + graph::pass::pass_base_ptr apass = get_pass("int8_bf16_sdp_fusion"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), 1U); + auto part = g.get_partitions()[0]; + + // compile + graph::partition_t p; + p.init(part); + + auto partition_inputs = p.get_inputs(); + auto partition_outputs = p.get_outputs(); + ASSERT_EQ(partition_inputs.size(), 6U); + ASSERT_EQ(partition_outputs.size(), 1U); + + std::vector inputs, outputs; + for (auto < : partition_inputs) { + inputs.emplace_back(<); + } + for (auto < : partition_outputs) { + outputs.emplace_back(<); + } + + graph::compiled_partition_t cp(p); + ASSERT_EQ(p.compile(&cp, inputs, outputs, eng), graph::status::success); + + std::vector inputs_ts, outputs_ts; + + for (auto < : partition_inputs) { + inputs_ts.emplace_back(lt, eng); + } + + for (auto < : partition_outputs) { + outputs_ts.emplace_back(lt, eng); + } + + ASSERT_EQ(cp.execute(strm, test_tensor::to_graph_tensor(inputs_ts), + test_tensor::to_graph_tensor(outputs_ts)), + graph::status::success); + strm->wait(); +} + TEST(test_large_partition_execute, Int8Bf16GptMha_CPU) { graph::engine_t *eng = get_engine(); graph::stream_t *strm = get_stream(); @@ -1362,3 +1861,62 @@ TEST(test_large_partition_execute, Int8Bf16GptMha_CPU) { graph::status::success); strm->wait(); } + +TEST(test_large_partition_execute, Int8Bf16GptMhaWithoutBroadcast_CPU) { + graph::engine_t *eng = get_engine(); + graph::stream_t *strm = get_stream(); + + SKIP_IF(eng->kind() == graph::engine_kind::gpu, "skip on gpu"); + + static auto isa = dnnl_get_effective_cpu_isa(); + SKIP_IF((isa < dnnl_cpu_isa_avx512_core) + && eng->kind() == graph::engine_kind::cpu, + "Skip bf16 tests for systems that do not support avx512_core."); + + graph::graph_t g(eng->kind()); + utils::construct_int8_bf16_MHA_without_broadcast( + &g, 1, 32, 16, 4096, false, true, false, true); + g.finalize(); + + ASSERT_EQ(g.get_ops().size(), 20U); + + graph::pass::pass_base_ptr apass = get_pass("int8_bf16_gpt_sdp"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), 1U); + auto part = g.get_partitions()[0]; + + // compile + graph::partition_t p; + p.init(part); + + auto partition_inputs = p.get_inputs(); + auto partition_outputs = p.get_outputs(); + ASSERT_EQ(partition_inputs.size(), 7U); + ASSERT_EQ(partition_outputs.size(), 1U); + + std::vector inputs, outputs; + for (auto < : partition_inputs) { + inputs.emplace_back(<); + } + for (auto < : partition_outputs) { + outputs.emplace_back(<); + } + + graph::compiled_partition_t cp(p); + ASSERT_EQ(p.compile(&cp, inputs, outputs, eng), graph::status::success); + + std::vector inputs_ts, outputs_ts; + + for (auto < : partition_inputs) { + inputs_ts.emplace_back(lt, eng); + } + + for (auto < : partition_outputs) { + outputs_ts.emplace_back(lt, eng); + } + + ASSERT_EQ(cp.execute(strm, test_tensor::to_graph_tensor(inputs_ts), + test_tensor::to_graph_tensor(outputs_ts)), + graph::status::success); + strm->wait(); +} diff --git a/tests/gtests/graph/unit/backend/dnnl/test_op_schema_cpu.cpp b/tests/gtests/graph/unit/backend/dnnl/test_op_schema_cpu.cpp index 812e02cf464..ea605ae6806 100644 --- a/tests/gtests/graph/unit/backend/dnnl/test_op_schema_cpu.cpp +++ b/tests/gtests/graph/unit/backend/dnnl/test_op_schema_cpu.cpp @@ -134,10 +134,11 @@ TEST(test_op_schema, DnnlBinary) { op_kind_t op_kind = dnnl_impl::op_kind::dnnl_binary; const size_t expected_in_size_lower = 2; const size_t expected_out_size = 2; - const size_t expected_attr_size = 7; + const size_t expected_attr_size = 8; const std::map attrs_data = {{op_attr::auto_broadcast, false}, - {dnnl_impl::op_attr::alg_kind, true}}; + {dnnl_impl::op_attr::alg_kind, true}, + {dnnl_impl::op_attr::is_select, false}}; verify_op_schema(op_kind, expected_in_size_lower, expected_out_size, expected_attr_size, attrs_data); diff --git a/tests/gtests/graph/unit/backend/dnnl/test_sdp_decomp.cpp b/tests/gtests/graph/unit/backend/dnnl/test_sdp_decomp.cpp index 5430a0276e9..1ac6c4904c1 100644 --- a/tests/gtests/graph/unit/backend/dnnl/test_sdp_decomp.cpp +++ b/tests/gtests/graph/unit/backend/dnnl/test_sdp_decomp.cpp @@ -1582,4 +1582,4 @@ TEST(test_sdp_decomp_execute, MultithreaSdpDecompCorr_CPU) { t1.join(); t2.join(); } -} +} \ No newline at end of file diff --git a/tests/gtests/graph/unit/backend/dnnl/test_select.cpp b/tests/gtests/graph/unit/backend/dnnl/test_select.cpp index 4044269851b..0241492ba17 100644 --- a/tests/gtests/graph/unit/backend/dnnl/test_select.cpp +++ b/tests/gtests/graph/unit/backend/dnnl/test_select.cpp @@ -26,8 +26,7 @@ using dim_t = dnnl_dim_t; using dims_t = dnnl_dims_t; using dims = std::vector; -TEST(test_select_execute, TestSelect) { - +TEST(test_select_execute, TestSelectBroadcast) { graph::engine_t *engine = get_engine(); std::vector cond(128, true); std::vector src0(1, -1); @@ -35,6 +34,7 @@ TEST(test_select_execute, TestSelect) { std::vector dst(12 * 128 * 128); for (int i = 0; i < 64; i++) cond[i] = false; + graph::op_t select_op(graph::op_kind::Select); graph::logical_tensor_t cond_lt = utils::logical_tensor_init( @@ -86,15 +86,181 @@ TEST(test_select_execute, TestSelect) { dst = dst_ts.as_vec_type(); for (size_t i = 0; i < 12 * 128; ++i) { for (size_t j = 0; j < 128; ++j) { - if (j < 64) + if (j < 64) { ASSERT_EQ(dst[i * 128 + j], 1); - else + } else { ASSERT_EQ(dst[i * 128 + j], -1); + } + } + } +} + +TEST(test_select_execute, TestSelectWithoutBroadcast) { + + graph::engine_t *engine = get_engine(); + std::vector cond(128, true); + std::vector src0(128, -1); + std::vector src1(128, 1); + std::vector dst(128); + for (int i = 0; i < 64; i++) + cond[i] = false; + + graph::op_t select_op(graph::op_kind::Select); + + graph::logical_tensor_t cond_lt = utils::logical_tensor_init( + 0, {1, 128}, graph::data_type::boolean); + + graph::logical_tensor_t src0_lt + = utils::logical_tensor_init(1, {1, 128}, graph::data_type::f32); + + graph::logical_tensor_t src1_lt + = utils::logical_tensor_init(2, {1, 128}, graph::data_type::f32); + + graph::logical_tensor_t dst_lt + = utils::logical_tensor_init(3, {1, 128}, graph::data_type::f32); + + select_op.add_input(cond_lt); + select_op.add_input(src0_lt); + select_op.add_input(src1_lt); + select_op.add_output(dst_lt); + + graph::graph_t g(engine->kind()); + g.add_op(&select_op); + g.finalize(); + + graph::pass::pass_base_ptr apass = get_pass("select_pass"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), 1U); + auto part = g.get_partitions()[0]; + + // compile + graph::partition_t p; + p.init(part); + graph::compiled_partition_t cp(p); + + std::vector inputs { + &cond_lt, &src0_lt, &src1_lt}; + std::vector outputs {&dst_lt}; + ASSERT_EQ(p.compile(&cp, inputs, outputs, engine), graph::status::success); + + graph::stream_t *stream = get_stream(); + test_tensor cond_ts(cond_lt, engine, cond); + test_tensor src0_ts(src0_lt, engine, src0); + test_tensor src1_ts(src1_lt, engine, src1); + test_tensor dst_ts(dst_lt, engine, dst); + + ASSERT_EQ(cp.execute(stream, {cond_ts.get(), src0_ts.get(), src1_ts.get()}, + {dst_ts.get()}), + graph::status::success); + stream->wait(); + dst = dst_ts.as_vec_type(); + for (size_t j = 0; j < 128; ++j) { + if (j < 64) { + ASSERT_EQ(dst[j], 1); + } else { + ASSERT_EQ(dst[j], -1); } } } -TEST(test_select_execute, MatmulSelect) { +TEST(test_select_execute, MatmulSelectWithoutBroadcast) { + graph::op_t matmul_op(0, graph::op_kind::MatMul, "MatMul"); + graph::op_t div_op(1, graph::op_kind::Divide, "div_op"); + graph::op_t select_op(2, graph::op_kind::Select, "Select"); + graph::engine_t *engine = get_engine(); + + std::vector src_data(32 * 56, 1); + std::vector weight_data(56 * 32, 1); + std::vector div_src1_data(1, 5); + std::vector select_src1_data(32 * 32, -1); + std::vector cond_data(32 * 32); + for (int i = 0; i < 32; i++) + for (int j = 0; j < 32; j++) + cond_data[i * 32 + j] = j > 1 ? true : false; + + std::vector ref_dst_data(32 * 32); + for (int i = 0; i < 1 * 32; i++) + for (int j = 0; j < 32; j++) + ref_dst_data[i * 32 + j] = j > 1 ? -1 : 11.2; + std::vector dst_data(ref_dst_data.size(), 0.0); + + // prepare logical tensor + graph::logical_tensor_t src = utils::logical_tensor_init( + 0, {1, 1, 32, 56}, graph::data_type::f32); + graph::logical_tensor_t weight = utils::logical_tensor_init( + 1, {1, 1, 56, 32}, graph::data_type::f32); + graph::logical_tensor_t matmul_dst = utils::logical_tensor_init( + 2, {1, 1, 32, 32}, graph::data_type::f32); + graph::logical_tensor_t div_src1 + = utils::logical_tensor_init(3, {1}, graph::data_type::f32); + graph::logical_tensor_t div_dst = utils::logical_tensor_init( + 4, {1, 1, 32, 32}, graph::data_type::f32); + graph::logical_tensor_t cond = utils::logical_tensor_init( + 5, {1, 1, 32, 32}, graph::data_type::boolean); + graph::logical_tensor_t select_src0 = utils::logical_tensor_init( + 6, {1, 1, 32, 32}, graph::data_type::f32); + graph::logical_tensor_t dst = utils::logical_tensor_init( + 7, {1, 1, 32, 32}, graph::data_type::f32); + + matmul_op.add_input(src); + matmul_op.add_input(weight); + matmul_op.add_output(matmul_dst); + + div_op.add_input(matmul_dst); + div_op.add_input(div_src1); + div_op.add_output(div_dst); + + select_op.add_input(cond); + select_op.add_input(select_src0); + select_op.add_input(div_dst); + select_op.add_output(dst); + + graph::graph_t g(engine->kind()); + g.add_op(&matmul_op); + g.add_op(&div_op); + g.add_op(&select_op); + g.finalize(); + + graph::pass::pass_base_ptr apass = get_pass("fp_matmul_post_ops"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), 1U); + auto part = g.get_partitions()[0]; + ASSERT_EQ(part->get_ops().size(), 3U); + + // compile + graph::partition_t p; + p.init(part); + + graph::compiled_partition_t cp(p); + + std::vector inputs { + &src, &weight, &div_src1, &cond, &select_src0}; + std::vector outputs {&dst}; + + ASSERT_EQ(p.compile(&cp, inputs, outputs, engine), graph::status::success); + + test_tensor src_ts(src, engine, src_data); + test_tensor weight_ts(weight, engine, weight_data); + test_tensor div_src1_ts(div_src1, engine, div_src1_data); + test_tensor cond_ts(cond, engine, cond_data); + test_tensor select_src0_ts(select_src0, engine, select_src1_data); + test_tensor dst_ts(dst, engine, dst_data); + + graph::stream_t *strm = get_stream(); + ASSERT_EQ(cp.execute(strm, + {src_ts.get(), weight_ts.get(), div_src1_ts.get(), + cond_ts.get(), select_src0_ts.get()}, + {dst_ts.get()}), + graph::status::success); + + strm->wait(); + dst_data = dst_ts.as_vec_type(); + for (size_t i = 0; i < 32; ++i) + for (size_t j = 0; j < 32; ++j) + ASSERT_EQ(dst_data[i * 32 + j], ref_dst_data[i * 32 + j]); +} + +TEST(test_select_execute, MatmulSelectBroadcast) { graph::op_t matmul_op(0, graph::op_kind::MatMul, "MatMul"); graph::op_t div_op(1, graph::op_kind::Divide, "div_op"); graph::op_t select_op(2, graph::op_kind::Select, "Select"); diff --git a/tests/gtests/graph/unit/utils.hpp b/tests/gtests/graph/unit/utils.hpp index bf25145a17f..a7b63ddcf2e 100644 --- a/tests/gtests/graph/unit/utils.hpp +++ b/tests/gtests/graph/unit/utils.hpp @@ -171,16 +171,12 @@ static inline void verify_op_schema(const dnnl::impl::graph::op_kind_t op_kind_, const op_schema_t *op_schema_ = op_schema_registry_t::get_op_schema(op_kind_); EXPECT_TRUE(nullptr != op_schema_); - const std::set input_size = op_schema_->get_num_inputs(); EXPECT_TRUE(input_size.find(expected_in_size) != input_size.end()); - const std::set output_size = op_schema_->get_num_outputs(); EXPECT_TRUE(output_size.find(expected_out_size) != output_size.end()); - size_t attr_size = op_schema_->get_attrs().size(); EXPECT_EQ(attr_size, expected_attr_size); - for (const auto &attr_data : attrs_data) { const auto &attr_name = attr_data.first; const auto is_required = attr_data.second; @@ -1165,6 +1161,169 @@ inline void construct_select_float_MHA(dnnl::impl::graph::graph_t *agraph, agraph->add_op(&reorder_output); } +inline void construct_select_float_MHA_without_broadcast( + dnnl::impl::graph::graph_t *agraph, + impl::data_type_t dtype = impl::data_type::f32, int batch_size = 1, + int seq_len = 128, int num_head = 12, int head_dim = 768, + bool transpose = false, bool distil = true, bool gpt = false) { + using namespace dnnl::impl::graph; + using namespace dnnl::graph::tests; + + int size_per_head = head_dim / num_head; + dims QKV_RESHAPED_SHAPE = {batch_size, seq_len, num_head, size_per_head}; + dims EXTENDED_ATTENTION_MASK_SHAPE + = {batch_size, num_head, seq_len, seq_len}; + dims QKV_TRANSPOSED_SHAPE = {batch_size, num_head, seq_len, size_per_head}; + dims KEY_TRANSPOSED_SHAPE; + if (!transpose) + KEY_TRANSPOSED_SHAPE = {batch_size, num_head, size_per_head, seq_len}; + else + KEY_TRANSPOSED_SHAPE = {batch_size, num_head, seq_len, size_per_head}; + dims MATMUL_QK_OUTPUT_SHAPE = {batch_size, num_head, seq_len, seq_len}; + dims MATMUL_V_OUTPUT_SHAPE = {batch_size, num_head, seq_len, size_per_head}; + + dims CONST_SHAPE = {1}; + + dims QKV_TRANSPOSED_ORDER = {0, 2, 1, 3}; + dims KEY_TRANSPOSED_ORDER = {0, 1, 3, 2}; + + size_t lt_id = 0; + size_t op_id = 0; + + auto attention_mask_flt = unit::utils::logical_tensor_init( + lt_id++, EXTENDED_ATTENTION_MASK_SHAPE, dtype); + + auto query_input = unit::utils::logical_tensor_init( + lt_id++, QKV_TRANSPOSED_SHAPE, dtype); + + auto key_input = unit::utils::logical_tensor_init( + lt_id++, KEY_TRANSPOSED_SHAPE, dtype); + + auto matmul_qk_out = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, dtype); + + auto fscore_scale + = unit::utils::logical_tensor_init(lt_id++, CONST_SHAPE, dtype); + fscore_scale.property = property_type::constant; + auto fscore_div_out = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, dtype); + + auto condition_input = unit::utils::logical_tensor_init( + lt_id++, EXTENDED_ATTENTION_MASK_SHAPE, impl::data_type::boolean); + // For bianry-select primitive, the select_other_input should have the same + // shape as cond input, this is the requirement of the primitive creation. + auto select_other_input = unit::utils::logical_tensor_init( + lt_id++, EXTENDED_ATTENTION_MASK_SHAPE, dtype); + auto select_output = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, dtype); + + auto softmax_input = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, dtype); + + auto softmax_out = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, dtype); + + auto tc1_out = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, impl::data_type::f32); + auto tc2_out = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, dtype); + + auto value_input = unit::utils::logical_tensor_init( + lt_id++, QKV_TRANSPOSED_SHAPE, dtype); + + auto matmul_v_out = unit::utils::logical_tensor_init( + lt_id++, MATMUL_V_OUTPUT_SHAPE, dtype); + + auto context_transpose_out = unit::utils::logical_tensor_init( + lt_id++, QKV_RESHAPED_SHAPE, dtype); + + auto context_reshape_out = unit::utils::logical_tensor_init( + lt_id++, QKV_RESHAPED_SHAPE, dtype); + + op_t matmul_qk {op_id++, op_kind::MatMul, "matmul_qk"}; + matmul_qk.set_attr(op_attr::transpose_b, transpose); + + op_t fscore_select {op_id++, op_kind::Select, "fscore_select"}; + + op_t fscore_div {op_id++, op_kind::Divide, "fscore_div"}; + fscore_div.set_attr(op_attr::auto_broadcast, std::string("numpy")); + op_t fscore_add {op_id++, op_kind::Add, "fscore_add"}; + fscore_add.set_attr(op_attr::auto_broadcast, std::string("numpy")); + op_t softmax {op_id++, op_kind::SoftMax, "softmax"}; + softmax.set_attr(op_attr::axis, (int64_t)3); + + op_t tc1 {op_id++, op_kind::TypeCast, "tc1"}; + op_t tc2 {op_id++, op_kind::TypeCast, "tc2"}; + op_t matmul_v {op_id++, op_kind::MatMul, "matmul_v"}; + + // transpose + reshape before output + op_t transpose_output { + op_id++, op_kind::StaticTranspose, "transpose_output"}; + transpose_output.set_attr>( + op_attr::order, QKV_TRANSPOSED_ORDER); + + op_t reorder_output {op_id++, op_kind::Reorder, "reorder_output"}; + + matmul_qk.add_input(query_input); + matmul_qk.add_input(key_input); + matmul_qk.add_output(matmul_qk_out); + if (gpt) { + fscore_select.add_input(condition_input); + fscore_select.add_input(matmul_qk_out); + fscore_select.add_input(select_other_input); + fscore_select.add_output(select_output); + fscore_div.add_input(select_output); + } else { + fscore_div.add_input(matmul_qk_out); + } + fscore_div.add_input(fscore_scale); + fscore_div.add_output(fscore_div_out); + + if (distil) { + fscore_select.add_input(condition_input); + fscore_select.add_input(select_other_input); + fscore_select.add_input(fscore_div_out); + fscore_select.add_output(select_output); + softmax.add_input(select_output); + } else if (gpt) { + fscore_add.add_input(fscore_div_out); + fscore_add.add_input(attention_mask_flt); + fscore_add.add_output(softmax_input); + softmax.add_input(softmax_input); + } + softmax.add_output(softmax_out); + if (gpt && dtype == impl::data_type::bf16) { + tc1.add_input(softmax_out); + tc1.add_output(tc1_out); + tc2.add_input(tc1_out); + tc2.add_output(tc2_out); + matmul_v.add_input(tc2_out); + } else { + matmul_v.add_input(softmax_out); + } + matmul_v.add_input(value_input); + matmul_v.add_output(matmul_v_out); + + transpose_output.add_input(matmul_v_out); + transpose_output.add_output(context_transpose_out); + + reorder_output.add_input(context_transpose_out); + reorder_output.add_output(context_reshape_out); + + agraph->add_op(&matmul_qk); + agraph->add_op(&fscore_div); + agraph->add_op(&fscore_select); + if (gpt) agraph->add_op(&fscore_add); + agraph->add_op(&softmax); + if (gpt && dtype == impl::data_type::bf16) { + agraph->add_op(&tc1); + agraph->add_op(&tc2); + } + agraph->add_op(&matmul_v); + agraph->add_op(&transpose_output); + agraph->add_op(&reorder_output); +} + inline void construct_dnnl_float_JAX_MHA(dnnl::impl::graph::graph_t *agraph, impl::data_type_t dtype = impl::data_type::f32, int batch_size = 1, int seq_len = 384, int num_head = 16, int head_dim = 1024) { @@ -1785,6 +1944,221 @@ inline void construct_select_int8_MHA(dnnl::impl::graph::graph_t *agraph, agraph->add_op(&quantize_output); } +inline void construct_select_int8_MHA_without_broadcast( + dnnl::impl::graph::graph_t *agraph, int batch_size = 1, + int seq_len = 128, int num_head = 12, int head_dim = 768, + bool transpose = false, bool distil = true, bool gpt = false) { + using namespace dnnl::impl::graph; + using namespace dnnl::graph::tests; + + int size_per_head = head_dim / num_head; + dims QKV_RESHAPED_SHAPE = {batch_size, seq_len, num_head, size_per_head}; + dims EXTENDED_ATTENTION_MASK_SHAPE + = {batch_size, num_head, seq_len, seq_len}; + dims QKV_TRANSPOSED_SHAPE = {batch_size, num_head, seq_len, size_per_head}; + dims KEY_TRANSPOSED_SHAPE; + if (!transpose) + KEY_TRANSPOSED_SHAPE = {batch_size, num_head, size_per_head, seq_len}; + else + KEY_TRANSPOSED_SHAPE = {batch_size, num_head, seq_len, size_per_head}; + dims MATMUL_QK_OUTPUT_SHAPE = {batch_size, num_head, seq_len, seq_len}; + dims MATMUL_V_OUTPUT_SHAPE = {batch_size, num_head, seq_len, size_per_head}; + + dims CONST_SHAPE = {1}; + + dims QKV_TRANSPOSED_ORDER = {0, 2, 1, 3}; + dims KEY_TRANSPOSED_ORDER = {0, 1, 3, 2}; + + size_t lt_id = 0; + size_t op_id = 0; + + auto attention_mask_flt = unit::utils::logical_tensor_init( + lt_id++, EXTENDED_ATTENTION_MASK_SHAPE, data_type::f32); + + auto query_input = unit::utils::logical_tensor_init( + lt_id++, QKV_TRANSPOSED_SHAPE, data_type::u8); + auto query_dequantize = unit::utils::logical_tensor_init( + lt_id++, QKV_TRANSPOSED_SHAPE, data_type::f32); + + auto key_input = unit::utils::logical_tensor_init( + lt_id++, KEY_TRANSPOSED_SHAPE, data_type::u8); + auto key_dequantize = unit::utils::logical_tensor_init( + lt_id++, KEY_TRANSPOSED_SHAPE, data_type::f32); + + auto matmul_qk_out = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, data_type::f32); + + auto fscore_scale = unit::utils::logical_tensor_init( + lt_id++, CONST_SHAPE, data_type::f32); + fscore_scale.property = property_type::constant; + auto fscore_div_out = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, data_type::f32); + + auto condition_input = unit::utils::logical_tensor_init( + lt_id++, EXTENDED_ATTENTION_MASK_SHAPE, impl::data_type::boolean); + // For bianry-select primitive, the select_other_input should have the same + // shape as cond input, this is the requirement of the primitive creation. + auto select_other_input = unit::utils::logical_tensor_init( + lt_id++, EXTENDED_ATTENTION_MASK_SHAPE, data_type::f32); + auto select_output = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, data_type::f32); + + auto softmax_input = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, data_type::f32); + + auto softmax_out = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, data_type::f32); + + auto softmax_out_q = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, data_type::u8); + auto softmax_out_deq = unit::utils::logical_tensor_init( + lt_id++, MATMUL_QK_OUTPUT_SHAPE, data_type::f32); + + auto value_input = unit::utils::logical_tensor_init( + lt_id++, QKV_TRANSPOSED_SHAPE, data_type::u8); + + auto value_dequantize = unit::utils::logical_tensor_init( + lt_id++, QKV_TRANSPOSED_SHAPE, data_type::f32); + + auto matmul_v_out = unit::utils::logical_tensor_init( + lt_id++, MATMUL_V_OUTPUT_SHAPE, data_type::f32); + + auto context_transpose_out = unit::utils::logical_tensor_init( + lt_id++, QKV_RESHAPED_SHAPE, data_type::f32); + + auto context_reshape_out = unit::utils::logical_tensor_init( + lt_id++, QKV_RESHAPED_SHAPE, data_type::f32); + auto context_out = unit::utils::logical_tensor_init( + lt_id++, QKV_RESHAPED_SHAPE, data_type::u8); + + op_t dequantize_query {op_id++, op_kind::Dequantize, "dequantize_query"}; + dequantize_query.set_attr(op_attr::scales, std::vector({0.12f})); + dequantize_query.set_attr(op_attr::zps, std::vector({2})); + dequantize_query.set_attr(op_attr::qtype, std::string("per_tensor")); + dequantize_query.set_attr(op_attr::axis, (int64_t)0); + + op_t dequantize_key {op_id++, op_kind::Dequantize, "dequantize_key"}; + dequantize_key.set_attr(op_attr::scales, std::vector({0.12f})); + dequantize_key.set_attr(op_attr::zps, std::vector({2})); + dequantize_key.set_attr(op_attr::qtype, std::string("per_tensor")); + dequantize_key.set_attr(op_attr::axis, (int64_t)0); + + op_t matmul_qk {op_id++, op_kind::MatMul, "matmul_qk"}; + matmul_qk.set_attr(op_attr::transpose_b, transpose); + + op_t fscore_select {op_id++, op_kind::Select, "fscore_select"}; + + op_t fscore_div {op_id++, op_kind::Divide, "fscore_div"}; + fscore_div.set_attr(op_attr::auto_broadcast, std::string("numpy")); + op_t fscore_add {op_id++, op_kind::Add, "fscore_add"}; + fscore_add.set_attr(op_attr::auto_broadcast, std::string("numpy")); + op_t softmax {op_id++, op_kind::SoftMax, "softmax"}; + softmax.set_attr(op_attr::axis, (int64_t)3); + + // quantize-dequantize softmax's output + op_t quantize_softmax {op_id++, op_kind::Quantize, "quantize_softmax"}; + op_t dequantize_softmax { + op_id++, op_kind::Dequantize, "dequantize_softmax"}; + quantize_softmax.set_attr(op_attr::scales, std::vector({0.12f})); + quantize_softmax.set_attr(op_attr::zps, std::vector({0})); + quantize_softmax.set_attr(op_attr::qtype, std::string("per_tensor")); + quantize_softmax.set_attr(op_attr::axis, (int64_t)0); + dequantize_softmax.set_attr(op_attr::scales, std::vector({0.12f})); + dequantize_softmax.set_attr(op_attr::zps, std::vector({2})); + dequantize_softmax.set_attr(op_attr::qtype, std::string("per_tensor")); + dequantize_softmax.set_attr(op_attr::axis, (int64_t)0); + + op_t dequantize_value {op_id++, op_kind::Dequantize, "dequantize_value"}; + dequantize_value.set_attr(op_attr::scales, std::vector({0.12f})); + dequantize_value.set_attr(op_attr::zps, std::vector({2})); + dequantize_value.set_attr(op_attr::qtype, std::string("per_tensor")); + dequantize_value.set_attr(op_attr::axis, (int64_t)0); + + op_t matmul_v {op_id++, op_kind::MatMul, "matmul_v"}; + + // transpose + reshape before output + op_t transpose_output { + op_id++, op_kind::StaticTranspose, "transpose_output"}; + transpose_output.set_attr>( + op_attr::order, QKV_TRANSPOSED_ORDER); + + op_t reorder_output {op_id++, op_kind::Reorder, "reorder_output"}; + + op_t quantize_output {op_id++, op_kind::Quantize, "quantize_value"}; + quantize_output.set_attr(op_attr::scales, std::vector({0.12f})); + quantize_output.set_attr(op_attr::zps, std::vector({2})); + quantize_output.set_attr(op_attr::qtype, std::string("per_tensor")); + quantize_output.set_attr(op_attr::axis, (int64_t)0); + + dequantize_query.add_input(query_input); + dequantize_query.add_output(query_dequantize); + dequantize_key.add_input(key_input); + dequantize_key.add_output(key_dequantize); + matmul_qk.add_input(query_dequantize); + matmul_qk.add_input(key_dequantize); + matmul_qk.add_output(matmul_qk_out); + if (gpt) { + fscore_select.add_input(condition_input); + fscore_select.add_input(matmul_qk_out); + fscore_select.add_input(select_other_input); + fscore_select.add_output(select_output); + fscore_div.add_input(select_output); + } else { + fscore_div.add_input(matmul_qk_out); + } + fscore_div.add_input(fscore_scale); + fscore_div.add_output(fscore_div_out); + + if (distil) { + fscore_select.add_input(condition_input); + fscore_select.add_input(select_other_input); + fscore_select.add_input(fscore_div_out); + fscore_select.add_output(select_output); + softmax.add_input(select_output); + } else if (gpt) { + fscore_add.add_input(fscore_div_out); + fscore_add.add_input(attention_mask_flt); + fscore_add.add_output(softmax_input); + softmax.add_input(softmax_input); + } + softmax.add_output(softmax_out); + quantize_softmax.add_input(softmax_out); + quantize_softmax.add_output(softmax_out_q); + dequantize_softmax.add_input(softmax_out_q); + dequantize_softmax.add_output(softmax_out_deq); + + dequantize_value.add_input(value_input); + dequantize_value.add_output(value_dequantize); + + matmul_v.add_input(softmax_out_deq); + matmul_v.add_input(value_dequantize); + matmul_v.add_output(matmul_v_out); + + transpose_output.add_input(matmul_v_out); + transpose_output.add_output(context_transpose_out); + + reorder_output.add_input(context_transpose_out); + reorder_output.add_output(context_reshape_out); + + quantize_output.add_input(context_reshape_out); + quantize_output.add_output(context_out); + + agraph->add_op(&dequantize_query); + agraph->add_op(&dequantize_key); + agraph->add_op(&matmul_qk); + agraph->add_op(&fscore_div); + agraph->add_op(&fscore_select); + if (gpt) agraph->add_op(&fscore_add); + agraph->add_op(&softmax); + agraph->add_op(&quantize_softmax); + agraph->add_op(&dequantize_softmax); + agraph->add_op(&dequantize_value); + agraph->add_op(&matmul_v); + agraph->add_op(&transpose_output); + agraph->add_op(&reorder_output); + agraph->add_op(&quantize_output); +} + inline void construct_int8_bf16_MHA(dnnl::impl::graph::graph_t *agraph, int batch_size = 1, int seq_len = 384, int num_head = 16, int head_dim = 1024, bool transpose = false, bool attention_mask = true, @@ -1869,6 +2243,91 @@ inline void construct_int8_bf16_MHA(dnnl::impl::graph::graph_t *agraph, } } +inline void construct_int8_bf16_MHA_without_broadcast( + dnnl::impl::graph::graph_t *agraph, int batch_size = 1, + int seq_len = 384, int num_head = 16, int head_dim = 1024, + bool transpose = false, bool attention_mask = true, bool distil = false, + bool gpt = false) { + using namespace dnnl::impl::graph; + using namespace dnnl::graph::tests; + + // construct a int8 MHA pattern first + if (!distil && !gpt) + construct_int8_MHA(agraph, batch_size, seq_len, num_head, head_dim, + transpose, attention_mask); + else + construct_select_int8_MHA_without_broadcast(agraph, batch_size, seq_len, + num_head, head_dim, transpose, distil, gpt); + + // change the f32 logical tensor to bf16 + for (auto &op : agraph->get_ops()) { + for (auto &val : op->get_input_values()) { + if (val->get_logical_tensor().data_type + == impl::graph::data_type::f32) + val->set_data_type(impl::graph::data_type::bf16); + } + + for (auto &val : op->get_output_values()) { + if (val->get_logical_tensor().data_type + == impl::graph::data_type::f32) + val->set_data_type(impl::graph::data_type::bf16); + } + } + + // insert bf16->f32 typecase op before quantize and f32->bf16 op after + // dequantize + std::vector> target_ops; + for (auto &op : agraph->get_ops()) { + if (op->get_kind() == impl::graph::op_kind::Quantize + || op->get_kind() == impl::graph::op_kind::Dequantize) { + target_ops.emplace_back(op); + } + } + + std::vector> to_be_inserted; + size_t new_lt_id_start = 1000; + for (auto &op : target_ops) { + // insert bf16->f32 typecase op before quantize + if (op->get_kind() == impl::graph::op_kind::Quantize) { + auto bf16_to_f32 + = agraph->create_op(op_kind::TypeCast, "bf16_to_f32"); + + auto in_val = op->get_input_value(0); + in_val->remove_consumer(*op, 0); + in_val->add_consumer(*bf16_to_f32, bf16_to_f32->num_inputs()); + bf16_to_f32->add_input(in_val); + + auto new_lt = in_val->get_logical_tensor(); + new_lt.id = new_lt_id_start++; + new_lt.data_type = impl::graph::data_type::f32; + auto new_val + = std::make_shared(*bf16_to_f32, 0, new_lt, false); + bf16_to_f32->add_output(new_val); + + new_val->add_consumer(*op, 0); + op->connect_input(0, new_val); + } + + // insert f32->bf16 op after dequantize + if (op->get_kind() == impl::graph::op_kind::Dequantize) { + auto f32_to_bf16 + = agraph->create_op(op_kind::TypeCast, "f32_to_bf16"); + + auto out_val = op->get_output_value(0); + f32_to_bf16->add_output(out_val); + + auto new_lt = out_val->get_logical_tensor(); + new_lt.id = new_lt_id_start++; + new_lt.data_type = impl::graph::data_type::f32; + auto new_val = std::make_shared(*op, 0, new_lt, false); + op->connect_output(0, new_val); + + new_val->add_consumer(*f32_to_bf16, f32_to_bf16->num_inputs()); + f32_to_bf16->add_input(new_val); + } + } +} + inline void construct_chained_relu(dnnl::impl::graph::graph_t *agraph) { using namespace dnnl::impl::graph; using namespace dnnl::graph::tests;