-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
graph: backend: dnnl: support select with binary primitive #2349
base: main
Are you sure you want to change the base?
Conversation
6be21c9
to
4ea4e67
Compare
make test |
4ea4e67
to
b94bafa
Compare
make test |
b94bafa
to
458e748
Compare
make test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have any performance data to share?
@@ -2266,7 +2266,8 @@ status_t binary_canonicalization(std::shared_ptr<subgraph_t> &sg) { | |||
int32_t src1_ndims = src1_lt.ndims; | |||
int32_t target_ndims = std::max(src0_ndims, src1_ndims); | |||
std::vector<int32_t> in_ndims {src0_ndims, src1_ndims}; | |||
for (size_t i = 0; i < cur_op->num_inputs(); ++i) { | |||
std::vector<size_t> input_indices = {0, 1}; | |||
for (auto i : input_indices) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? Previously num_inputs()
is 2 - 32 per the schema definition. Now the code only handles the first two?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass is applied before postop fusion pass, so input number is always 2 before. For this PR, although binary select has three inputs, since cond
dims has been promised to be the same that of src0
by pass decompose_select_to_binary_ops
, we only need to unsqueeze src0
and src1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cond
dims has been promised to be the same of src0
, then it should fall into the condition of if (in_ndims[i] == target_ndims) { continue; }
, so no unsqueeze inserted. If this is the case, no need to limit the input_indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in_ndims
only has two elements, the access for the third element is not legal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, then it seems the original code is designed for 2 elements
@@ -32,9 +32,6 @@ namespace dnnl { | |||
namespace impl { | |||
namespace graph { | |||
|
|||
// utils function | |||
namespace { | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- We have many utils functions in
shape_infer.cpp
, they are defined inshape_infer.cpp
and declared inshape_infer.hpp
, and this temporary namespace only has functiondims2str
(it's declared and defined in this namespace and used inshape_infer.cpp
only) dims2str
can be used and has been used for this PR in other files, it should be in the same namespace of other utils functions, that isdnnl::impl::graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is this function used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's used in the shape inference process of binary select op:
dims2str(input0_dims).c_str(), dims2str(input1_dims).c_str(), |
Sure, I have attached it to the PR description. |
458e748
to
f8262e0
Compare
make test |
f8262e0
to
325bca9
Compare
make test |
@@ -2266,7 +2266,8 @@ status_t binary_canonicalization(std::shared_ptr<subgraph_t> &sg) { | |||
int32_t src1_ndims = src1_lt.ndims; | |||
int32_t target_ndims = std::max(src0_ndims, src1_ndims); | |||
std::vector<int32_t> in_ndims {src0_ndims, src1_ndims}; | |||
for (size_t i = 0; i < cur_op->num_inputs(); ++i) { | |||
std::vector<size_t> input_indices = {0, 1}; | |||
for (auto i : input_indices) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, then it seems the original code is designed for 2 elements
325bca9
to
6694b8c
Compare
make test |
Description
cond
input is defined for dnnl binary opcond
input, we use binary select primitive for non-broadcast case only, the lowering logic is: always lower select to binary primitive and then decide which impl path to use in passdecompose_select_to_multiple_binary_ops
and decompose it to multiple binary ops if necessary.Performance
relative perf:
platform: Intel(R) Xeon(R) Platinum 8490H