From 9b5bbe6d6ee948af4a924aa9b44f1b3aec618f9c Mon Sep 17 00:00:00 2001 From: Jiexin-Zheng Date: Fri, 10 Jan 2025 06:48:58 +0000 Subject: [PATCH 1/2] graph: backend,interface: add select binary impl --- src/graph/backend/dnnl/dnnl_op_def.hpp | 1 + src/graph/backend/dnnl/dnnl_shape_infer.cpp | 66 ++++++-- src/graph/backend/dnnl/dnnl_shape_infer.hpp | 6 +- .../backend/dnnl/kernels/large_partition.cpp | 2 + src/graph/backend/dnnl/kernels/matmul.cpp | 3 + src/graph/backend/dnnl/kernels/sdp_decomp.cpp | 4 +- .../dnnl/kernels/sdp_decomp_config.cpp | 19 ++- src/graph/backend/dnnl/kernels/select.cpp | 3 + src/graph/backend/dnnl/op_executable.cpp | 17 +- src/graph/backend/dnnl/passes/lower.cpp | 123 +++------------ src/graph/backend/dnnl/passes/transform.cpp | 146 +++++++++++++++++- src/graph/backend/dnnl/passes/transform.hpp | 10 +- src/graph/backend/dnnl/passes/utils.cpp | 18 ++- src/graph/backend/dnnl/passes/utils.hpp | 5 +- src/graph/interface/shape_infer.cpp | 7 +- src/graph/interface/shape_infer.hpp | 4 +- 16 files changed, 306 insertions(+), 128 deletions(-) diff --git a/src/graph/backend/dnnl/dnnl_op_def.hpp b/src/graph/backend/dnnl/dnnl_op_def.hpp index 148efc50817..5dd7a8e1776 100644 --- a/src/graph/backend/dnnl/dnnl_op_def.hpp +++ b/src/graph/backend/dnnl/dnnl_op_def.hpp @@ -701,6 +701,7 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_binary, 1, .set_num_outputs(2) .set_input(0, "a") .set_input(1, "b") + .set_input(2, "cond") .set_output(0, "output") .set_output(1, "scratchpad") // Attributes inherited from front binary ops (Add, Multiply, diff --git a/src/graph/backend/dnnl/dnnl_shape_infer.cpp b/src/graph/backend/dnnl/dnnl_shape_infer.cpp index 781fe979190..b94ab7a87aa 100644 --- a/src/graph/backend/dnnl/dnnl_shape_infer.cpp +++ b/src/graph/backend/dnnl/dnnl_shape_infer.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,9 @@ *******************************************************************************/ #include -#include - #include "graph/interface/shape_infer.hpp" +#include "oneapi/dnnl/dnnl.hpp" +#include #include "graph/backend/dnnl/dnnl_shape_infer.hpp" #include "graph/backend/dnnl/internal_attrs.hpp" @@ -484,17 +484,65 @@ status_t infer_dnnl_pool_bwd_output_shape(op_t *n, return status::success; } +status_t infer_binary_select_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs) { + auto in0 = logical_tensor_wrapper_t(inputs[0]); + auto in1 = logical_tensor_wrapper_t(inputs[1]); + auto in2 = logical_tensor_wrapper_t(inputs[2]); + + const bool shapes_should_match = n->has_attr(op_attr::auto_broadcast) + ? "none" == n->get_attr(op_attr::auto_broadcast) + : false; + + dims input0_dims = in0.vdims(); + dims input1_dims = in1.vdims(); + dims input2_dims = in2.vdims(); + dims inferred_out_shape; + + if (shapes_should_match) { // no broadcast + VCHECK_INVALID_SHAPE( + (input0_dims == input1_dims && input1_dims == input2_dims), + "%s, all input dims should match each other if there is no " + "broadcast. input0 dims: %s, input1 dims: %s, input2 dims: %s ", + op_t::kind2str(n->get_kind()).c_str(), + dims2str(input0_dims).c_str(), dims2str(input1_dims).c_str(), + dims2str(input2_dims).c_str()); + inferred_out_shape = std::move(input0_dims); + } else { // can broadcast + status_t ret1 = broadcast(input0_dims, input1_dims, inferred_out_shape); + VCHECK_INVALID_SHAPE((ret1 == status::success), + "%s, failed to implement numpy broadcasting", + op_t::kind2str(n->get_kind()).c_str()); + } + + auto out0 = logical_tensor_wrapper_t(outputs[0]); + // check if given or partial set shape aligns with inferred shape + if (!out0.is_shape_unknown() || out0.ndims() != -1) { + VCHECK_INVALID_SHAPE(validate(inferred_out_shape, out0.vdims()), + "%s, inferred out shape and output shape are not compatible", + op_t::kind2str(n->get_kind()).c_str()); + if (!out0.is_shape_unknown()) return status::success; + } + + set_shape_and_strides(*outputs[0], inferred_out_shape); + return status::success; +} + status_t infer_dnnl_binary_output_shape(op_t *n, std::vector &inputs, std::vector &outputs) { const bool is_bias_add = n->has_attr(op_attr::is_bias_add) && n->get_attr(op_attr::is_bias_add); - - auto ret = is_bias_add - ? infer_bias_add_output_shape(n, inputs, outputs) - : infer_elemwise_arithmetic_output_shape(n, inputs, outputs); - - return ret; + const algorithm algo = static_cast( + n->get_attr(op_attr::alg_kind)); + if (algo == algorithm::binary_select) { + return infer_binary_select_output_shape(n, inputs, outputs); + } else if (is_bias_add) { + return infer_bias_add_output_shape(n, inputs, outputs); + } else { + return infer_elemwise_arithmetic_output_shape(n, inputs, outputs); + } } } // namespace dnnl_impl diff --git a/src/graph/backend/dnnl/dnnl_shape_infer.hpp b/src/graph/backend/dnnl/dnnl_shape_infer.hpp index 22ef21b65ae..78368597062 100644 --- a/src/graph/backend/dnnl/dnnl_shape_infer.hpp +++ b/src/graph/backend/dnnl/dnnl_shape_infer.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -103,6 +103,10 @@ status_t infer_dnnl_binary_output_shape(op_t *n, std::vector &inputs, std::vector &outputs); +status_t infer_binary_select_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs); + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/kernels/large_partition.cpp b/src/graph/backend/dnnl/kernels/large_partition.cpp index 9962d2f473a..d7514da0c1d 100644 --- a/src/graph/backend/dnnl/kernels/large_partition.cpp +++ b/src/graph/backend/dnnl/kernels/large_partition.cpp @@ -36,6 +36,8 @@ void larger_partition_kernel_t::setup_pipeline_stage1( pass_pipeline_t &pipeline) { // Directly lower down (1 to 1 mapping) BACKEND_DNNL_ADD_PASS(pipeline, lower_down); + // Decompose select to binary ops if necessary + BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_binary_ops); // Indirectly lower down (N to 1 mapping) BACKEND_DNNL_ADD_PASS(pipeline, fuse_reciprocal_mul_to_div); diff --git a/src/graph/backend/dnnl/kernels/matmul.cpp b/src/graph/backend/dnnl/kernels/matmul.cpp index f0fc7193e4a..17005554cba 100644 --- a/src/graph/backend/dnnl/kernels/matmul.cpp +++ b/src/graph/backend/dnnl/kernels/matmul.cpp @@ -50,6 +50,9 @@ status_t matmul_t::compile_impl(const dnnl_partition_impl_t *part, pass_pipeline_t pipeline(vis); BACKEND_DNNL_ADD_PASS(pipeline, lower_down); + // Decompose select to binary ops if necessary + BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_binary_ops); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_bias_add); // check if bias exists BACKEND_DNNL_ADD_PASS(pipeline, check_with_bias); diff --git a/src/graph/backend/dnnl/kernels/sdp_decomp.cpp b/src/graph/backend/dnnl/kernels/sdp_decomp.cpp index ffde91622e0..9e1361d7add 100644 --- a/src/graph/backend/dnnl/kernels/sdp_decomp.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_decomp.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,6 +60,8 @@ status_t sdp_decomp_kernel_t::compile_impl( pass_pipeline_t pipeline = pass_pipeline_t(vis); pass_pipeline_t select_pipeline = pass_pipeline_t(vis); BACKEND_DNNL_ADD_PASS(pipeline, lower_down); + // Decompose select to binary ops if necessary + BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_binary_ops); BACKEND_DNNL_ADD_PASS(pipeline, fuse_reshape_for_gqa); // Fusion and canonicalization passes begin if (quantized) { diff --git a/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp b/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp index 8e49149d2c5..d09567286ad 100644 --- a/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,6 +58,23 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr &sg, "Only supports single scale value, but got %lld", scale_sz); } + // Check select cond and src0 shape + if (graph_inport[5] != -1 && graph_inport[6] != -1) { + const auto select_cond_dims = ltw(inputs[graph_inport[5]]).vdims(); + const auto select_src0_dims = ltw(inputs[graph_inport[6]]).vdims(); + VCHECK_SDP_DECOMP(select_cond_dims.size() == select_src0_dims.size(), + false, + "Select cond and src0 dims should be same, but got %zu and %zu", + select_cond_dims.size(), select_src0_dims.size()); + for (size_t i = 0; i < select_cond_dims.size(); i++) { + + VCHECK_SDP_DECOMP(select_cond_dims[i] == select_src0_dims[i], false, + "Select cond and src0 dims should be same, but got %lld " + "and %lld", + select_cond_dims[i], select_src0_dims[i]); + } + } + #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP // RATIO is an empirical value used to determine the numerical relationship // between batch_size, num_head_q and thread number to determine whether to use diff --git a/src/graph/backend/dnnl/kernels/select.cpp b/src/graph/backend/dnnl/kernels/select.cpp index 1434c8bc2cd..9e7b60fe118 100644 --- a/src/graph/backend/dnnl/kernels/select.cpp +++ b/src/graph/backend/dnnl/kernels/select.cpp @@ -49,6 +49,9 @@ status_t select_t::compile_impl(const dnnl_partition_impl_t *part, pass_pipeline_t pipeline(vis); BACKEND_DNNL_ADD_PASS(pipeline, lower_down); + // Decompose select to binary ops if necessary + BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_binary_ops); + BACKEND_DNNL_ADD_PASS(pipeline, binary_canonicalization); BACKEND_DNNL_ADD_PASS(pipeline, fuse_post_ops); diff --git a/src/graph/backend/dnnl/op_executable.cpp b/src/graph/backend/dnnl/op_executable.cpp index 55b0c7e9f96..ef433a0a413 100644 --- a/src/graph/backend/dnnl/op_executable.cpp +++ b/src/graph/backend/dnnl/op_executable.cpp @@ -1252,8 +1252,15 @@ binary_executable_t::desc_t binary_executable_t::create_desc( op->get_attr(op_attr::alg_kind)); dnnl::binary::primitive_desc pd; - pd = dnnl::binary::primitive_desc( - p_engine, algo, src0, src1, dst, prm_attr); + if (algo == algorithm::binary_select) { + auto src2 = make_dnnl_memory_desc( + op->get_input_value(2)->get_logical_tensor()); + pd = dnnl::binary::primitive_desc( + p_engine, algo, src0, src1, src2, dst, prm_attr); + } else { + pd = dnnl::binary::primitive_desc( + p_engine, algo, src0, src1, dst, prm_attr); + } pd_cache.insert({op.get(), pd}); @@ -1891,12 +1898,16 @@ arg_indices_t matmul_executable_t::get_arg_indices( arg_indices_t binary_executable_t::get_arg_indices( const op_t *op, fusion_info_mgr_t &mgr) { arg_indices_t arg_indices; + const algorithm algo = static_cast( + op->get_attr(op_attr::alg_kind)); // add input args size_t index = 0; arg_indices.insert({DNNL_ARG_SRC_0, indices_t {input, index++}}); arg_indices.insert({DNNL_ARG_SRC_1, indices_t {input, index++}}); - + if (algo == algorithm::binary_select) { + arg_indices.insert({DNNL_ARG_SRC_2, indices_t {input, index++}}); + } get_arg_indices_for_post_ops(op, mgr, arg_indices, index); // add output args diff --git a/src/graph/backend/dnnl/passes/lower.cpp b/src/graph/backend/dnnl/passes/lower.cpp index 18589444b90..888520dbb46 100644 --- a/src/graph/backend/dnnl/passes/lower.cpp +++ b/src/graph/backend/dnnl/passes/lower.cpp @@ -664,114 +664,33 @@ static status_t select_handler( auto cond = in_vals[0]; auto src0 = in_vals[1]; auto src1 = in_vals[2]; - cond->set_data_type(dnnl::impl::data_type::u8); - - //TODO: This reorder can be removed once eltwise_clip support int8 input - op_ptr type_cast = std::make_shared(op_kind::dnnl_reorder); - type_cast->set_attr(op_attr::change_layout, false); - - op_ptr clip = std::make_shared(op_kind::dnnl_eltwise); - clip->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::eltwise_clip)); - clip->set_attr(op_attr::alpha, 0.f); - clip->set_attr(op_attr::beta, 1.f); - - // After reorder and clip. The cond value is 0 or 1. - // Then output = src0.*cond+src1.*(cond*-1 + 1) - op_ptr mul1 = std::make_shared(op_kind::dnnl_binary); - mul1->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::binary_mul)); - mul1->merge_attributes(op->get_attributes()); - - op_ptr mul2 = std::make_shared(op_kind::dnnl_binary); - mul2->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::binary_mul)); - mul2->merge_attributes(op->get_attributes()); - - op_ptr linear = std::make_shared(op_kind::dnnl_eltwise); - linear->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::eltwise_linear)); - const float alpha_value = -1.0f, beta_value = 1.0f; - linear->set_attr(op_attr::alpha, alpha_value); - linear->set_attr(op_attr::beta, beta_value); - - op_ptr add = std::make_shared(op_kind::dnnl_binary); - add->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::binary_add)); + // For the binary select operation, the conditional input tensor can + // only be of `s8` data type. + cond->set_data_type(dnnl::impl::data_type::s8); + + op_ptr new_op = std::make_shared(op_kind::dnnl_binary); + new_op->set_attr(op_attr::alg_kind, + static_cast(get_binary_alg_map().at(op->get_kind()))); + new_op->merge_attributes(op->get_attributes()); // reconnect cond->remove_consumer(*op, 0); src0->remove_consumer(*op, 1); src1->remove_consumer(*op, 2); - // first reorder and clip - cond->add_consumer(*type_cast, 0); - type_cast->add_input(cond); - logical_tensor_t float_cond = empty_logical_tensor_with_default_id(); - auto float_cond_val - = std::make_shared(*type_cast, 0, float_cond, true); - float_cond_val->set_data_type(dnnl::impl::data_type::f32); - type_cast->add_output(float_cond_val); - insert_empty_scratchpad(type_cast); - - float_cond_val->add_consumer(*clip, 0); - clip->add_input(float_cond_val); - logical_tensor_t clip_cond = empty_logical_tensor_with_default_id(); - auto clip_cond_val = std::make_shared(*clip, 0, clip_cond, true); - clip_cond_val->set_data_type( - float_cond_val->get_logical_tensor().data_type); - clip->add_output(clip_cond_val); - insert_empty_scratchpad(clip); - - // first multiply - src0->add_consumer(*mul1, 0); - clip_cond_val->add_consumer(*mul1, 1); - mul1->add_input(src0); - mul1->add_input(clip_cond_val); - - logical_tensor_t src0_cond = empty_logical_tensor_with_default_id(); - auto src0_val = std::make_shared(*mul1, 0, src0_cond, true); - src0_val->set_data_type(src0->get_logical_tensor().data_type); - mul1->add_output(src0_val); - insert_empty_scratchpad(mul1); - - //cond.*{-1} + 1 - clip_cond_val->add_consumer(*linear, 0); - linear->add_input(clip_cond_val); - - logical_tensor_t cond_inv = empty_logical_tensor_with_default_id(); - auto cond_inv_val = std::make_shared(*linear, 0, cond_inv, true); - cond_inv_val->set_data_type(clip_cond_val->get_logical_tensor().data_type); - linear->add_output(cond_inv_val); - insert_empty_scratchpad(linear); - - //src1.*(cond_inv) - - src1->add_consumer(*mul2, 0); - cond_inv_val->add_consumer(*mul2, 1); - mul2->add_input(src1); - mul2->add_input(cond_inv_val); - - logical_tensor_t src1_cond = empty_logical_tensor_with_default_id(); - auto src1_val = std::make_shared(*mul2, 0, src1_cond, true); - src1_val->set_data_type(src1->get_logical_tensor().data_type); - mul2->add_output(src1_val); - insert_empty_scratchpad(mul2); - - src0_val->add_consumer(*add, 0); - src1_val->add_consumer(*add, 1); - add->add_input(src0_val); - add->add_input(src1_val); - add->add_output(out_vals[0]); - insert_empty_scratchpad(add); - - // add new ops and delete select op - rewriter.to_insert(type_cast); - rewriter.to_insert(clip); - rewriter.to_insert(mul1); - rewriter.to_insert(linear); - rewriter.to_insert(mul2); - rewriter.to_insert(add); + // binary select primitive places the condition input tensor as the + // third input tensor. + src0->add_consumer(*new_op, 0); + src1->add_consumer(*new_op, 1); + cond->add_consumer(*new_op, 2); + + new_op->add_input(src0); + new_op->add_input(src1); + new_op->add_input(cond); + new_op->add_output(out_vals[0]); + + insert_empty_scratchpad(new_op); + rewriter.to_insert(new_op); rewriter.to_remove(op); return status::success; diff --git a/src/graph/backend/dnnl/passes/transform.cpp b/src/graph/backend/dnnl/passes/transform.cpp index a16106babad..b5133e6135b 100644 --- a/src/graph/backend/dnnl/passes/transform.cpp +++ b/src/graph/backend/dnnl/passes/transform.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2021-2024 Intel Corporation + * Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -2268,6 +2268,9 @@ status_t binary_canonicalization(std::shared_ptr &sg) { std::vector in_ndims {src0_ndims, src1_ndims}; for (size_t i = 0; i < cur_op->num_inputs(); ++i) { if (in_ndims[i] == target_ndims) { continue; } + // For binary select op, broadcast for the third input is + // unsupported. + if (i == 2) { continue; } std::vector axes(target_ndims - in_ndims[i]); std::iota(axes.begin(), axes.end(), 0); @@ -2297,6 +2300,147 @@ status_t binary_canonicalization(std::shared_ptr &sg) { return infer_shape(sg); } +status_t decompose_select_to_binary_ops(std::shared_ptr &sg) { + subgraph_rewriter_t rewriter(sg); + for (auto &op : sg->get_ops()) { + if (op->get_kind() != op_kind::dnnl_binary) continue; + const algorithm algo = static_cast( + op->get_attr(op_attr::alg_kind)); + + if (algo != algorithm::binary_select) continue; + + // For the binary select primitive, broadcast semantics are not + // supported for the third conditional input tensor. For this case, the + // shape of the conditional input tensor must match that of the source 0 + // tensor. + // The binary select primitive is unsupported on GPU. + const bool require_broadcast = need_broadcast_for_inputs(op, 0, 2); + if (!require_broadcast && sg->get_engine_kind() != engine_kind::gpu) + continue; + + auto in_vals = op->get_input_values(); + auto out_vals = op->get_output_values(); + + auto src0 = in_vals[0]; + auto src1 = in_vals[1]; + auto cond = in_vals[2]; + cond->set_data_type(dnnl::impl::data_type::u8); + + //TODO: This reorder can be removed once eltwise_clip support int8 input + op_ptr type_cast = std::make_shared(op_kind::dnnl_reorder); + type_cast->set_attr(op_attr::change_layout, false); + + op_ptr clip = std::make_shared(op_kind::dnnl_eltwise); + clip->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::eltwise_clip)); + clip->set_attr(op_attr::alpha, 0.f); + clip->set_attr(op_attr::beta, 1.f); + + // After reorder and clip. The cond value is 0 or 1. + // Then output = src0.*cond+src1.*(cond*-1 + 1) + op_ptr mul1 = std::make_shared(op_kind::dnnl_binary); + mul1->merge_attributes(op->get_attributes()); + mul1->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::binary_mul)); + + op_ptr mul2 = std::make_shared(op_kind::dnnl_binary); + mul2->merge_attributes(op->get_attributes()); + mul2->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::binary_mul)); + + op_ptr linear = std::make_shared(op_kind::dnnl_eltwise); + linear->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::eltwise_linear)); + const float alpha_value = -1.0f, beta_value = 1.0f; + linear->set_attr(op_attr::alpha, alpha_value); + linear->set_attr(op_attr::beta, beta_value); + + op_ptr add = std::make_shared(op_kind::dnnl_binary); + add->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::binary_add)); + + // reconnect + src0->remove_consumer(*op, 0); + src1->remove_consumer(*op, 1); + cond->remove_consumer(*op, 2); + + // first reorder and clip + cond->add_consumer(*type_cast, 0); + type_cast->add_input(cond); + logical_tensor_t float_cond = empty_logical_tensor_with_default_id(); + auto float_cond_val + = std::make_shared(*type_cast, 0, float_cond, true); + float_cond_val->set_data_type(dnnl::impl::data_type::f32); + type_cast->add_output(float_cond_val); + insert_empty_scratchpad(type_cast); + + float_cond_val->add_consumer(*clip, 0); + clip->add_input(float_cond_val); + logical_tensor_t clip_cond = empty_logical_tensor_with_default_id(); + auto clip_cond_val + = std::make_shared(*clip, 0, clip_cond, true); + clip_cond_val->set_data_type( + float_cond_val->get_logical_tensor().data_type); + clip->add_output(clip_cond_val); + insert_empty_scratchpad(clip); + + // first multiply + src0->add_consumer(*mul1, 0); + clip_cond_val->add_consumer(*mul1, 1); + mul1->add_input(src0); + mul1->add_input(clip_cond_val); + + logical_tensor_t src0_cond = empty_logical_tensor_with_default_id(); + auto src0_val = std::make_shared(*mul1, 0, src0_cond, true); + src0_val->set_data_type(src0->get_logical_tensor().data_type); + mul1->add_output(src0_val); + insert_empty_scratchpad(mul1); + + //cond.*{-1} + 1 + clip_cond_val->add_consumer(*linear, 0); + linear->add_input(clip_cond_val); + + logical_tensor_t cond_inv = empty_logical_tensor_with_default_id(); + auto cond_inv_val + = std::make_shared(*linear, 0, cond_inv, true); + cond_inv_val->set_data_type( + clip_cond_val->get_logical_tensor().data_type); + linear->add_output(cond_inv_val); + insert_empty_scratchpad(linear); + + //src1.*(cond_inv) + + src1->add_consumer(*mul2, 0); + cond_inv_val->add_consumer(*mul2, 1); + mul2->add_input(src1); + mul2->add_input(cond_inv_val); + + logical_tensor_t src1_cond = empty_logical_tensor_with_default_id(); + auto src1_val = std::make_shared(*mul2, 0, src1_cond, true); + src1_val->set_data_type(src1->get_logical_tensor().data_type); + mul2->add_output(src1_val); + insert_empty_scratchpad(mul2); + + src0_val->add_consumer(*add, 0); + src1_val->add_consumer(*add, 1); + add->add_input(src0_val); + add->add_input(src1_val); + add->add_output(out_vals[0]); + insert_empty_scratchpad(add); + + // add new ops and delete select op + rewriter.to_insert(type_cast); + rewriter.to_insert(clip); + rewriter.to_insert(mul1); + rewriter.to_insert(linear); + rewriter.to_insert(mul2); + rewriter.to_insert(add); + rewriter.to_remove(op); + } + rewriter.run(); + return infer_shape(sg); +} + status_t binary_broadcast_swap(std::shared_ptr &sg) { subgraph_rewriter_t rewriter(sg); diff --git a/src/graph/backend/dnnl/passes/transform.hpp b/src/graph/backend/dnnl/passes/transform.hpp index 4378c527e20..ef7329a4d89 100644 --- a/src/graph/backend/dnnl/passes/transform.hpp +++ b/src/graph/backend/dnnl/passes/transform.hpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2021-2024 Intel Corporation + * Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -119,6 +119,14 @@ status_t fuse_to_dnnl_sum(std::shared_ptr &sg); // make the input shape meet the requirement of oneDNN binary primitive status_t binary_canonicalization(std::shared_ptr &sg); +// For now, we support two impl paths for select op: one is to use binary +// primitive with select alorithm, the other is to use multiple binary ops(we +// call it "legacy impl" here). However, during the lowering pass, we directly +// lower the front-end select op to single binary select op, this pass is used +// to decide which impl path to apply and then decompose the select binary op +// back to multiple binary ops if it's the case to use legacy impl. +status_t decompose_select_to_binary_ops(std::shared_ptr &sg); + // This pass is used to swap two inputs to broadcast src1 which is optimized in // oneDNN binary primitive. Notice that this should be applied after // binary_canonicalization and infer_shape diff --git a/src/graph/backend/dnnl/passes/utils.cpp b/src/graph/backend/dnnl/passes/utils.cpp index 98e7093d668..88159700f96 100644 --- a/src/graph/backend/dnnl/passes/utils.cpp +++ b/src/graph/backend/dnnl/passes/utils.cpp @@ -250,7 +250,8 @@ const std::map &get_binary_alg_map() { {graph::op_kind::Maximum, dnnl::algorithm::binary_max}, {graph::op_kind::Subtract, dnnl::algorithm::binary_sub}, {graph::op_kind::BiasAdd, dnnl::algorithm::binary_add}, - {graph::op_kind::GreaterEqual, dnnl::algorithm::binary_ge}}; + {graph::op_kind::GreaterEqual, dnnl::algorithm::binary_ge}, + {graph::op_kind::Select, dnnl::algorithm::binary_select}}; return binary_alg_map; } @@ -646,6 +647,21 @@ bool inverse_mul_scales(std::shared_ptr &scale_op) { return true; } +bool need_broadcast_for_inputs( + const std::shared_ptr &op, size_t index1, size_t index2) { + auto in_vals = op->get_input_values(); + + const dims input1_dims + = logical_tensor_wrapper_t(in_vals[index1]->get_logical_tensor()) + .vdims(); + const dims input2_dims + = logical_tensor_wrapper_t(in_vals[index2]->get_logical_tensor()) + .vdims(); + + if (input1_dims != input2_dims) { return true; } + + return false; +} } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/passes/utils.hpp b/src/graph/backend/dnnl/passes/utils.hpp index 912f2bc531b..6ab4536157d 100644 --- a/src/graph/backend/dnnl/passes/utils.hpp +++ b/src/graph/backend/dnnl/passes/utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -348,6 +348,9 @@ std::shared_ptr clone_mul_scales(const std::shared_ptr &scale_op); // This function is used to inverse scales of a dnnl_mul_scales op bool inverse_mul_scales(std::shared_ptr &scale_op); +bool need_broadcast_for_inputs( + const std::shared_ptr &op, size_t index1, size_t index2); + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/interface/shape_infer.cpp b/src/graph/interface/shape_infer.cpp index 8f1c8a3d94e..556eb631958 100644 --- a/src/graph/interface/shape_infer.cpp +++ b/src/graph/interface/shape_infer.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,9 +32,6 @@ namespace dnnl { namespace impl { namespace graph { -// utils function -namespace { - std::string dims2str(const dims &dims) { if (dims.empty()) return std::string(""); @@ -45,8 +42,6 @@ std::string dims2str(const dims &dims) { return str; } -} // namespace - /// convert shape to ncx or oix dims canonicalize(const dims &shape, const std::string &format) { dims ret(shape); diff --git a/src/graph/interface/shape_infer.hpp b/src/graph/interface/shape_infer.hpp index 976e4c481ff..a9b72305cd3 100644 --- a/src/graph/interface/shape_infer.hpp +++ b/src/graph/interface/shape_infer.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,6 +74,8 @@ status_t infer_auto_pad(const dim_t in_dim, const dim_t stride, /// TODO(xxx): 0-D broadcasting? status_t broadcast(const dims &lhs, const dims &rhs, dims &broadcasted); +std::string dims2str(const dims &dims); + status_t one_way_broadcast(const dims &lhs, const dims &rhs); /// This function assumes the size of all vectors are correct. Eg. size of From 66e2b1f6ee1aea54f4cc2a7b72940f4218228d43 Mon Sep 17 00:00:00 2001 From: Jiexin-Zheng Date: Fri, 10 Jan 2025 06:52:47 +0000 Subject: [PATCH 2/2] benchdnn: graph: add select broadcast cases --- tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all | 2 ++ tests/benchdnn/inputs/graph/op/harness_bf16_all | 2 ++ tests/benchdnn/inputs/graph/op/harness_f16_all | 2 ++ tests/benchdnn/inputs/graph/op/harness_f32_all | 1 + 4 files changed, 7 insertions(+) diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all index 13d8e7ccd6d..6f1d9b9680c 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all @@ -35,6 +35,7 @@ --reset --dt=f32,bf16,f16 --in-shapes=3:20x16x384x64+4:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json --reset --dt=f32,bf16,f16 --in-shapes=3:10x16x384x64+4:10x1x64x384+0:10x1x384x64+1:10x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json --reset --dt=f32,bf16,f16 --in-shapes=4:56x12x128x64+5:56x12x64x128+0:56x12x128x64+1:56x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json +--reset --dt=f32,bf16,f16 --in-shapes=2:1x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json --reset --dt=f32,bf16,f16 --in-shapes=0:56x8x1024x80+1:56x8x77x80+2:56x8x77x80 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json --reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=5:20x117x48x128+6:20x1x128x117+19:20x1x117x128 --case=complex_fusion/mha/MHA-starcoder-inf-fp32-bs1.json --reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=2514:32x16x512x64+2518:32x16x512x64+2543:32x1x512x512+2547:32x16x512x512+2525:32x16x512x64 --op-attrs=4837:shape:16384x1024 --case=complex_fusion/mha/MHA_forward-Bert_large-train-fp32-bs4.json @@ -51,6 +52,7 @@ --reset --expected-n-partitions=0 --in-shapes=4:4x32x32x128+3:4x32x128x33+0:4x32x33x128+1:4x1x32x33 --case=complex_fusion/mha/MHA-LLaMa-inf-int8-bs1.json --reset --in-shapes=4:20x16x384x64+3:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-int8-bs1.json --reset --in-shapes=5:56x12x128x64+4:56x12x64x128+0:56x12x128x64+1:56x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json +--reset --in-shapes=2:1x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json --reset --expected-n-partitions=0 --in-shapes=4:20x117x48x128+3:20x1x128x117+0:20x1x117x128 --case=complex_fusion/mha/MHA-starcoder-inf-int8-bs1.json --reset --expected-n-partitions=0 --in-shapes=4:32x16x384x64+3:32x16x64x384+0:32x16x384x64+1:32x1x1x384 --case=complex_fusion/mha/dynamic_quantized_mha-Bert_large-inf-int8-bs1-fake.json --reset --in-shapes=4:20x16x384x64+3:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/sdpa-plain-wo-scale-int8-bs1.json diff --git a/tests/benchdnn/inputs/graph/op/harness_bf16_all b/tests/benchdnn/inputs/graph/op/harness_bf16_all index 7e4e8abc9aa..70847b2f2b9 100644 --- a/tests/benchdnn/inputs/graph/op/harness_bf16_all +++ b/tests/benchdnn/inputs/graph/op/harness_bf16_all @@ -153,6 +153,8 @@ --reset --dt=bf16 --in-shapes=1:1x1x1x1 --case=op/f32/greaterequal.json --reset --dt=bf16 --in-shapes=1:1 --case=op/f32/greaterequal.json +# select +--reset --dt=bf16 --in-shapes=2:1x1x1x128 --case=op/f32/select.json # concat --reset --dt=bf16 --in-shapes=0:1x4096x14x14+1:1x4096x14x14 --case=op/f32/concat.json --reset --dt=bf16 --in-shapes=0:64x128x28x28+1:64x128x28x28 --op-attrs=0:axis:1 --case=op/f32/concat.json diff --git a/tests/benchdnn/inputs/graph/op/harness_f16_all b/tests/benchdnn/inputs/graph/op/harness_f16_all index ee77a8943d8..b6539efd726 100644 --- a/tests/benchdnn/inputs/graph/op/harness_f16_all +++ b/tests/benchdnn/inputs/graph/op/harness_f16_all @@ -153,6 +153,8 @@ --reset --dt=f16 --in-shapes=1:1x1x1x1 --case=op/f32/greaterequal.json --reset --dt=f16 --in-shapes=1:1 --case=op/f32/greaterequal.json +# select +--reset --dt=bf16 --in-shapes=2:1x1x1x128 --case=op/f32/select.json # concat --reset --dt=f16 --in-shapes=0:1x4096x14x14+1:1x4096x14x14 --case=op/f32/concat.json --reset --dt=f16 --in-shapes=0:64x128x28x28+1:64x128x28x28 --op-attrs=0:axis:1 --case=op/f32/concat.json diff --git a/tests/benchdnn/inputs/graph/op/harness_f32_all b/tests/benchdnn/inputs/graph/op/harness_f32_all index ff8781a57e1..da402ae12ff 100644 --- a/tests/benchdnn/inputs/graph/op/harness_f32_all +++ b/tests/benchdnn/inputs/graph/op/harness_f32_all @@ -948,6 +948,7 @@ --reset --in-shapes=0:2x9x3x5x7*acdeb+1:2x9x2x8x12*acdeb --op-attrs=0:sizes:2x8x12*mode:linear --case=op/f32/interpolate_bwd.json --reset --in-shapes=0:2x9x3x8x6*acdeb+1:2x9x2x5x12*acdeb --op-attrs=0:sizes:2x5x12*mode:linear --case=op/f32/interpolate_bwd.json --reset --in-shapes=0:2x9x3x8x7*acdeb+1:2x9x2x5x12*acdeb --op-attrs=0:sizes:2x5x12*mode:linear --case=op/f32/interpolate_bwd.json +--reset --in-shapes=2:1x1x1x128 --case=op/f32/select.json --reset --case=op/f32/select.json --reset --case=op/f32/gnorm.json --reset --case=op/f32/static_reshape.json