From fc994fd403bb1df53acdb62e5e9a287594462a71 Mon Sep 17 00:00:00 2001 From: Svetlozar Georgiev Date: Wed, 11 Dec 2024 16:24:28 +0000 Subject: [PATCH 1/3] gpu: generic: add simple SYCL reduction implementation --- src/gpu/generic/sycl/simple_reduction.cpp | 57 ++++++++ src/gpu/generic/sycl/simple_reduction.hpp | 84 +++++++++++ .../generic/sycl/simple_reduction_kernels.hpp | 132 ++++++++++++++++++ src/gpu/generic/sycl/sycl_primitive_conf.hpp | 12 ++ src/gpu/gpu_reduction_list.cpp | 5 + 5 files changed, 290 insertions(+) create mode 100644 src/gpu/generic/sycl/simple_reduction.cpp create mode 100644 src/gpu/generic/sycl/simple_reduction.hpp create mode 100644 src/gpu/generic/sycl/simple_reduction_kernels.hpp diff --git a/src/gpu/generic/sycl/simple_reduction.cpp b/src/gpu/generic/sycl/simple_reduction.cpp new file mode 100644 index 00000000000..629cfe44a7d --- /dev/null +++ b/src/gpu/generic/sycl/simple_reduction.cpp @@ -0,0 +1,57 @@ +#include "simple_reduction.hpp" + +#include "gpu/generic/sycl/engine.hpp" +#include "gpu/generic/sycl/simple_reduction_kernels.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +status_t simple_reduction_t::pd_t::init_conf() { + conf_.alg = desc()->alg_kind; + conf_.src_md = xpu::sycl::md_t(src_md()); + conf_.dst_md = xpu::sycl::md_t(dst_md()); + conf_.p = desc()->p; + conf_.eps = desc()->eps; + + auto src_wrap = memory_desc_wrapper(src_md()); + auto dst_wrap = memory_desc_wrapper(dst_md()); + dst_nelems_ = dst_wrap.nelems(); + + const auto ndims = dst_wrap.ndims(); + for (int d = 0; d < xpu::sycl::md_t::max_dims; d++) { + conf_.reduce_dims[d] = dim_t {1}; + if (d < ndims) { + if (src_wrap.dims()[d] != dst_wrap.dims()[d]) { + conf_.reduce_dims[d] = src_wrap.dims()[d]; + conf_.reduce_size *= conf_.reduce_dims[d]; + } + } + } + + conf_.post_ops = sycl_post_ops_t(attr(), dst_wrap); + + return status::success; +} + +status_t simple_reduction_t::init(impl::engine_t *engine) { + const auto kid = ::sycl::get_kernel_id(); + CHECK(create_kernel(engine, kid, &kernel_)); + + return status::success; +} + +status_t simple_reduction_t::execute(const exec_ctx_t &ctx) const { + return parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) { + reduction_kernel_fwd_t reduction_kernel(pd()->conf_, cgh, ctx); + cgh.parallel_for(::sycl::range<1>(pd()->dst_nelems_), reduction_kernel); + }); +} + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl diff --git a/src/gpu/generic/sycl/simple_reduction.hpp b/src/gpu/generic/sycl/simple_reduction.hpp new file mode 100644 index 00000000000..f8801882719 --- /dev/null +++ b/src/gpu/generic/sycl/simple_reduction.hpp @@ -0,0 +1,84 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GPU_GENERIC_SYCL_SIMPLE_REDUCTION_HPP +#define GPU_GENERIC_SYCL_SIMPLE_REDUCTION_HPP + +#include "common/primitive_desc_iterator.hpp" +#include "common/reorder.hpp" +#include "common/reorder_pd.hpp" +#include "gpu/generic/sycl/sycl_gpu_primitive.hpp" +#include "gpu/generic/sycl/sycl_io_helper.hpp" +#include "gpu/generic/sycl/sycl_post_ops.hpp" +#include "gpu/generic/sycl/sycl_primitive_conf.hpp" +#include "gpu/generic/sycl/sycl_utils.hpp" +#include "gpu/gpu_reduction_pd.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +struct simple_reduction_t : public gpu::generic::sycl::primitive_t { + using gpu::generic::sycl::primitive_t::primitive_t; + + struct pd_t : public gpu_reduction_pd_t { + using gpu_reduction_pd_t::gpu_reduction_pd_t; + + DECLARE_COMMON_PD_T("dpcpp:ref:any", simple_reduction_t); + + status_t init(impl::engine_t *engine) { + using sm = primitive_attr_t::skip_mask_t; + + memory_desc_wrapper src_wrap(src_md()); + memory_desc_wrapper dst_wrap(dst_md()); + + bool ok = set_default_params() == status::success + && attr()->has_default_values(sm::post_ops) + && sycl_post_ops_t::post_ops_ok(attr()) + && attr_.set_default_formats(dst_md()) == status::success + && src_wrap.is_plain() && dst_wrap.is_plain() + && src_wrap.ndims() == dst_wrap.ndims() + && md_dims_in_range(src_md()) && md_dims_in_range(dst_md()); + if (!ok) return status::unimplemented; + + return init_conf(); + } + + sycl_simple_reduction_conf_t conf_; + dim_t dst_nelems_; + + private: + status_t init_conf(); + }; + + status_t init(impl::engine_t *engine) override; + status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + kernel_t kernel_; + std::shared_ptr reorder_p_; +}; + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/gpu/generic/sycl/simple_reduction_kernels.hpp b/src/gpu/generic/sycl/simple_reduction_kernels.hpp new file mode 100644 index 00000000000..c2c1cdb2125 --- /dev/null +++ b/src/gpu/generic/sycl/simple_reduction_kernels.hpp @@ -0,0 +1,132 @@ + +#ifndef GPU_GENERIC_SYCL_SIMPLE_REDUCTION_KERNELS_HPP +#define GPU_GENERIC_SYCL_SIMPLE_REDUCTION_KERNELS_HPP + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/primitive_exec_types.hpp" +#include "common/utils.hpp" +#include "gpu/generic/sycl/sycl_io_helper.hpp" +#include "gpu/generic/sycl/sycl_math_utils.hpp" +#include "gpu/generic/sycl/sycl_primitive_conf.hpp" +#include "xpu/sycl/memory_storage_base.hpp" +#include "xpu/sycl/types.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +struct Reducer { + dnnl_alg_kind_t alg_; + float p_, eps_; + + Reducer(dnnl_alg_kind_t alg, float p, float eps) + : alg_(alg), p_(p), eps_(eps) {} + + float identity() const { + if (alg_ == dnnl_reduction_min) { + return std::numeric_limits::max(); + } else if (alg_ == dnnl_reduction_max) { + return std::numeric_limits::lowest(); + } else if (alg_ == dnnl_reduction_mul) { + return 1.f; + } + + return 0.f; + } + + float reduce(float lhs, float rhs) const { + if (alg_ == dnnl_reduction_sum || alg_ == dnnl_reduction_mean) { + return lhs + rhs; + } else if (alg_ == dnnl_reduction_min) { + return ::sycl::min(lhs, rhs); + } else if (alg_ == dnnl_reduction_max) { + return ::sycl::max(lhs, rhs); + } else if (alg_ == dnnl_reduction_mul) { + return lhs * rhs; + } else if (alg_ == dnnl_reduction_norm_lp_max + || alg_ == dnnl_reduction_norm_lp_sum + || alg_ == dnnl_reduction_norm_lp_power_p_max + || alg_ == dnnl_reduction_norm_lp_power_p_sum) { + return lhs + ::sycl::pow(::sycl::fabs(rhs), p_); + } + + return ::sycl::nan(0U); + } + + float finalize(float val, int size) const { + if (alg_ == dnnl_reduction_mean) { + return val / size; + } else if (alg_ == dnnl_reduction_norm_lp_max) { + return ::sycl::rootn(::sycl::max(val, eps_), p_); + } else if (alg_ == dnnl_reduction_norm_lp_sum) { + return ::sycl::rootn(val + eps_, p_); + } else if (alg_ == dnnl_reduction_norm_lp_power_p_max) { + return ::sycl::max(val, eps_); + } else if (alg_ == dnnl_reduction_norm_lp_power_p_sum) { + return val + eps_; + } + + return val; + } +}; + +struct reduction_kernel_fwd_t { + sycl_simple_reduction_conf_t conf_; + xpu::sycl::in_memory_arg_t src_; + xpu::sycl::out_memory_arg_t dst_; + post_op_input_args po_args_; + + reduction_kernel_fwd_t(const sycl_simple_reduction_conf_t &conf, + ::sycl::handler &cgh, const exec_ctx_t &ctx) + : conf_(conf) + , src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC)) + , dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST)) + , po_args_(cgh, ctx, conf_.post_ops) {} + + void operator()(::sycl::item<1> item) const { + Reducer reducer(conf_.alg, conf_.p, conf_.eps); + + memory_tensor_t<::sycl::access_mode::read> src(src_, conf_.src_md); + memory_tensor_t<::sycl::access_mode::write> dst(dst_, conf_.dst_md); + const int id = item.get_linear_id(); + + const auto &dst_md = conf_.dst_md; + dims_t pos; + int l_offset = id; + for (int i = 0; i < dst_md.ndims(); i++) { + const int d = dst_md.ndims() - 1 - i; + const dim_t cur_dim = dst_md.dims()[d]; + pos[d] = l_offset % cur_dim; + l_offset = l_offset / cur_dim; + } + + float acc = reducer.identity(); + for (off_t d0 = 0; d0 < conf_.reduce_dims[0]; d0++) + for (off_t d1 = 0; d1 < conf_.reduce_dims[1]; d1++) + for (off_t d2 = 0; d2 < conf_.reduce_dims[2]; d2++) + for (off_t d3 = 0; d3 < conf_.reduce_dims[3]; d3++) + for (off_t d4 = 0; d4 < conf_.reduce_dims[4]; d4++) + for (off_t d5 = 0; d5 < conf_.reduce_dims[5]; + d5++) { + dims_t src_off = {pos[0] + d0, pos[1] + d1, + pos[2] + d2, pos[3] + d3, pos[4] + d4, + pos[5] + d5}; + const float val = src.load_md(src_off); + acc = reducer.reduce(acc, val); + } + + float result = reducer.finalize(acc, conf_.reduce_size); + result = conf_.post_ops.apply(result, dst.load_md(pos), po_args_, pos); + dst.store_md(result, pos); + } +}; + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl +#endif diff --git a/src/gpu/generic/sycl/sycl_primitive_conf.hpp b/src/gpu/generic/sycl/sycl_primitive_conf.hpp index ec4e812cebb..1074af9e1c4 100644 --- a/src/gpu/generic/sycl/sycl_primitive_conf.hpp +++ b/src/gpu/generic/sycl/sycl_primitive_conf.hpp @@ -415,6 +415,17 @@ struct sycl_pooling_bwd_conf_t : public sycl_pooling_base_conf_t { xpu::sycl::md_t diff_dst_md; }; +struct sycl_simple_reduction_conf_t { + dnnl_alg_kind_t alg = dnnl_alg_kind_undef; + xpu::sycl::md_t src_md; + xpu::sycl::md_t dst_md; + float p; + float eps; + sycl_post_ops_t post_ops; + dim_t reduce_dims[xpu::sycl::md_t::max_dims]; + int reduce_size = 1; +}; + CHECK_SYCL_KERNEL_ARG_TYPE(sycl_binary_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_prelu_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_shuffle_conf_t); @@ -431,6 +442,7 @@ CHECK_SYCL_KERNEL_ARG_TYPE(sycl_pooling_bwd_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_fwd_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_bwd_data_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_bwd_weights_conf_t); +CHECK_SYCL_KERNEL_ARG_TYPE(sycl_simple_reduction_conf_t); } // namespace sycl } // namespace generic diff --git a/src/gpu/gpu_reduction_list.cpp b/src/gpu/gpu_reduction_list.cpp index b29c238e04a..c04977f46eb 100644 --- a/src/gpu/gpu_reduction_list.cpp +++ b/src/gpu/gpu_reduction_list.cpp @@ -36,6 +36,10 @@ #include "gpu/amd/miopen_reduction.hpp" #endif +#ifdef GENERIC_SYCL_KERNELS_ENABLED +#include "gpu/generic/sycl/simple_reduction.hpp" +#endif + namespace dnnl { namespace impl { namespace gpu { @@ -51,6 +55,7 @@ constexpr impl_list_item_t impl_list[] = REG_REDUCTION_P({ GPU_INSTANCE_INTEL(intel::ocl::reusable_ref_reduction_t) GPU_INSTANCE_NVIDIA(nvidia::cudnn_reduction_t) GPU_INSTANCE_AMD(amd::miopen_reduction_t) + GPU_INSTANCE_GENERIC_SYCL(generic::sycl::simple_reduction_t) nullptr, }); // clang-format on From 7d6052c7db7d50b4b2c726a29895b8dd76ee2e1a Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 7 Nov 2024 16:06:27 +0000 Subject: [PATCH 2/3] generic:sycl: Inner Product FWD Co-authored-by: Atharva Dubey --- src/gpu/generic/sycl/README.md | 8 + src/gpu/generic/sycl/ref_inner_product.cpp | 55 ++++++ src/gpu/generic/sycl/ref_inner_product.hpp | 175 ++++++++++++++++++++ src/gpu/gpu_inner_product_list.cpp | 5 + src/gpu/nvidia/cudnn_matmul_executor.hpp | 8 +- src/gpu/nvidia/cudnn_matmul_lt_impl.hpp | 2 +- tests/gtests/test_inner_product_forward.cpp | 38 ++++- 7 files changed, 279 insertions(+), 12 deletions(-) create mode 100644 src/gpu/generic/sycl/ref_inner_product.cpp create mode 100644 src/gpu/generic/sycl/ref_inner_product.hpp diff --git a/src/gpu/generic/sycl/README.md b/src/gpu/generic/sycl/README.md index e7ff444462d..f1c64ed91d0 100644 --- a/src/gpu/generic/sycl/README.md +++ b/src/gpu/generic/sycl/README.md @@ -94,6 +94,14 @@ The implementation supports both forward and backward directions. * Supported formats: `NCDHW`, `NDHWC`, `NCHW`, `NHWC`, `NCW`, `NWC`, `NC`, `N` * Supported data types: `f32`, `bf16`, `f16`, `s32`, `s8`, `u8` +## Inner Product + +The implementation supports the forward direction only. + +* Supported formats: All plain formats are supported. +* Supported data types: All possible data combinations listed in the oneDNN specification are supported. +* Supported post-ops: All the post operations as mentioned in the specification are supported. + ## Layer Normalization The implementation supports both forward and backward directions. diff --git a/src/gpu/generic/sycl/ref_inner_product.cpp b/src/gpu/generic/sycl/ref_inner_product.cpp new file mode 100644 index 00000000000..e8e1f7f8d20 --- /dev/null +++ b/src/gpu/generic/sycl/ref_inner_product.cpp @@ -0,0 +1,55 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* Copyright 2024 Codeplay Software Limited +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "gpu/generic/sycl/ref_inner_product.hpp" +#include "common/primitive_desc_iterator.hpp" + +namespace dnnl::impl::gpu::generic::sycl { + +status_t ref_inner_product_fwd_t::pd_t::init_matmul(impl::engine_t *engine) { + matmul_desc_t matmul_desc; + CHECK(matmul_desc_init(&matmul_desc, &src_md_reshaped, &weights_md_reshaped, + &bias_md_reshaped, arg_md(DNNL_ARG_DST))); + primitive_attr_t matmul_attr(*attr()); + + primitive_desc_iterator_t it(engine, + reinterpret_cast(&matmul_desc), &matmul_attr, nullptr); + if (!it.is_initialized()) return status::invalid_arguments; + while (++it != it.end()) { + matmul_pd = *it; + if (matmul_pd) { break; } + } + if (!matmul_pd) { return status::invalid_arguments; } + return status::success; +} + +status_t ref_inner_product_fwd_t::init(impl::engine_t *engine) { + std::pair, cache_state_t> p; + CHECK(pd()->matmul_pd->create_primitive_nested(p, engine)); + matmul_primitive = p.first; + return status::success; +} + +status_t ref_inner_product_fwd_t::execute(const exec_ctx_t &ctx) const { + nested_scratchpad_t nested_scratchpad( + ctx, memory_tracking::names::key_nested, matmul_primitive); + exec_ctx_t copied_ctx(ctx); + copied_ctx.set_scratchpad_grantor(nested_scratchpad.grantor()); + return matmul_primitive->execute(copied_ctx); +} + +} // namespace dnnl::impl::gpu::generic::sycl diff --git a/src/gpu/generic/sycl/ref_inner_product.hpp b/src/gpu/generic/sycl/ref_inner_product.hpp new file mode 100644 index 00000000000..648d17bca49 --- /dev/null +++ b/src/gpu/generic/sycl/ref_inner_product.hpp @@ -0,0 +1,175 @@ +/******************************************************************************* +* Copyright 2023-2024 Intel Corporation +* Copyright 2024-2025 Codeplay Software Limited +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GPU_GENERIC_SYCL_REF_INNER_PRODUCT_HPP +#define GPU_GENERIC_SYCL_REF_INNER_PRODUCT_HPP + +#include "gpu/generic/sycl/ref_matmul.hpp" +#include "gpu/generic/sycl/sycl_gpu_primitive.hpp" +#include "gpu/generic/sycl/sycl_post_ops.hpp" +#include "gpu/generic/sycl/sycl_primitive_conf.hpp" +#include "gpu/generic/sycl/sycl_utils.hpp" +#include "gpu/gpu_inner_product_pd.hpp" +#include "gpu/gpu_primitive.hpp" + +namespace dnnl::impl::gpu::generic::sycl { +struct ref_inner_product_fwd_t : public gpu::generic::sycl::primitive_t { + using gpu::generic::sycl::primitive_t::primitive_t; + + struct pd_t : public gpu_inner_product_fwd_pd_t { + using gpu_inner_product_fwd_pd_t::gpu_inner_product_fwd_pd_t; + using sm = primitive_attr_t::skip_mask_t; + + DECLARE_COMMON_PD_T("dpcpp:ref:any", ref_inner_product_fwd_t); + + status_t init(impl::engine_t *engine) { + auto src_dt = arg_md(DNNL_ARG_SRC)->data_type; + auto weights_dt = arg_md(DNNL_ARG_WEIGHTS)->data_type; + auto dst_dt = arg_md(DNNL_ARG_DST)->data_type; + auto bias_dt = with_bias() ? arg_md(DNNL_ARG_BIAS)->data_type + : data_type::undef; + + const bool ok = (set_default_params() == status::success) + && is_fwd() + && check_if_dtypes_valid( + src_dt, dst_dt, bias_dt, weights_dt) + && sycl_post_ops_t::post_ops_ok(attr()) + && (attr_.set_default_formats(dst_md()) == status::success) + // Blocked memory formats are not supported + && memory_desc_wrapper(src_md()).is_plain() + && memory_desc_wrapper(dst_md()).is_plain() + && memory_desc_wrapper(weights_md()).is_plain(); + + if (!ok) { return status::unimplemented; } + CHECK(create_ip_mds()); + CHECK(init_matmul(engine)); + + // book scratchpad for the matmul + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_nested, + matmul_pd->scratchpad_registry()); + return status::success; + } + + std::shared_ptr matmul_pd; + + private: + bool check_if_dtypes_valid(const data_type_t &src_dt, + const data_type_t &dst_dt, const data_type_t &bias_dt, + const data_type_t &weight_dt) const { + using namespace data_type; + return (utils::one_of(src_dt, f32) && utils::one_of(weight_dt, f32) + && utils::one_of(dst_dt, f32) + && utils::one_of(bias_dt, f32, undef)) + || (utils::one_of(src_dt, f16) + && utils::one_of(weight_dt, f16) + && utils::one_of(dst_dt, f16, f32, s8, u8) + && utils::one_of(bias_dt, f16, f32, undef)) + || (utils::one_of(src_dt, u8, s8) + && utils::one_of(weight_dt, s8) + && utils::one_of(dst_dt, u8, s8, s32, bf16, f32) + && utils::one_of( + bias_dt, u8, s8, s32, bf16, f32, undef)) + || (utils::one_of(src_dt, bf16) + && utils::one_of(weight_dt, bf16) + && utils::one_of(dst_dt, f32, bf16) + && utils::one_of(bias_dt, f32, bf16, undef)); + } + + std::vector get_dim_order(int ndims, const dims_t strides) { + std::vector order(ndims); + for (int i = 0; i < ndims; ++i) { + order[i] = i; + } + + std::sort( + order.begin(), order.end(), [&strides](size_t i, size_t j) { + return strides[i] < strides[j]; + }); + + return order; + } + + status_t create_ip_mds() { + auto accumulate_dimensions = [](const dims_t dimensions, int start, + int end) -> int64_t { + int64_t accum = 1; + for (int i = start; i < end; i++) { + accum *= dimensions[i]; + } + return accum; + }; + + const auto src_md_ = arg_md(DNNL_ARG_SRC); + const auto weights_md_ = arg_md(DNNL_ARG_WEIGHTS); + const auto bias_md_ = arg_md(DNNL_ARG_BIAS); + auto src_wrap = memory_desc_wrapper(src_md_); + auto w_wrap = memory_desc_wrapper(weights_md_); + + // src and weights dims need to be in the same order + if (get_dim_order(src_wrap.ndims(), src_wrap.strides()) + != get_dim_order(w_wrap.ndims(), w_wrap.strides())) { + return status::unimplemented; + } + + // Reshape input into the form of Batch x (\prod_{dim_{n-1}}^dim_0) + if (src_md_->ndims == 2) { + src_md_reshaped = *src_md_; + } else { + int64_t src_flattened_dimension = accumulate_dimensions( + src_md_->dims, 1, src_md_->ndims); + dims_t src_reshaped_dims { + src_md_->dims[0], src_flattened_dimension}; + CHECK(memory_desc_init_by_tag(src_md_reshaped, 2, + src_reshaped_dims, src_md_->data_type, format_tag::ab)); + } + + // Reshape weights as (OC x (\prod_{dim_{n-1}}^dim_0))^T + int weights_flattened_dimensions = accumulate_dimensions( + weights_md_->dims, 1, weights_md_->ndims); + dims_t weights_reshaped_dims { + weights_flattened_dimensions, weights_md_->dims[0]}; + CHECK(memory_desc_init_by_tag(weights_md_reshaped, 2, + weights_reshaped_dims, weights_md_->data_type, + format_tag::ba)); + if (with_bias()) { + dims_t bias_reshaped_dims {1, bias_md_->dims[0]}; + CHECK(memory_desc_init_by_tag(bias_md_reshaped, 2, + bias_reshaped_dims, bias_md_->data_type, + format_tag::ab)); + } + return status::success; + } + + status_t init_matmul(impl::engine_t *engine); + // Memory descriptors to contain reshaped tensors from nD to 2D for IP + memory_desc_t src_md_reshaped; + memory_desc_t weights_md_reshaped; + memory_desc_t bias_md_reshaped; + }; + + status_t init(impl::engine_t *engine) override; + status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + kernel_t kernel_; + std::shared_ptr matmul_primitive; +}; +} // namespace dnnl::impl::gpu::generic::sycl + +#endif diff --git a/src/gpu/gpu_inner_product_list.cpp b/src/gpu/gpu_inner_product_list.cpp index b13f990a9a5..dccaedc1681 100644 --- a/src/gpu/gpu_inner_product_list.cpp +++ b/src/gpu/gpu_inner_product_list.cpp @@ -32,6 +32,10 @@ #include "gpu/amd/miopen_gemm_inner_product.hpp" #endif +#ifdef GENERIC_SYCL_KERNELS_ENABLED +#include "gpu/generic/sycl/ref_inner_product.hpp" +#endif + namespace dnnl { namespace impl { namespace gpu { @@ -49,6 +53,7 @@ const std::map> GPU_INSTANCE_NVIDIA(nvidia::cudnn_gemm_inner_product_fwd_t) GPU_INSTANCE_NVIDIA(nvidia::cudnn_conv_inner_product_fwd_t) GPU_INSTANCE_AMD(amd::miopen_gemm_inner_product_fwd_t) + GPU_INSTANCE_GENERIC_SYCL(generic::sycl::ref_inner_product_fwd_t) nullptr, }}, {{backward}, REG_BWD_PK({ diff --git a/src/gpu/nvidia/cudnn_matmul_executor.hpp b/src/gpu/nvidia/cudnn_matmul_executor.hpp index f78cd853d0a..1209995de93 100644 --- a/src/gpu/nvidia/cudnn_matmul_executor.hpp +++ b/src/gpu/nvidia/cudnn_matmul_executor.hpp @@ -392,12 +392,12 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t { memory_tracking::names::key_matmul_dst_in_acc_dt) : xpu::sycl::interop_memory_arg_t< ::sycl::access::mode::read_write>(); - auto arg_block_a_scratch = params->source_size_ != 0 + auto arg_block_a_scratch = params->weight_size_ != 0 ? CTX_SCRATCH_SYCL_MEMORY( memory_tracking::names::key_gemm_blocked_a) : xpu::sycl::interop_memory_arg_t< ::sycl::access::mode::read_write>(); - auto arg_block_b_scratch = params->weight_size_ != 0 + auto arg_block_b_scratch = params->source_size_ != 0 ? CTX_SCRATCH_SYCL_MEMORY( memory_tracking::names::key_gemm_blocked_b) : xpu::sycl::interop_memory_arg_t< @@ -457,10 +457,10 @@ struct cudnn_matmul_lt_runtime_args_exec_t final matmul_params->reorder_scratch_size_, cuda_stream->queue()); uint8_t *block_a_scratch_ptr - = alloc_ptr(matmul_params->source_size_, cuda_stream->queue()); + = alloc_ptr(matmul_params->weight_size_, cuda_stream->queue()); uint8_t *block_b_scratch_ptr - = alloc_ptr(matmul_params->weight_size_, cuda_stream->queue()); + = alloc_ptr(matmul_params->source_size_, cuda_stream->queue()); uint8_t *block_c_scratch_ptr = alloc_ptr(matmul_params->dest_size_, cuda_stream->queue()); diff --git a/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp b/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp index fda74f94e51..f529cd67000 100644 --- a/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp +++ b/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp @@ -717,7 +717,7 @@ struct cudnn_matmul_lt_impl_t { } if (!params->w_blocked_) { transform_matrix(lt_handle, params, a_layout, a, - blocked_a_layout, block_a_scratch, !params->trans_a_, + blocked_a_layout, block_a_scratch, params->trans_a_, streamId); a = block_a_scratch; } diff --git a/tests/gtests/test_inner_product_forward.cpp b/tests/gtests/test_inner_product_forward.cpp index a92be9571ba..c5672163926 100644 --- a/tests/gtests/test_inner_product_forward.cpp +++ b/tests/gtests/test_inner_product_forward.cpp @@ -88,16 +88,18 @@ class inner_product_test_t protected: void SetUp() override { auto p = ::testing::TestWithParam::GetParam(); - SKIP_IF_CUDA(!cuda_check_format_tags(p.src_format, p.weights_format, - p.bias_format, p.dst_format), + SKIP_IF_CUDA(!cuda_generic_check_format_tags(p.src_format, + p.weights_format, p.bias_format, p.dst_format), + "Unsupported format tag"); + SKIP_IF_GENERIC(!cuda_generic_check_format_tags(p.src_format, + p.weights_format, p.bias_format, p.dst_format), "Unsupported format tag"); SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions"); - SKIP_IF_GENERIC(true, "Primitive not implemented"); catch_expected_failures( [&]() { Test(); }, p.expect_to_fail, p.expected_status); } - bool cuda_check_format_tags(memory::format_tag src_format, + bool cuda_generic_check_format_tags(memory::format_tag src_format, memory::format_tag wei_format, memory::format_tag bia_format, memory::format_tag dst_format) { bool src_ok = src_format == memory::format_tag::ncdhw @@ -130,6 +132,20 @@ class inner_product_test_t return src_ok && wei_ok && bia_ok && dst_ok; } + std::vector get_dim_order(const memory::dims &strides) { + size_t ndims = strides.size(); + std::vector order(ndims); + for (size_t i = 0; i < ndims; ++i) { + order[i] = i; + } + + std::sort(order.begin(), order.end(), [&strides](size_t i, size_t j) { + return strides[i] < strides[j]; + }); + + return order; + } + void Test() { auto p = ::testing::TestWithParam::GetParam(); test_inner_product_descr_t ipd = p.test_ipd; @@ -169,6 +185,10 @@ class inner_product_test_t : create_md({}, data_type, p.bias_format); auto ip_dst_desc = create_md({ipd.mb, ipd.oc}, data_type, p.dst_format); + SKIP_IF_GENERIC(get_dim_order(ip_src_desc.get_strides()) + != get_dim_order(ip_weights_desc.get_strides()), + "Unsupported case for generic"); + auto ip_primitive_desc = with_bias ? pd_t(eng, p.aprop_kind, ip_src_desc, ip_weights_desc, ip_bias_desc, ip_dst_desc) @@ -176,11 +196,15 @@ class inner_product_test_t ip_dst_desc); auto aa = allows_attr_t {false}; - aa.po_binary = !is_nvidia_gpu(eng) && !is_amd_gpu(eng); aa.po_eltwise = true; - aa.po_prelu = !is_nvidia_gpu(eng) && !is_amd_gpu(eng); aa.po_sum = true; - +#ifdef DNNL_SYCL_GENERIC + aa.po_binary = true; + aa.po_prelu = true; +#else + aa.po_binary = !is_nvidia_gpu(eng) && !is_amd_gpu(eng); + aa.po_prelu = !is_nvidia_gpu(eng) && !is_amd_gpu(eng); +#endif test_fwd_pd_constructors(ip_primitive_desc, aa, p.aprop_kind, ip_src_desc, ip_weights_desc, ip_bias_desc, ip_dst_desc); From dc02a2cee4107a55985e4f000cc77329a16a27bd Mon Sep 17 00:00:00 2001 From: Svetlozar Georgiev Date: Fri, 13 Dec 2024 15:04:56 +0000 Subject: [PATCH 3/3] generic: sycl: inner product backward --- src/gpu/generic/sycl/README.md | 7 +- src/gpu/generic/sycl/ref_inner_product.cpp | 133 ++++++++- src/gpu/generic/sycl/ref_inner_product.hpp | 279 ++++++++++++++---- src/gpu/gpu_inner_product_list.cpp | 2 + src/gpu/nvidia/cudnn_matmul_impl.hpp | 4 - src/gpu/nvidia/cudnn_matmul_lt_impl.hpp | 4 - .../test_inner_product_backward_data.cpp | 26 +- .../test_inner_product_backward_weights.cpp | 31 +- 8 files changed, 396 insertions(+), 90 deletions(-) diff --git a/src/gpu/generic/sycl/README.md b/src/gpu/generic/sycl/README.md index f1c64ed91d0..95c1a4a51b8 100644 --- a/src/gpu/generic/sycl/README.md +++ b/src/gpu/generic/sycl/README.md @@ -96,11 +96,12 @@ The implementation supports both forward and backward directions. ## Inner Product -The implementation supports the forward direction only. +The implementation supports both forward and backward directions. * Supported formats: All plain formats are supported. -* Supported data types: All possible data combinations listed in the oneDNN specification are supported. -* Supported post-ops: All the post operations as mentioned in the specification are supported. +* Supported data types: All possible data combinations as listed in the specification are supported. +* Supported post-ops: All the post-ops as mentioned in the specification are supported. +* The backward pass does not support post-ops. One should not use post-ops in the forward pass during training ## Layer Normalization diff --git a/src/gpu/generic/sycl/ref_inner_product.cpp b/src/gpu/generic/sycl/ref_inner_product.cpp index e8e1f7f8d20..cd6df6fd6a7 100644 --- a/src/gpu/generic/sycl/ref_inner_product.cpp +++ b/src/gpu/generic/sycl/ref_inner_product.cpp @@ -16,27 +16,69 @@ *******************************************************************************/ #include "gpu/generic/sycl/ref_inner_product.hpp" -#include "common/primitive_desc_iterator.hpp" namespace dnnl::impl::gpu::generic::sycl { -status_t ref_inner_product_fwd_t::pd_t::init_matmul(impl::engine_t *engine) { +namespace detail { +status_t init_matmul_pd(impl::engine_t *engine, + const primitive_attr_t *attributes, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, + std::shared_ptr &matmul_pd) { + matmul_desc_t matmul_desc; - CHECK(matmul_desc_init(&matmul_desc, &src_md_reshaped, &weights_md_reshaped, - &bias_md_reshaped, arg_md(DNNL_ARG_DST))); - primitive_attr_t matmul_attr(*attr()); + CHECK(matmul_desc_init( + &matmul_desc, src_desc, weights_desc, bias_desc, dst_desc)); primitive_desc_iterator_t it(engine, - reinterpret_cast(&matmul_desc), &matmul_attr, nullptr); + reinterpret_cast(&matmul_desc), attributes, nullptr); + if (!it.is_initialized()) return status::invalid_arguments; while (++it != it.end()) { - matmul_pd = *it; - if (matmul_pd) { break; } + if (*it) { + matmul_pd = *it; + break; + } + } + return status::success; +} + +status_t flatten_md(const memory_desc_t &desc, memory_desc_t &flattened_md, + format_tag_t format_tag) { + // Always flattens from a nD to a 2D memory layout, with batch as the first dimension + assert(format_tag == format_tag::ab || format_tag == format_tag::ba); + const auto &dimensions = desc.dims; + int64_t flattened_dimension = 1; + for (int i = 1; i < desc.ndims; i++) { + flattened_dimension *= dimensions[i]; } - if (!matmul_pd) { return status::invalid_arguments; } + dims_t reshaped_dims; + if (format_tag == format_tag::ab) { + reshaped_dims[0] = desc.dims[0]; + reshaped_dims[1] = flattened_dimension; + } else { + reshaped_dims[0] = flattened_dimension; + reshaped_dims[1] = desc.dims[0]; + } + CHECK(memory_desc_init_by_tag( + flattened_md, 2, reshaped_dims, desc.data_type, format_tag)); return status::success; } +std::vector get_dim_order(int ndims, const dims_t strides) { + std::vector order(ndims); + for (int i = 0; i < ndims; ++i) { + order[i] = i; + } + + std::sort(order.begin(), order.end(), + [&strides](size_t i, size_t j) { return strides[i] < strides[j]; }); + + return order; +} + +} // namespace detail + status_t ref_inner_product_fwd_t::init(impl::engine_t *engine) { std::pair, cache_state_t> p; CHECK(pd()->matmul_pd->create_primitive_nested(p, engine)); @@ -52,4 +94,77 @@ status_t ref_inner_product_fwd_t::execute(const exec_ctx_t &ctx) const { return matmul_primitive->execute(copied_ctx); } +status_t ref_inner_product_bwd_data_t::init(impl::engine_t *engine) { + std::pair, cache_state_t> p; + CHECK(pd()->matmul_pd->create_primitive_nested(p, engine)); + matmul_primitive = p.first; + return status::success; +} + +status_t ref_inner_product_bwd_data_t::execute(const exec_ctx_t &ctx) const { + nested_scratchpad_t nested_scratchpad( + ctx, memory_tracking::names::key_nested, matmul_primitive); + + exec_args_t args_copy(ctx.args()); + // Map src and dst to diff_dst and diff_src respectively + args_copy[DNNL_ARG_SRC] = args_copy[DNNL_ARG_DIFF_DST]; + args_copy[DNNL_ARG_DST] = args_copy[DNNL_ARG_DIFF_SRC]; + exec_ctx_t copied_ctx(ctx.stream(), std::move(args_copy)); + + copied_ctx.set_scratchpad_grantor(nested_scratchpad.grantor()); + + return matmul_primitive->execute(copied_ctx); +} + +status_t ref_inner_product_bwd_weights_t::init(impl::engine_t *engine) { + std::pair, cache_state_t> p; + CHECK(pd()->matmul_pd->create_primitive_nested(p, engine)); + matmul_primitive = p.first; + + if (pd()->with_bias()) { + std::pair, cache_state_t> + p_reduction; + CHECK(pd()->reduction_pd->create_primitive_nested(p_reduction, engine)); + reduction_primitive = p_reduction.first; + } + + return status::success; +} + +status_t ref_inner_product_bwd_weights_t::execute(const exec_ctx_t &ctx) const { + nested_scratchpad_t nested_scratchpad( + ctx, memory_tracking::names::key_nested_multiple, matmul_primitive); + + exec_args_t args_copy(ctx.args()); + // Map src and dst to diff_dst and diff_src respectively + auto src_memory_arg = args_copy[DNNL_ARG_SRC]; + args_copy[DNNL_ARG_SRC] = args_copy[DNNL_ARG_DIFF_DST]; + args_copy[DNNL_ARG_WEIGHTS] = src_memory_arg; + args_copy[DNNL_ARG_DST] = args_copy[DNNL_ARG_DIFF_WEIGHTS]; + exec_ctx_t copied_ctx(ctx.stream(), std::move(args_copy)); + + copied_ctx.set_scratchpad_grantor(nested_scratchpad.grantor()); + // calcules dL/dW; + CHECK(matmul_primitive->execute(copied_ctx)); + + if (pd()->with_bias()) { + //calculates dL/dB + nested_scratchpad_t reduction_scratchpad(ctx, + memory_tracking::names::key_nested_multiple + 1, + reduction_primitive); + exec_args_t args_copy_reduction(ctx.args()); + args_copy_reduction[DNNL_ARG_SRC] + = args_copy_reduction[DNNL_ARG_DIFF_DST]; + args_copy_reduction[DNNL_ARG_DST] + = args_copy_reduction[DNNL_ARG_DIFF_BIAS]; + exec_ctx_t copied_ctx_reduction( + ctx.stream(), std::move(args_copy_reduction)); + + copied_ctx_reduction.set_scratchpad_grantor( + reduction_scratchpad.grantor()); + CHECK(reduction_primitive->execute(copied_ctx_reduction)); + } + return status::success; +} + } // namespace dnnl::impl::gpu::generic::sycl diff --git a/src/gpu/generic/sycl/ref_inner_product.hpp b/src/gpu/generic/sycl/ref_inner_product.hpp index 648d17bca49..f7824d872c2 100644 --- a/src/gpu/generic/sycl/ref_inner_product.hpp +++ b/src/gpu/generic/sycl/ref_inner_product.hpp @@ -18,6 +18,8 @@ #ifndef GPU_GENERIC_SYCL_REF_INNER_PRODUCT_HPP #define GPU_GENERIC_SYCL_REF_INNER_PRODUCT_HPP +#include "common/primitive_desc_iterator.hpp" +#include "common/reduction_pd.hpp" #include "gpu/generic/sycl/ref_matmul.hpp" #include "gpu/generic/sycl/sycl_gpu_primitive.hpp" #include "gpu/generic/sycl/sycl_post_ops.hpp" @@ -27,6 +29,20 @@ #include "gpu/gpu_primitive.hpp" namespace dnnl::impl::gpu::generic::sycl { + +namespace detail { +status_t init_matmul_pd(impl::engine_t *engine, + const primitive_attr_t *attributes, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, + std::shared_ptr &matmul_pd); + +status_t flatten_md(const memory_desc_t &desc, memory_desc_t &flattened_md, + format_tag_t format_tag); + +std::vector get_dim_order(int ndims, const dims_t strides); +} // namespace detail + struct ref_inner_product_fwd_t : public gpu::generic::sycl::primitive_t { using gpu::generic::sycl::primitive_t::primitive_t; @@ -43,6 +59,9 @@ struct ref_inner_product_fwd_t : public gpu::generic::sycl::primitive_t { auto bias_dt = with_bias() ? arg_md(DNNL_ARG_BIAS)->data_type : data_type::undef; + auto src_wrap = memory_desc_wrapper(src_md()); + auto wei_wrap = memory_desc_wrapper(weights_md()); + const bool ok = (set_default_params() == status::success) && is_fwd() && check_if_dtypes_valid( @@ -55,8 +74,32 @@ struct ref_inner_product_fwd_t : public gpu::generic::sycl::primitive_t { && memory_desc_wrapper(weights_md()).is_plain(); if (!ok) { return status::unimplemented; } - CHECK(create_ip_mds()); - CHECK(init_matmul(engine)); + + if (detail::get_dim_order(src_wrap.ndims(), src_wrap.strides()) + != detail::get_dim_order( + wei_wrap.ndims(), wei_wrap.strides())) { + return status::unimplemented; + } + + memory_desc_t src_reshaped; + memory_desc_t weights_reshaped; + memory_desc_t bias_reshaped = types::zero_md(); + CHECK(detail::flatten_md( + *arg_md(DNNL_ARG_SRC), src_reshaped, format_tag::ab)); + CHECK(detail::flatten_md(*arg_md(DNNL_ARG_WEIGHTS), + weights_reshaped, format_tag::ba)); + if (with_bias()) { + const auto bias_md = arg_md(DNNL_ARG_BIAS); + //Reshape bias to 1 x OC; + dims_t reshaped_bias_dims {1, bias_md->dims[0]}; + CHECK(memory_desc_init_by_tag(bias_reshaped, 2, + reshaped_bias_dims, bias_md->data_type, + format_tag::ab)); + } + + CHECK(gpu::generic::sycl::detail::init_matmul_pd(engine, attr(), + &src_reshaped, &weights_reshaped, &bias_reshaped, + arg_md(DNNL_ARG_DST), matmul_pd)); // book scratchpad for the matmul auto scratchpad = scratchpad_registry().registrar(); @@ -89,77 +132,188 @@ struct ref_inner_product_fwd_t : public gpu::generic::sycl::primitive_t { && utils::one_of(dst_dt, f32, bf16) && utils::one_of(bias_dt, f32, bf16, undef)); } + }; - std::vector get_dim_order(int ndims, const dims_t strides) { - std::vector order(ndims); - for (int i = 0; i < ndims; ++i) { - order[i] = i; - } + status_t init(impl::engine_t *engine) override; + status_t execute(const exec_ctx_t &ctx) const override; - std::sort( - order.begin(), order.end(), [&strides](size_t i, size_t j) { - return strides[i] < strides[j]; - }); +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::shared_ptr matmul_primitive; +}; - return order; +struct ref_inner_product_bwd_data_t : public gpu::generic::sycl::primitive_t { + + using gpu::generic::sycl::primitive_t::primitive_t; + + struct pd_t : public gpu_inner_product_bwd_data_pd_t { + using gpu_inner_product_bwd_data_pd_t::gpu_inner_product_bwd_data_pd_t; + DECLARE_COMMON_PD_T("dpcpp:ref:any", ref_inner_product_bwd_data_t); + + status_t init(impl::engine_t *engine) { + auto src_dt = arg_md(DNNL_ARG_DIFF_DST)->data_type; + auto weights_dt = arg_md(DNNL_ARG_WEIGHTS)->data_type; + auto dst_dt = arg_md(DNNL_ARG_DIFF_SRC)->data_type; + + bool ok = !is_fwd() && (set_default_params() == status::success) + && check_bwd_data_dtypes(src_dt, dst_dt, weights_dt) + && attr()->has_default_values() // no post-op is supported + && memory_desc_wrapper(arg_md(DNNL_ARG_DIFF_DST)).is_plain() + && memory_desc_wrapper(arg_md(DNNL_ARG_DIFF_SRC)) + .is_plain(); // Blocked memory formats are not supported + if (!ok) { return status::unimplemented; } + + // dL/dX = (dL/dY) x W (hence no transpose required here) + auto empty_bias_desc = types:: + zero_md(); // empty memory descriptor to signify bias is not applied + + // Temporary memory descriptors to initialize matmul_pd; diff_dst will always be 2D + memory_desc_t reshaped_diff_src_md; + memory_desc_t reshaped_weights_md; + CHECK(detail::flatten_md(*arg_md(DNNL_ARG_DIFF_SRC), + reshaped_diff_src_md, format_tag::ab)); + CHECK(detail::flatten_md(*arg_md(DNNL_ARG_WEIGHTS), + reshaped_weights_md, format_tag::ab)); + + CHECK(gpu::generic::sycl::detail::init_matmul_pd(engine, attr(), + arg_md(DNNL_ARG_DIFF_DST), &reshaped_weights_md, + &empty_bias_desc, &reshaped_diff_src_md, matmul_pd)); + + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_nested, + matmul_pd->scratchpad_registry()); + return status::success; } - status_t create_ip_mds() { - auto accumulate_dimensions = [](const dims_t dimensions, int start, - int end) -> int64_t { - int64_t accum = 1; - for (int i = start; i < end; i++) { - accum *= dimensions[i]; - } - return accum; - }; - - const auto src_md_ = arg_md(DNNL_ARG_SRC); - const auto weights_md_ = arg_md(DNNL_ARG_WEIGHTS); - const auto bias_md_ = arg_md(DNNL_ARG_BIAS); - auto src_wrap = memory_desc_wrapper(src_md_); - auto w_wrap = memory_desc_wrapper(weights_md_); - - // src and weights dims need to be in the same order - if (get_dim_order(src_wrap.ndims(), src_wrap.strides()) - != get_dim_order(w_wrap.ndims(), w_wrap.strides())) { - return status::unimplemented; - } + std::shared_ptr matmul_pd; - // Reshape input into the form of Batch x (\prod_{dim_{n-1}}^dim_0) - if (src_md_->ndims == 2) { - src_md_reshaped = *src_md_; - } else { - int64_t src_flattened_dimension = accumulate_dimensions( - src_md_->dims, 1, src_md_->ndims); - dims_t src_reshaped_dims { - src_md_->dims[0], src_flattened_dimension}; - CHECK(memory_desc_init_by_tag(src_md_reshaped, 2, - src_reshaped_dims, src_md_->data_type, format_tag::ab)); - } + private: + bool check_bwd_data_dtypes(const data_type_t &src_dt, + const data_type_t &dst_dt, const data_type_t &weight_dt) { + using namespace data_type; + return (utils::one_of(src_dt, f32) + && utils::one_of(dst_dt, f32, f16, bf16) + && utils::one_of(weight_dt, f32, bf16, f16)) + || (utils::one_of(src_dt, bf16) + && utils::one_of(dst_dt, bf16) + && utils::one_of(weight_dt, bf16)) + || (utils::one_of(src_dt, f16) && utils::one_of(dst_dt, f16) + && utils::one_of(weight_dt, f16)); + } + }; + + status_t init(impl::engine_t *engine) override; + status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::shared_ptr matmul_primitive; +}; + +struct ref_inner_product_bwd_weights_t + : public gpu::generic::sycl::primitive_t { + + using gpu::generic::sycl::primitive_t::primitive_t; + + struct pd_t : public gpu_inner_product_bwd_weights_pd_t { + using gpu_inner_product_bwd_weights_pd_t:: + gpu_inner_product_bwd_weights_pd_t; + DECLARE_COMMON_PD_T("dpcpp:ref:any", ref_inner_product_bwd_weights_t); + + status_t init(impl::engine_t *engine) { + auto src_dt = arg_md(DNNL_ARG_DIFF_DST)->data_type; + auto weights_dt = arg_md(DNNL_ARG_SRC)->data_type; + auto dst_dt = arg_md(DNNL_ARG_DIFF_WEIGHTS)->data_type; + auto bias_dt = arg_md(DNNL_ARG_DIFF_BIAS)->data_type; + + bool ok = !is_fwd() && (set_default_params() == status::success) + && check_bwd_weights_dtypes( + src_dt, dst_dt, weights_dt, bias_dt) + && attr()->has_default_values() // no post-op is supported + && memory_desc_wrapper(arg_md(DNNL_ARG_DIFF_DST)).is_plain() + && memory_desc_wrapper(arg_md(DNNL_ARG_SRC)).is_plain() + && memory_desc_wrapper(arg_md(DNNL_ARG_DIFF_WEIGHTS)) + .is_plain(); - // Reshape weights as (OC x (\prod_{dim_{n-1}}^dim_0))^T - int weights_flattened_dimensions = accumulate_dimensions( - weights_md_->dims, 1, weights_md_->ndims); - dims_t weights_reshaped_dims { - weights_flattened_dimensions, weights_md_->dims[0]}; - CHECK(memory_desc_init_by_tag(weights_md_reshaped, 2, - weights_reshaped_dims, weights_md_->data_type, - format_tag::ba)); + if (!ok) { return status::unimplemented; }; + + memory_desc_t reshaped_src_md; + memory_desc_t reshaped_diff_wt_md; + memory_desc_t reshaped_diff_dst_md; + auto empty_bias_desc = types:: + zero_md(); // empty memory descriptor to signify bias is not applied + // (dL / dW) = (dL/dY) ^ T x X; + CHECK(detail::flatten_md( + *arg_md(DNNL_ARG_SRC), reshaped_src_md, format_tag::ab)); + CHECK(detail::flatten_md(*arg_md(DNNL_ARG_DIFF_DST), + reshaped_diff_dst_md, format_tag::ba)); + CHECK(detail::flatten_md(*arg_md(DNNL_ARG_DIFF_WEIGHTS), + reshaped_diff_wt_md, format_tag::ab)); + + // Create matmul_pd for dL/dW + CHECK(detail::init_matmul_pd(engine, attr(), &reshaped_diff_dst_md, + &reshaped_src_md, &empty_bias_desc, &reshaped_diff_wt_md, + matmul_pd)); + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_nested_multiple, + matmul_pd->scratchpad_registry()); + + //Create reduction_pd for dL/dB if (with_bias()) { - dims_t bias_reshaped_dims {1, bias_md_->dims[0]}; - CHECK(memory_desc_init_by_tag(bias_md_reshaped, 2, - bias_reshaped_dims, bias_md_->data_type, - format_tag::ab)); + CHECK(init_reduction_pd(engine, arg_md(DNNL_ARG_DIFF_DST), + arg_md(DNNL_ARG_DIFF_BIAS))); + // book scratchpad for reduction + scratchpad.book(memory_tracking::names::key_nested_multiple + 1, + reduction_pd->scratchpad_registry()); } return status::success; } - status_t init_matmul(impl::engine_t *engine); - // Memory descriptors to contain reshaped tensors from nD to 2D for IP - memory_desc_t src_md_reshaped; - memory_desc_t weights_md_reshaped; - memory_desc_t bias_md_reshaped; + std::shared_ptr matmul_pd; + std::shared_ptr reduction_pd; + + private: + bool check_bwd_weights_dtypes(const data_type_t &src_dt, + const data_type_t &dst_dt, const data_type_t &weight_dt, + const data_type_t &bias_dt) { + using namespace data_type; + return (utils::one_of(src_dt, f32) && utils::one_of(dst_dt, f32) + && utils::one_of(weight_dt, f32) + && utils::one_of(bias_dt, f32, undef)) + || (utils::one_of(src_dt, bf16) + && utils::one_of(dst_dt, bf16) + && utils::one_of(weight_dt, f32, bf16) + && utils::one_of(bias_dt, f32, bf16, undef)) + || (utils::one_of(src_dt, f16) && utils::one_of(dst_dt, f16) + && utils::one_of(weight_dt, f32, f16) + && utils::one_of(bias_dt, f32, f16, undef)); + } + + status_t init_reduction_pd(impl::engine_t *engine, + const memory_desc_t *src_desc, const memory_desc_t *dest_desc) { + reduction_desc_t reduction_descriptor; + //diff_bias is 1D, diff_dst will be 2D, reshape diff_bias to 1xOC + dims_t diff_bias_reshaped_dims {1, dest_desc->dims[0]}; + memory_desc_t diff_bias_reshaped; + CHECK(memory_desc_init_by_tag(diff_bias_reshaped, 2, + diff_bias_reshaped_dims, dest_desc->data_type, + format_tag::ab)); + CHECK(reduction_desc_init(&reduction_descriptor, + alg_kind::reduction_sum, src_desc, &diff_bias_reshaped, + 0.0f, 0.0f)); + primitive_desc_iterator_t it(engine, + reinterpret_cast(&reduction_descriptor), + attr(), nullptr); + + if (!it.is_initialized()) return status::invalid_arguments; + while (++it != it.end()) { + if (*it) { + reduction_pd = *it; + break; + } + } + return status::success; + } }; status_t init(impl::engine_t *engine) override; @@ -167,9 +321,10 @@ struct ref_inner_product_fwd_t : public gpu::generic::sycl::primitive_t { private: const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - kernel_t kernel_; std::shared_ptr matmul_primitive; + std::shared_ptr reduction_primitive; }; + } // namespace dnnl::impl::gpu::generic::sycl #endif diff --git a/src/gpu/gpu_inner_product_list.cpp b/src/gpu/gpu_inner_product_list.cpp index dccaedc1681..00e20fed8c5 100644 --- a/src/gpu/gpu_inner_product_list.cpp +++ b/src/gpu/gpu_inner_product_list.cpp @@ -67,6 +67,8 @@ const std::map> GPU_INSTANCE_NVIDIA(nvidia::cudnn_conv_inner_product_bwd_weights_t) GPU_INSTANCE_AMD(amd::miopen_gemm_inner_product_bwd_data_t) GPU_INSTANCE_AMD(amd::miopen_gemm_inner_product_bwd_weights_t) + GPU_INSTANCE_GENERIC_SYCL(generic::sycl::ref_inner_product_bwd_data_t) + GPU_INSTANCE_GENERIC_SYCL(generic::sycl::ref_inner_product_bwd_weights_t) nullptr, })}, }); diff --git a/src/gpu/nvidia/cudnn_matmul_impl.hpp b/src/gpu/nvidia/cudnn_matmul_impl.hpp index bdbb25358c2..2be67966873 100644 --- a/src/gpu/nvidia/cudnn_matmul_impl.hpp +++ b/src/gpu/nvidia/cudnn_matmul_impl.hpp @@ -459,10 +459,6 @@ struct cudnn_matmul_impl_t { cudnn_handle, c, bias, reorder_scratch, host_dst_scale); } - ~cudnn_matmul_impl_t() { - if (matmul_params_) { matmul_params_->cleanup(); } - } - private: std::shared_ptr matmul_params_; }; diff --git a/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp b/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp index f529cd67000..01854a48e8f 100644 --- a/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp +++ b/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp @@ -809,10 +809,6 @@ struct cudnn_matmul_lt_impl_t { } } - ~cudnn_matmul_lt_impl_t() { - if (matmul_params_) { matmul_params_->cleanup(); } - } - private: void transform_matrix(cublasLtHandle_t handle, const std::shared_ptr ¶ms, diff --git a/tests/gtests/test_inner_product_backward_data.cpp b/tests/gtests/test_inner_product_backward_data.cpp index a3ff4677750..ff0e36ae722 100644 --- a/tests/gtests/test_inner_product_backward_data.cpp +++ b/tests/gtests/test_inner_product_backward_data.cpp @@ -96,16 +96,18 @@ class inner_product_test_bwd_data_t protected: void SetUp() override { auto p = ::testing::TestWithParam::GetParam(); - SKIP_IF_CUDA(!cuda_check_format_tags(p.diff_src_format, + SKIP_IF_CUDA(!cuda_generic_check_format_tags(p.diff_src_format, p.weights_format, p.diff_dst_format), "Unsupported format tag"); SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions"); - SKIP_IF_GENERIC(true, "Primitive not implemented"); + SKIP_IF_GENERIC(!cuda_generic_check_format_tags(p.diff_src_format, + p.weights_format, p.diff_dst_format), + "Unsupported format tag"); catch_expected_failures( [&]() { Test(); }, p.expect_to_fail, p.expected_status); } - bool cuda_check_format_tags(memory::format_tag diff_src_format, + bool cuda_generic_check_format_tags(memory::format_tag diff_src_format, memory::format_tag wei_format, memory::format_tag diff_dst_format) { bool diff_src_ok = diff_src_format == memory::format_tag::ncdhw || diff_src_format == memory::format_tag::ndhwc @@ -133,6 +135,20 @@ class inner_product_test_bwd_data_t return diff_src_ok && wei_ok && diff_dst_ok; } + std::vector get_dim_order(const memory::dims &strides) { + size_t ndims = strides.size(); + std::vector order(ndims); + for (size_t i = 0; i < ndims; ++i) { + order[i] = i; + } + + std::sort(order.begin(), order.end(), [&strides](size_t i, size_t j) { + return strides[i] < strides[j]; + }); + + return order; + } + void Test() { auto p = ::testing::TestWithParam::GetParam(); test_inner_product_descr_t ipd = p.test_ipd; @@ -171,6 +187,10 @@ class inner_product_test_bwd_data_t auto ip_diff_dst_desc = create_md({ipd.mb, ipd.oc}, data_type, p.diff_dst_format); + SKIP_IF_GENERIC(get_dim_order(ip_diff_src_desc.get_strides()) + != get_dim_order(ip_weights_desc.get_strides()), + "Unsupported case for generic"); + // Create inner product forward (hint for backward) auto ip_fwd_pdesc = hint_pd_t(eng, prop_kind::forward, ip_diff_src_desc, ip_weights_desc, ip_diff_dst_desc); diff --git a/tests/gtests/test_inner_product_backward_weights.cpp b/tests/gtests/test_inner_product_backward_weights.cpp index c5c45c57e98..fd8b974ca38 100644 --- a/tests/gtests/test_inner_product_backward_weights.cpp +++ b/tests/gtests/test_inner_product_backward_weights.cpp @@ -124,17 +124,20 @@ class inner_product_test_bwd_weights_t protected: void SetUp() override { auto p = ::testing::TestWithParam::GetParam(); - SKIP_IF_CUDA( - !cuda_check_format_tags(p.src_format, p.diff_weights_format, - p.diff_bias_format, p.diff_dst_format), + SKIP_IF_CUDA(!cuda_generic_check_format_tags(p.src_format, + p.diff_weights_format, p.diff_bias_format, + p.diff_dst_format), "Unsupported format tag"); SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions"); - SKIP_IF_GENERIC(true, "Primitive not implemented"); + SKIP_IF_GENERIC(!cuda_generic_check_format_tags(p.src_format, + p.diff_weights_format, p.diff_bias_format, + p.diff_dst_format), + "Unsupported format tag"); catch_expected_failures( [&]() { Test(); }, p.expect_to_fail, p.expected_status); } - bool cuda_check_format_tags(memory::format_tag src_format, + bool cuda_generic_check_format_tags(memory::format_tag src_format, memory::format_tag diff_wei_format, memory::format_tag diff_bia_format, memory::format_tag diff_dst_format) { @@ -168,6 +171,20 @@ class inner_product_test_bwd_weights_t return src_ok && diff_wei_ok && diff_bia_ok && diff_dst_ok; } + std::vector get_dim_order(const memory::dims &strides) { + size_t ndims = strides.size(); + std::vector order(ndims); + for (size_t i = 0; i < ndims; ++i) { + order[i] = i; + } + + std::sort(order.begin(), order.end(), [&strides](size_t i, size_t j) { + return strides[i] < strides[j]; + }); + + return order; + } + void Test() { auto p = ::testing::TestWithParam::GetParam(); test_inner_product_descr_t ipd = p.test_ipd; @@ -211,6 +228,10 @@ class inner_product_test_bwd_weights_t ? create_md({ipd.oc}, data_type, p.diff_bias_format) : create_md({}, data_type, p.diff_bias_format); + SKIP_IF_GENERIC(get_dim_order(ip_src_desc.get_strides()) + != get_dim_order(ip_diff_weights_desc.get_strides()), + "Unsupported case for generic"); + // Create inner product forward (hint for backward) auto ip_fwd_pdesc = inner_product_forward::primitive_desc(eng, prop_kind::forward,