Skip to content

Commit

Permalink
graph: backend: dnnl: support select with binary primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiexin-Zheng committed Jan 8, 2025
1 parent 72a9eb8 commit b94bafa
Show file tree
Hide file tree
Showing 25 changed files with 318 additions and 128 deletions.
2 changes: 2 additions & 0 deletions src/graph/backend/dnnl/dnnl_op_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_binary, 1,
.set_num_outputs(2)
.set_input(0, "a")
.set_input(1, "b")
.set_input(2, "cond")
.set_output(0, "output")
.set_output(1, "scratchpad")
// Attributes inherited from front binary ops (Add, Multiply,
Expand All @@ -713,6 +714,7 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_binary, 1,
{"NXC", "NCX"})
// New added attributes
.set_attr(op_attr::is_bias_add, false, attribute_kind::b, false)
.set_attr(op_attr::is_select, false, attribute_kind::b, false)
.set_attr(op_attr::fusion_info_key, false, attribute_kind::i,
(int64_t)-1)
.set_attr(op_attr::alg_kind, true, attribute_kind::i)
Expand Down
60 changes: 54 additions & 6 deletions src/graph/backend/dnnl/dnnl_shape_infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,17 +484,65 @@ status_t infer_dnnl_pool_bwd_output_shape(op_t *n,
return status::success;
}

status_t infer_binary_select_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs) {
auto in0 = logical_tensor_wrapper_t(inputs[0]);
auto in1 = logical_tensor_wrapper_t(inputs[1]);
auto in2 = logical_tensor_wrapper_t(inputs[2]);

const bool shapes_should_match = n->has_attr(op_attr::auto_broadcast)
? "none" == n->get_attr<std::string>(op_attr::auto_broadcast)
: false;

dims input0_dims = in0.vdims();
dims input1_dims = in1.vdims();
dims input2_dims = in2.vdims();
dims inferred_out_shape;

if (shapes_should_match) { // no broadcast
VCHECK_INVALID_SHAPE(
(input0_dims == input1_dims && input1_dims == input2_dims),
"%s, all input dims should match each other if there is no "
"broadcast. input0 dims: %s, input1 dims: %s, input2 dims: %s ",
op_t::kind2str(n->get_kind()).c_str(),
dims2str(input0_dims).c_str(), dims2str(input1_dims).c_str(),
dims2str(input2_dims).c_str());
inferred_out_shape = std::move(input0_dims);
} else { // can broadcast
status_t ret1 = broadcast(input0_dims, input1_dims, inferred_out_shape);
VCHECK_INVALID_SHAPE((ret1 == status::success),
"%s, failed to implement numpy broadcasting",
op_t::kind2str(n->get_kind()).c_str());
}

auto out0 = logical_tensor_wrapper_t(outputs[0]);
// check if given or partial set shape aligns with inferred shape
if (!out0.is_shape_unknown() || out0.ndims() != -1) {
VCHECK_INVALID_SHAPE(validate(inferred_out_shape, out0.vdims()),
"%s, inferred out shape and output shape are not compatible",
op_t::kind2str(n->get_kind()).c_str());
if (!out0.is_shape_unknown()) return status::success;
}

set_shape_and_strides(*outputs[0], inferred_out_shape);
return status::success;
}

status_t infer_dnnl_binary_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs) {
const bool is_bias_add = n->has_attr(op_attr::is_bias_add)
&& n->get_attr<bool>(op_attr::is_bias_add);

auto ret = is_bias_add
? infer_bias_add_output_shape(n, inputs, outputs)
: infer_elemwise_arithmetic_output_shape(n, inputs, outputs);

return ret;
const bool is_select = n->has_attr(op_attr::is_select)
&& n->get_attr<bool>(op_attr::is_select);
if (is_select) {
return infer_binary_select_output_shape(n, inputs, outputs);
} else if (is_bias_add) {
return infer_bias_add_output_shape(n, inputs, outputs);
} else {
return infer_elemwise_arithmetic_output_shape(n, inputs, outputs);
}
}

} // namespace dnnl_impl
Expand Down
4 changes: 4 additions & 0 deletions src/graph/backend/dnnl/dnnl_shape_infer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ status_t infer_dnnl_binary_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

status_t infer_binary_select_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
Expand Down
1 change: 1 addition & 0 deletions src/graph/backend/dnnl/internal_attrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const op_attr_t with_runtime_dst_zps = 0x1000c;
const op_attr_t is_bias_add = 0x1000d;
const op_attr_t with_sum = 0x1000e;
const op_attr_t keep_dst_layout = 0x1000f;
const op_attr_t is_select = 0x10010;

// int64_t
const op_attr_t alg_kind = 0x10100;
Expand Down
2 changes: 2 additions & 0 deletions src/graph/backend/dnnl/kernels/large_partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ void larger_partition_kernel_t::setup_pipeline_stage1(
pass_pipeline_t &pipeline) {
// Directly lower down (1 to 1 mapping)
BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
// Decompose select to multiple binary ops if necessary
BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_multiple_binary_ops);

// Indirectly lower down (N to 1 mapping)
BACKEND_DNNL_ADD_PASS(pipeline, fuse_reciprocal_mul_to_div);
Expand Down
3 changes: 3 additions & 0 deletions src/graph/backend/dnnl/kernels/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ status_t matmul_t<quantized>::compile_impl(const dnnl_partition_impl_t *part,
pass_pipeline_t pipeline(vis);

BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
// Decompose select to multiple binary ops if necessary
BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_multiple_binary_ops);

BACKEND_DNNL_ADD_PASS(pipeline, fuse_bias_add);
// check if bias exists
BACKEND_DNNL_ADD_PASS(pipeline, check_with_bias);
Expand Down
3 changes: 2 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_decomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ status_t sdp_decomp_kernel_t<quantized, dt>::compile_impl(
pass_pipeline_t pipeline = pass_pipeline_t(vis);
pass_pipeline_t select_pipeline = pass_pipeline_t(vis);
BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_multiple_binary_ops);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_reshape_for_gqa);
// Fusion and canonicalization passes begin
if (quantized) {
Expand Down Expand Up @@ -391,4 +392,4 @@ template struct sdp_decomp_kernel_t<true, dnnl::memory::data_type::f32>;
} // namespace dnnl_impl
} // namespace graph
} // namespace impl
} // namespace dnnl
} // namespace dnnl
3 changes: 3 additions & 0 deletions src/graph/backend/dnnl/kernels/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ status_t select_t::compile_impl(const dnnl_partition_impl_t *part,
pass_pipeline_t pipeline(vis);

BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
// Decompose select to multiple binary ops if necessary
BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_multiple_binary_ops);

BACKEND_DNNL_ADD_PASS(pipeline, binary_canonicalization);

BACKEND_DNNL_ADD_PASS(pipeline, fuse_post_ops);
Expand Down
19 changes: 16 additions & 3 deletions src/graph/backend/dnnl/op_executable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,8 +1249,16 @@ binary_executable_t::desc_t binary_executable_t::create_desc(
op->get_attr<int64_t>(op_attr::alg_kind));

dnnl::binary::primitive_desc pd;
pd = dnnl::binary::primitive_desc(
p_engine, algo, src0, src1, dst, prm_attr);
if (op->has_attr(op_attr::is_select)
&& op->get_attr<bool>(op_attr::is_select)) {
auto src2 = make_dnnl_memory_desc(
op->get_input_value(2)->get_logical_tensor());
pd = dnnl::binary::primitive_desc(
p_engine, algo, src0, src1, src2, dst, prm_attr);
} else {
pd = dnnl::binary::primitive_desc(
p_engine, algo, src0, src1, dst, prm_attr);
}

pd_cache.insert({op.get(), pd});

Expand Down Expand Up @@ -1874,12 +1882,17 @@ arg_indices_t matmul_executable_t::get_arg_indices(
arg_indices_t binary_executable_t::get_arg_indices(
const op_t *op, fusion_info_mgr_t &mgr) {
arg_indices_t arg_indices;
const bool is_select = op->has_attr(op_attr::is_select)
? op->get_attr<bool>(op_attr::is_select)
: false;

// add input args
size_t index = 0;
arg_indices.insert({DNNL_ARG_SRC_0, indices_t {input, index++}});
arg_indices.insert({DNNL_ARG_SRC_1, indices_t {input, index++}});

if (is_select) {
arg_indices.insert({DNNL_ARG_SRC_2, indices_t {input, index++}});
}
get_arg_indices_for_post_ops(op, mgr, arg_indices, index);

// add output args
Expand Down
127 changes: 24 additions & 103 deletions src/graph/backend/dnnl/passes/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,114 +658,34 @@ static status_t select_handler(
auto cond = in_vals[0];
auto src0 = in_vals[1];
auto src1 = in_vals[2];
cond->set_data_type(dnnl::impl::data_type::u8);

//TODO: This reorder can be removed once eltwise_clip support int8 input
op_ptr type_cast = std::make_shared<op_t>(op_kind::dnnl_reorder);
type_cast->set_attr<bool>(op_attr::change_layout, false);

op_ptr clip = std::make_shared<op_t>(op_kind::dnnl_eltwise);
clip->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(dnnl::algorithm::eltwise_clip));
clip->set_attr<float>(op_attr::alpha, 0.f);
clip->set_attr<float>(op_attr::beta, 1.f);

// After reorder and clip. The cond value is 0 or 1.
// Then output = src0.*cond+src1.*(cond*-1 + 1)
op_ptr mul1 = std::make_shared<op_t>(op_kind::dnnl_binary);
mul1->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(dnnl::algorithm::binary_mul));
mul1->merge_attributes(op->get_attributes());

op_ptr mul2 = std::make_shared<op_t>(op_kind::dnnl_binary);
mul2->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(dnnl::algorithm::binary_mul));
mul2->merge_attributes(op->get_attributes());

op_ptr linear = std::make_shared<op_t>(op_kind::dnnl_eltwise);
linear->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(dnnl::algorithm::eltwise_linear));
const float alpha_value = -1.0f, beta_value = 1.0f;
linear->set_attr<float>(op_attr::alpha, alpha_value);
linear->set_attr<float>(op_attr::beta, beta_value);

op_ptr add = std::make_shared<op_t>(op_kind::dnnl_binary);
add->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(dnnl::algorithm::binary_add));
// For the binary select operation, the conditional input tensor can
// only be of `s8` data type.
cond->set_data_type(dnnl::impl::data_type::s8);

op_ptr new_op = std::make_shared<op_t>(op_kind::dnnl_binary);
new_op->set_attr<bool>(op_attr::is_select, true);
new_op->set_attr<int64_t>(op_attr::alg_kind,
static_cast<int64_t>(get_binary_alg_map().at(op->get_kind())));
new_op->merge_attributes(op->get_attributes());

// reconnect
cond->remove_consumer(*op, 0);
src0->remove_consumer(*op, 1);
src1->remove_consumer(*op, 2);

// first reorder and clip
cond->add_consumer(*type_cast, 0);
type_cast->add_input(cond);
logical_tensor_t float_cond = empty_logical_tensor_with_default_id();
auto float_cond_val
= std::make_shared<value_t>(*type_cast, 0, float_cond, true);
float_cond_val->set_data_type(dnnl::impl::data_type::f32);
type_cast->add_output(float_cond_val);
insert_empty_scratchpad(type_cast);

float_cond_val->add_consumer(*clip, 0);
clip->add_input(float_cond_val);
logical_tensor_t clip_cond = empty_logical_tensor_with_default_id();
auto clip_cond_val = std::make_shared<value_t>(*clip, 0, clip_cond, true);
clip_cond_val->set_data_type(
float_cond_val->get_logical_tensor().data_type);
clip->add_output(clip_cond_val);
insert_empty_scratchpad(clip);

// first multiply
src0->add_consumer(*mul1, 0);
clip_cond_val->add_consumer(*mul1, 1);
mul1->add_input(src0);
mul1->add_input(clip_cond_val);

logical_tensor_t src0_cond = empty_logical_tensor_with_default_id();
auto src0_val = std::make_shared<value_t>(*mul1, 0, src0_cond, true);
src0_val->set_data_type(src0->get_logical_tensor().data_type);
mul1->add_output(src0_val);
insert_empty_scratchpad(mul1);

//cond.*{-1} + 1
clip_cond_val->add_consumer(*linear, 0);
linear->add_input(clip_cond_val);

logical_tensor_t cond_inv = empty_logical_tensor_with_default_id();
auto cond_inv_val = std::make_shared<value_t>(*linear, 0, cond_inv, true);
cond_inv_val->set_data_type(clip_cond_val->get_logical_tensor().data_type);
linear->add_output(cond_inv_val);
insert_empty_scratchpad(linear);

//src1.*(cond_inv)

src1->add_consumer(*mul2, 0);
cond_inv_val->add_consumer(*mul2, 1);
mul2->add_input(src1);
mul2->add_input(cond_inv_val);

logical_tensor_t src1_cond = empty_logical_tensor_with_default_id();
auto src1_val = std::make_shared<value_t>(*mul2, 0, src1_cond, true);
src1_val->set_data_type(src1->get_logical_tensor().data_type);
mul2->add_output(src1_val);
insert_empty_scratchpad(mul2);

src0_val->add_consumer(*add, 0);
src1_val->add_consumer(*add, 1);
add->add_input(src0_val);
add->add_input(src1_val);
add->add_output(out_vals[0]);
insert_empty_scratchpad(add);

// add new ops and delete select op
rewriter.to_insert(type_cast);
rewriter.to_insert(clip);
rewriter.to_insert(mul1);
rewriter.to_insert(linear);
rewriter.to_insert(mul2);
rewriter.to_insert(add);
// binary select primitive places the condition input tensor as the
// third input tensor.
src0->add_consumer(*new_op, 0);
src1->add_consumer(*new_op, 1);
cond->add_consumer(*new_op, 2);

new_op->add_input(src0);
new_op->add_input(src1);
new_op->add_input(cond);
new_op->add_output(out_vals[0]);

insert_empty_scratchpad(new_op);
rewriter.to_insert(new_op);
rewriter.to_remove(op);

return status::success;
Expand Down Expand Up @@ -895,7 +815,8 @@ status_t lower_down(std::shared_ptr<subgraph_t> &sg) {
auto kind = cur_op->get_kind();
if (!handler_table.count(kind)) {
assertm(false,
"All spec ops should be lowered to internal ops, except "
"All spec ops should be lowered to internal ops, "
"except "
"for some utility ops like End, Wildcard");
return status::invalid_graph_op;
}
Expand Down
Loading

0 comments on commit b94bafa

Please sign in to comment.