Skip to content

Commit

Permalink
graph: backend: dnnl: support select pattern and op with binary primi…
Browse files Browse the repository at this point in the history
…tive
  • Loading branch information
Jiexin-Zheng committed Jan 7, 2025
1 parent 64b4c34 commit 4ea4e67
Show file tree
Hide file tree
Showing 14 changed files with 1,576 additions and 154 deletions.
2 changes: 2 additions & 0 deletions src/graph/backend/dnnl/dnnl_op_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_binary, 1,
.set_num_outputs(2)
.set_input(0, "a")
.set_input(1, "b")
.set_input(2, "cond")
.set_output(0, "output")
.set_output(1, "scratchpad")
// Attributes inherited from front binary ops (Add, Multiply,
Expand All @@ -713,6 +714,7 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_binary, 1,
{"NXC", "NCX"})
// New added attributes
.set_attr(op_attr::is_bias_add, false, attribute_kind::b, false)
.set_attr(op_attr::is_select, false, attribute_kind::b, false)
.set_attr(op_attr::fusion_info_key, false, attribute_kind::i,
(int64_t)-1)
.set_attr(op_attr::alg_kind, true, attribute_kind::i)
Expand Down
15 changes: 9 additions & 6 deletions src/graph/backend/dnnl/dnnl_shape_infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,12 +489,15 @@ status_t infer_dnnl_binary_output_shape(op_t *n,
std::vector<logical_tensor_t *> &outputs) {
const bool is_bias_add = n->has_attr(op_attr::is_bias_add)
&& n->get_attr<bool>(op_attr::is_bias_add);

auto ret = is_bias_add
? infer_bias_add_output_shape(n, inputs, outputs)
: infer_elemwise_arithmetic_output_shape(n, inputs, outputs);

return ret;
const bool is_select = n->has_attr(op_attr::is_select)
&& n->get_attr<bool>(op_attr::is_select);
if (is_select) {
return infer_select_output_shape(n, inputs, outputs);
} else if (is_bias_add) {
return infer_bias_add_output_shape(n, inputs, outputs);
} else {
return infer_elemwise_arithmetic_output_shape(n, inputs, outputs);
}
}

} // namespace dnnl_impl
Expand Down
1 change: 1 addition & 0 deletions src/graph/backend/dnnl/internal_attrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const op_attr_t with_runtime_dst_zps = 0x1000c;
const op_attr_t is_bias_add = 0x1000d;
const op_attr_t with_sum = 0x1000e;
const op_attr_t keep_dst_layout = 0x1000f;
const op_attr_t is_select = 0x10010;

// int64_t
const op_attr_t alg_kind = 0x10100;
Expand Down
3 changes: 2 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_decomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ status_t sdp_decomp_kernel_t<quantized, dt>::compile_impl(
pass_pipeline_t pipeline = pass_pipeline_t(vis);
pass_pipeline_t select_pipeline = pass_pipeline_t(vis);
BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_multiple_binary_ops);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_reshape_for_gqa);
// Fusion and canonicalization passes begin
if (quantized) {
Expand Down Expand Up @@ -391,4 +392,4 @@ template struct sdp_decomp_kernel_t<true, dnnl::memory::data_type::f32>;
} // namespace dnnl_impl
} // namespace graph
} // namespace impl
} // namespace dnnl
} // namespace dnnl
19 changes: 16 additions & 3 deletions src/graph/backend/dnnl/op_executable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,8 +1249,16 @@ binary_executable_t::desc_t binary_executable_t::create_desc(
op->get_attr<int64_t>(op_attr::alg_kind));

dnnl::binary::primitive_desc pd;
pd = dnnl::binary::primitive_desc(
p_engine, algo, src0, src1, dst, prm_attr);
if (op->has_attr(op_attr::is_select)
&& op->get_attr<bool>(op_attr::is_select)) {
auto src2 = make_dnnl_memory_desc(
op->get_input_value(2)->get_logical_tensor());
pd = dnnl::binary::primitive_desc(
p_engine, algo, src0, src1, src2, dst, prm_attr);
} else {
pd = dnnl::binary::primitive_desc(
p_engine, algo, src0, src1, dst, prm_attr);
}

pd_cache.insert({op.get(), pd});

Expand Down Expand Up @@ -1874,12 +1882,17 @@ arg_indices_t matmul_executable_t::get_arg_indices(
arg_indices_t binary_executable_t::get_arg_indices(
const op_t *op, fusion_info_mgr_t &mgr) {
arg_indices_t arg_indices;
const bool is_select = op->has_attr(op_attr::is_select)
? op->get_attr<bool>(op_attr::is_select)
: false;

// add input args
size_t index = 0;
arg_indices.insert({DNNL_ARG_SRC_0, indices_t {input, index++}});
arg_indices.insert({DNNL_ARG_SRC_1, indices_t {input, index++}});

if (is_select) {
arg_indices.insert({DNNL_ARG_SRC_2, indices_t {input, index++}});
}
get_arg_indices_for_post_ops(op, mgr, arg_indices, index);

// add output args
Expand Down
Loading

0 comments on commit 4ea4e67

Please sign in to comment.