Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: backend: dnnl: support select with binary primitive #2349

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/graph/backend/dnnl/dnnl_op_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_binary, 1,
.set_num_outputs(2)
.set_input(0, "a")
.set_input(1, "b")
.set_input(2, "cond")
.set_output(0, "output")
.set_output(1, "scratchpad")
// Attributes inherited from front binary ops (Add, Multiply,
Expand Down
66 changes: 57 additions & 9 deletions src/graph/backend/dnnl/dnnl_shape_infer.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,9 +15,9 @@
*******************************************************************************/

#include <algorithm>
#include <unordered_set>

#include "graph/interface/shape_infer.hpp"
#include "oneapi/dnnl/dnnl.hpp"
#include <unordered_set>

#include "graph/backend/dnnl/dnnl_shape_infer.hpp"
#include "graph/backend/dnnl/internal_attrs.hpp"
Expand Down Expand Up @@ -484,17 +484,65 @@ status_t infer_dnnl_pool_bwd_output_shape(op_t *n,
return status::success;
}

status_t infer_binary_select_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs) {
auto in0 = logical_tensor_wrapper_t(inputs[0]);
auto in1 = logical_tensor_wrapper_t(inputs[1]);
auto in2 = logical_tensor_wrapper_t(inputs[2]);

const bool shapes_should_match = n->has_attr(op_attr::auto_broadcast)
? "none" == n->get_attr<std::string>(op_attr::auto_broadcast)
: false;

dims input0_dims = in0.vdims();
dims input1_dims = in1.vdims();
dims input2_dims = in2.vdims();
dims inferred_out_shape;

if (shapes_should_match) { // no broadcast
VCHECK_INVALID_SHAPE(
(input0_dims == input1_dims && input1_dims == input2_dims),
"%s, all input dims should match each other if there is no "
"broadcast. input0 dims: %s, input1 dims: %s, input2 dims: %s ",
op_t::kind2str(n->get_kind()).c_str(),
dims2str(input0_dims).c_str(), dims2str(input1_dims).c_str(),
dims2str(input2_dims).c_str());
inferred_out_shape = std::move(input0_dims);
} else { // can broadcast
status_t ret1 = broadcast(input0_dims, input1_dims, inferred_out_shape);
Jiexin-Zheng marked this conversation as resolved.
Show resolved Hide resolved
VCHECK_INVALID_SHAPE((ret1 == status::success),
"%s, failed to implement numpy broadcasting",
op_t::kind2str(n->get_kind()).c_str());
}

auto out0 = logical_tensor_wrapper_t(outputs[0]);
// check if given or partial set shape aligns with inferred shape
if (!out0.is_shape_unknown() || out0.ndims() != -1) {
VCHECK_INVALID_SHAPE(validate(inferred_out_shape, out0.vdims()),
"%s, inferred out shape and output shape are not compatible",
op_t::kind2str(n->get_kind()).c_str());
if (!out0.is_shape_unknown()) return status::success;
}

set_shape_and_strides(*outputs[0], inferred_out_shape);
return status::success;
}

status_t infer_dnnl_binary_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs) {
const bool is_bias_add = n->has_attr(op_attr::is_bias_add)
&& n->get_attr<bool>(op_attr::is_bias_add);

auto ret = is_bias_add
? infer_bias_add_output_shape(n, inputs, outputs)
: infer_elemwise_arithmetic_output_shape(n, inputs, outputs);

return ret;
const algorithm algo = static_cast<dnnl::algorithm>(
n->get_attr<int64_t>(op_attr::alg_kind));
if (algo == algorithm::binary_select) {
return infer_binary_select_output_shape(n, inputs, outputs);
Jiexin-Zheng marked this conversation as resolved.
Show resolved Hide resolved
} else if (is_bias_add) {
return infer_bias_add_output_shape(n, inputs, outputs);
} else {
return infer_elemwise_arithmetic_output_shape(n, inputs, outputs);
}
}

} // namespace dnnl_impl
Expand Down
6 changes: 5 additions & 1 deletion src/graph/backend/dnnl/dnnl_shape_infer.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -103,6 +103,10 @@ status_t infer_dnnl_binary_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

status_t infer_binary_select_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
Expand Down
2 changes: 2 additions & 0 deletions src/graph/backend/dnnl/kernels/large_partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ void larger_partition_kernel_t::setup_pipeline_stage1(
pass_pipeline_t &pipeline) {
// Directly lower down (1 to 1 mapping)
BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
// Decompose select to binary ops if necessary
BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_binary_ops);

// Indirectly lower down (N to 1 mapping)
BACKEND_DNNL_ADD_PASS(pipeline, fuse_reciprocal_mul_to_div);
Expand Down
3 changes: 3 additions & 0 deletions src/graph/backend/dnnl/kernels/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ status_t matmul_t<quantized>::compile_impl(const dnnl_partition_impl_t *part,
pass_pipeline_t pipeline(vis);

BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
// Decompose select to binary ops if necessary
BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_binary_ops);

BACKEND_DNNL_ADD_PASS(pipeline, fuse_bias_add);
// check if bias exists
BACKEND_DNNL_ADD_PASS(pipeline, check_with_bias);
Expand Down
4 changes: 3 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_decomp.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
* Copyright 2024-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,6 +60,8 @@ status_t sdp_decomp_kernel_t<quantized, dt>::compile_impl(
pass_pipeline_t pipeline = pass_pipeline_t(vis);
pass_pipeline_t select_pipeline = pass_pipeline_t(vis);
BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
// Decompose select to binary ops if necessary
BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_binary_ops);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_reshape_for_gqa);
// Fusion and canonicalization passes begin
if (quantized) {
Expand Down
19 changes: 18 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
* Copyright 2024-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -58,6 +58,23 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
"Only supports single scale value, but got %lld", scale_sz);
}

// Check select cond and src0 shape
if (graph_inport[5] != -1 && graph_inport[6] != -1) {
const auto select_cond_dims = ltw(inputs[graph_inport[5]]).vdims();
const auto select_src0_dims = ltw(inputs[graph_inport[6]]).vdims();
VCHECK_SDP_DECOMP(select_cond_dims.size() == select_src0_dims.size(),
false,
"Select cond and src0 dims should be same, but got %zu and %zu",
select_cond_dims.size(), select_src0_dims.size());
for (size_t i = 0; i < select_cond_dims.size(); i++) {

VCHECK_SDP_DECOMP(select_cond_dims[i] == select_src0_dims[i], false,
"Select cond and src0 dims should be same, but got %lld "
"and %lld",
select_cond_dims[i], select_src0_dims[i]);
Jiexin-Zheng marked this conversation as resolved.
Show resolved Hide resolved
}
}

#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP
// RATIO is an empirical value used to determine the numerical relationship
// between batch_size, num_head_q and thread number to determine whether to use
Expand Down
3 changes: 3 additions & 0 deletions src/graph/backend/dnnl/kernels/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ status_t select_t::compile_impl(const dnnl_partition_impl_t *part,
pass_pipeline_t pipeline(vis);

BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
// Decompose select to binary ops if necessary
BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_binary_ops);

BACKEND_DNNL_ADD_PASS(pipeline, binary_canonicalization);

BACKEND_DNNL_ADD_PASS(pipeline, fuse_post_ops);
Expand Down
17 changes: 14 additions & 3 deletions src/graph/backend/dnnl/op_executable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1252,8 +1252,15 @@ binary_executable_t::desc_t binary_executable_t::create_desc(
op->get_attr<int64_t>(op_attr::alg_kind));

dnnl::binary::primitive_desc pd;
pd = dnnl::binary::primitive_desc(
p_engine, algo, src0, src1, dst, prm_attr);
if (algo == algorithm::binary_select) {
auto src2 = make_dnnl_memory_desc(
op->get_input_value(2)->get_logical_tensor());
pd = dnnl::binary::primitive_desc(
p_engine, algo, src0, src1, src2, dst, prm_attr);
} else {
pd = dnnl::binary::primitive_desc(
p_engine, algo, src0, src1, dst, prm_attr);
}

pd_cache.insert({op.get(), pd});

Expand Down Expand Up @@ -1891,12 +1898,16 @@ arg_indices_t matmul_executable_t::get_arg_indices(
arg_indices_t binary_executable_t::get_arg_indices(
const op_t *op, fusion_info_mgr_t &mgr) {
arg_indices_t arg_indices;
const algorithm algo = static_cast<dnnl::algorithm>(
op->get_attr<int64_t>(op_attr::alg_kind));

// add input args
size_t index = 0;
arg_indices.insert({DNNL_ARG_SRC_0, indices_t {input, index++}});
arg_indices.insert({DNNL_ARG_SRC_1, indices_t {input, index++}});

if (algo == algorithm::binary_select) {
arg_indices.insert({DNNL_ARG_SRC_2, indices_t {input, index++}});
}
get_arg_indices_for_post_ops(op, mgr, arg_indices, index);

// add output args
Expand Down
123 changes: 21 additions & 102 deletions src/graph/backend/dnnl/passes/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,114 +664,33 @@ static status_t select_handler(
auto cond = in_vals[0];
auto src0 = in_vals[1];
auto src1 = in_vals[2];
cond->set_data_type(dnnl::impl::data_type::u8);

//TODO: This reorder can be removed once eltwise_clip support int8 input
op_ptr type_cast = std::make_shared<op_t>(op_kind::dnnl_reorder);
type_cast->set_attr<bool>(op_attr::change_layout, false);

op_ptr clip = std::make_shared<op_t>(op_kind::dnnl_eltwise);
clip->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(dnnl::algorithm::eltwise_clip));
clip->set_attr<float>(op_attr::alpha, 0.f);
clip->set_attr<float>(op_attr::beta, 1.f);

// After reorder and clip. The cond value is 0 or 1.
// Then output = src0.*cond+src1.*(cond*-1 + 1)
op_ptr mul1 = std::make_shared<op_t>(op_kind::dnnl_binary);
mul1->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(dnnl::algorithm::binary_mul));
mul1->merge_attributes(op->get_attributes());

op_ptr mul2 = std::make_shared<op_t>(op_kind::dnnl_binary);
mul2->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(dnnl::algorithm::binary_mul));
mul2->merge_attributes(op->get_attributes());

op_ptr linear = std::make_shared<op_t>(op_kind::dnnl_eltwise);
linear->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(dnnl::algorithm::eltwise_linear));
const float alpha_value = -1.0f, beta_value = 1.0f;
linear->set_attr<float>(op_attr::alpha, alpha_value);
linear->set_attr<float>(op_attr::beta, beta_value);

op_ptr add = std::make_shared<op_t>(op_kind::dnnl_binary);
add->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(dnnl::algorithm::binary_add));
// For the binary select operation, the conditional input tensor can
// only be of `s8` data type.
cond->set_data_type(dnnl::impl::data_type::s8);

op_ptr new_op = std::make_shared<op_t>(op_kind::dnnl_binary);
new_op->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(get_binary_alg_map().at(op->get_kind())));
new_op->merge_attributes(op->get_attributes());

// reconnect
cond->remove_consumer(*op, 0);
src0->remove_consumer(*op, 1);
src1->remove_consumer(*op, 2);

// first reorder and clip
cond->add_consumer(*type_cast, 0);
type_cast->add_input(cond);
logical_tensor_t float_cond = empty_logical_tensor_with_default_id();
auto float_cond_val
= std::make_shared<value_t>(*type_cast, 0, float_cond, true);
float_cond_val->set_data_type(dnnl::impl::data_type::f32);
type_cast->add_output(float_cond_val);
insert_empty_scratchpad(type_cast);

float_cond_val->add_consumer(*clip, 0);
clip->add_input(float_cond_val);
logical_tensor_t clip_cond = empty_logical_tensor_with_default_id();
auto clip_cond_val = std::make_shared<value_t>(*clip, 0, clip_cond, true);
clip_cond_val->set_data_type(
float_cond_val->get_logical_tensor().data_type);
clip->add_output(clip_cond_val);
insert_empty_scratchpad(clip);

// first multiply
src0->add_consumer(*mul1, 0);
clip_cond_val->add_consumer(*mul1, 1);
mul1->add_input(src0);
mul1->add_input(clip_cond_val);

logical_tensor_t src0_cond = empty_logical_tensor_with_default_id();
auto src0_val = std::make_shared<value_t>(*mul1, 0, src0_cond, true);
src0_val->set_data_type(src0->get_logical_tensor().data_type);
mul1->add_output(src0_val);
insert_empty_scratchpad(mul1);

//cond.*{-1} + 1
clip_cond_val->add_consumer(*linear, 0);
linear->add_input(clip_cond_val);

logical_tensor_t cond_inv = empty_logical_tensor_with_default_id();
auto cond_inv_val = std::make_shared<value_t>(*linear, 0, cond_inv, true);
cond_inv_val->set_data_type(clip_cond_val->get_logical_tensor().data_type);
linear->add_output(cond_inv_val);
insert_empty_scratchpad(linear);

//src1.*(cond_inv)

src1->add_consumer(*mul2, 0);
cond_inv_val->add_consumer(*mul2, 1);
mul2->add_input(src1);
mul2->add_input(cond_inv_val);

logical_tensor_t src1_cond = empty_logical_tensor_with_default_id();
auto src1_val = std::make_shared<value_t>(*mul2, 0, src1_cond, true);
src1_val->set_data_type(src1->get_logical_tensor().data_type);
mul2->add_output(src1_val);
insert_empty_scratchpad(mul2);

src0_val->add_consumer(*add, 0);
src1_val->add_consumer(*add, 1);
add->add_input(src0_val);
add->add_input(src1_val);
add->add_output(out_vals[0]);
insert_empty_scratchpad(add);

// add new ops and delete select op
rewriter.to_insert(type_cast);
rewriter.to_insert(clip);
rewriter.to_insert(mul1);
rewriter.to_insert(linear);
rewriter.to_insert(mul2);
rewriter.to_insert(add);
// binary select primitive places the condition input tensor as the
// third input tensor.
src0->add_consumer(*new_op, 0);
src1->add_consumer(*new_op, 1);
cond->add_consumer(*new_op, 2);

new_op->add_input(src0);
new_op->add_input(src1);
new_op->add_input(cond);
new_op->add_output(out_vals[0]);

insert_empty_scratchpad(new_op);
rewriter.to_insert(new_op);
rewriter.to_remove(op);

return status::success;
Expand Down
Loading
Loading