Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generic:sycl: Inner Product Backward #2360

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/gpu/generic/sycl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ The implementation supports both forward and backward directions.

## Inner Product

The implementation supports the forward direction only.
The implementation supports both forward and backward directions.

* Supported formats: All plain formats are supported.
* Supported data types: All possible data combinations listed in the oneDNN specification are supported.
* Supported post-ops: All the post operations as mentioned in the specification are supported.
* Supported data types: All possible data combinations as listed in the specification are supported.
* Supported post-ops: All the post-ops as mentioned in the specification are supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Supported post-ops: All the post-ops as mentioned in the specification are supported.
* Supported post-ops: All post-ops mentioned in the specification are supported.

* The backward pass does not support post-ops. One should not use post-ops in the forward pass during training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this sentence could be included in a note format as opposed to listing this as a bullet.

Suggested change
* The backward pass does not support post-ops. One should not use post-ops in the forward pass during training
Note: The backward pass does not support post-ops. You should not use post-ops in the forward pass during training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively "post-ops should not be used in the forward pass during training"


## Layer Normalization

Expand Down
135 changes: 125 additions & 10 deletions src/gpu/generic/sycl/ref_inner_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,69 @@
*******************************************************************************/

#include "gpu/generic/sycl/ref_inner_product.hpp"
#include "common/primitive_desc_iterator.hpp"

namespace dnnl::impl::gpu::generic::sycl {

status_t ref_inner_product_fwd_t::pd_t::init_matmul(impl::engine_t *engine) {
namespace detail {
status_t init_matmul_pd(impl::engine_t *engine,
const primitive_attr_t *attributes, const memory_desc_t *src_desc,
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
const memory_desc_t *dst_desc,
std::shared_ptr<primitive_desc_t> &matmul_pd) {

matmul_desc_t matmul_desc;
CHECK(matmul_desc_init(&matmul_desc, &src_md_reshaped, &weights_md_reshaped,
&bias_md_reshaped, arg_md(DNNL_ARG_DST)));
primitive_attr_t matmul_attr(*attr());
CHECK(matmul_desc_init(
&matmul_desc, src_desc, weights_desc, bias_desc, dst_desc));

primitive_desc_iterator_t it(engine,
reinterpret_cast<op_desc_t *>(&matmul_desc), &matmul_attr, nullptr);
if (!it.is_initialized()) return status::out_of_memory;
reinterpret_cast<op_desc_t *>(&matmul_desc), attributes, nullptr);

if (!it.is_initialized()) return status::invalid_arguments;
while (++it != it.end()) {
matmul_pd = *it;
if (matmul_pd) { break; }
if (*it) {
matmul_pd = *it;
break;
}
}
return status::success;
}

status_t flatten_md(const memory_desc_t &desc, memory_desc_t &flattened_md,
format_tag_t format_tag) {
// Always flattens from a nD to a 2D memory layout, with batch as the first dimension
assert(format_tag == format_tag::ab || format_tag == format_tag::ba);
const auto &dimensions = desc.dims;
int64_t flattened_dimension = 1;
for (int i = 1; i < desc.ndims; i++) {
flattened_dimension *= dimensions[i];
}
if (!matmul_pd) { return status::invalid_arguments; }
dims_t reshaped_dims;
if (format_tag == format_tag::ab) {
reshaped_dims[0] = desc.dims[0];
reshaped_dims[1] = flattened_dimension;
} else {
reshaped_dims[0] = flattened_dimension;
reshaped_dims[1] = desc.dims[0];
}
CHECK(memory_desc_init_by_tag(
flattened_md, 2, reshaped_dims, desc.data_type, format_tag));
return status::success;
}

std::vector<int> get_dim_order(int ndims, const dims_t strides) {
std::vector<int> order(ndims);
for (int i = 0; i < ndims; ++i) {
order[i] = i;
}

std::sort(order.begin(), order.end(),
[&strides](size_t i, size_t j) { return strides[i] < strides[j]; });

return order;
}

} // namespace detail

status_t ref_inner_product_fwd_t::init(impl::engine_t *engine) {
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t> p;
CHECK(pd()->matmul_pd->create_primitive_nested(p, engine));
Expand All @@ -52,4 +94,77 @@ status_t ref_inner_product_fwd_t::execute(const exec_ctx_t &ctx) const {
return matmul_primitive->execute(copied_ctx);
}

status_t ref_inner_product_bwd_data_t::init(impl::engine_t *engine) {
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t> p;
CHECK(pd()->matmul_pd->create_primitive_nested(p, engine));
matmul_primitive = p.first;
return status::success;
}

status_t ref_inner_product_bwd_data_t::execute(const exec_ctx_t &ctx) const {
nested_scratchpad_t nested_scratchpad(
ctx, memory_tracking::names::key_nested, matmul_primitive);

exec_args_t args_copy(ctx.args());
// Map src and dst to diff_dst and diff_src respectively
args_copy[DNNL_ARG_SRC] = args_copy[DNNL_ARG_DIFF_DST];
args_copy[DNNL_ARG_DST] = args_copy[DNNL_ARG_DIFF_SRC];
exec_ctx_t copied_ctx(ctx.stream(), std::move(args_copy));

copied_ctx.set_scratchpad_grantor(nested_scratchpad.grantor());

return matmul_primitive->execute(copied_ctx);
}

status_t ref_inner_product_bwd_weights_t::init(impl::engine_t *engine) {
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t> p;
CHECK(pd()->matmul_pd->create_primitive_nested(p, engine));
matmul_primitive = p.first;

if (pd()->with_bias()) {
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t>
p_reduction;
CHECK(pd()->reduction_pd->create_primitive_nested(p_reduction, engine));
reduction_primitive = p_reduction.first;
}

return status::success;
}

status_t ref_inner_product_bwd_weights_t::execute(const exec_ctx_t &ctx) const {
nested_scratchpad_t nested_scratchpad(
ctx, memory_tracking::names::key_nested_multiple, matmul_primitive);

exec_args_t args_copy(ctx.args());
// Map src and dst to diff_dst and diff_src respectively
auto src_memory_arg = args_copy[DNNL_ARG_SRC];
args_copy[DNNL_ARG_SRC] = args_copy[DNNL_ARG_DIFF_DST];
args_copy[DNNL_ARG_WEIGHTS] = src_memory_arg;
args_copy[DNNL_ARG_DST] = args_copy[DNNL_ARG_DIFF_WEIGHTS];
exec_ctx_t copied_ctx(ctx.stream(), std::move(args_copy));

copied_ctx.set_scratchpad_grantor(nested_scratchpad.grantor());
// calcules dL/dW;
CHECK(matmul_primitive->execute(copied_ctx));

if (pd()->with_bias()) {
//calculates dL/dB
nested_scratchpad_t reduction_scratchpad(ctx,
memory_tracking::names::key_nested_multiple + 1,
reduction_primitive);
exec_args_t args_copy_reduction(ctx.args());
args_copy_reduction[DNNL_ARG_SRC]
= args_copy_reduction[DNNL_ARG_DIFF_DST];
args_copy_reduction[DNNL_ARG_DST]
= args_copy_reduction[DNNL_ARG_DIFF_BIAS];
exec_ctx_t copied_ctx_reduction(
ctx.stream(), std::move(args_copy_reduction));

copied_ctx_reduction.set_scratchpad_grantor(
reduction_scratchpad.grantor());
CHECK(reduction_primitive->execute(copied_ctx_reduction));
}
return status::success;
}

} // namespace dnnl::impl::gpu::generic::sycl
Loading
Loading