diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp index 4adb730bd9b..56147bcc283 100644 --- a/src/cpu/cpu_inner_product_list.cpp +++ b/src/cpu/cpu_inner_product_list.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ #if DNNL_X64 #include "cpu/x64/gemm_bf16_inner_product.hpp" #include "cpu/x64/jit_brgemm_inner_product.hpp" +#include "cpu/x64/matmul_inner_product.hpp" using namespace dnnl::impl::cpu::x64; #endif @@ -43,6 +44,7 @@ using namespace dnnl::impl::prop_kind; #define BRGEMM_FP8_FWD_IP(dtsrc, dtwei, dtdst) \ { \ {forward, dtsrc, dtwei, dtdst}, { \ + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) \ CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) \ CPU_INSTANCE(ref_inner_product_fwd_t) nullptr, \ } \ @@ -52,6 +54,7 @@ using namespace dnnl::impl::prop_kind; const std::map> &impl_list_map() { static const std::map> the_map = REG_IP_P({ {{forward, f32, f32, f32}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) // bf32 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) @@ -61,6 +64,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, bf16, bf16, f32}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t) @@ -69,6 +73,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, bf16, bf16, bf16}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t) @@ -77,6 +82,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, f16, f16, f32}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) @@ -84,6 +90,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, f16, f16, f16}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) @@ -187,6 +194,7 @@ const std::map> &impl_list_map() nullptr, })}, {{forward, s8, s8, f32}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) @@ -197,6 +205,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, s8, s8, s32}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) @@ -207,6 +216,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, s8, s8, s8}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) @@ -217,6 +227,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, s8, s8, u8}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) @@ -227,6 +238,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, u8, s8, f32}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) @@ -237,6 +249,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, u8, s8, s32}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) @@ -247,6 +260,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, u8, s8, s8}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) @@ -257,6 +271,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, u8, s8, u8}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) @@ -267,6 +282,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, s8, s8, bf16}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) @@ -275,6 +291,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, u8, s8, bf16}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) diff --git a/src/cpu/x64/jit_brgemm_inner_product.hpp b/src/cpu/x64/jit_brgemm_inner_product.hpp index dca15bb151c..8caa350224d 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.hpp +++ b/src/cpu/x64/jit_brgemm_inner_product.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,7 +67,9 @@ struct brgemm_inner_product_fwd_t : public primitive_t { // better readability if (!mayiuse(isa)) return status::unimplemented; - VDISPATCH_INNER_PRODUCT(is_fwd(), VERBOSE_BAD_PROPKIND); + VDISPATCH_INNER_PRODUCT( + get_prop_kind() == prop_kind::forward_training, + VERBOSE_BAD_PROPKIND); VDISPATCH_INNER_PRODUCT( expect_data_types(src_dt, wei_dt, data_type::undef, dst_dt, data_type::undef), diff --git a/src/cpu/x64/matmul_inner_product.cpp b/src/cpu/x64/matmul_inner_product.cpp new file mode 100644 index 00000000000..9a4aa756066 --- /dev/null +++ b/src/cpu/x64/matmul_inner_product.cpp @@ -0,0 +1,276 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "cpu/x64/matmul_inner_product.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +status_t create_matmul_pd(std::shared_ptr &matmul_pd, + engine_t *engine, const memory_desc_t *src_md, + const memory_desc_t *wei_md, const memory_desc_t *dst_md, + const memory_desc_t *bia_md, const primitive_attr_t *attr) { + auto matmul_desc = matmul_desc_t(); + + CHECK(matmul_desc_init(&matmul_desc, src_md, wei_md, bia_md, dst_md)); + + primitive_desc_iterator_t it( + engine, (op_desc_t *)&matmul_desc, attr, nullptr); + + matmul_pd = *(++it); + if (!matmul_pd) return status::unimplemented; + + return status::success; +} + +status_t init_matmul_md(memory_desc_t &mm_md, const memory_desc_t &ip_md, + format_tag_t tag, bool swap_dims) { + auto p_dims = ip_md.dims; + auto p_dim1 = utils::array_product(p_dims + 1, ip_md.ndims - 1); + + if (swap_dims) { + dims_t dims_2d = {p_dim1, p_dims[0]}; + return memory_desc_init_by_tag(mm_md, 2, dims_2d, ip_md.data_type, tag); + } else { + dims_t dims_2d = {p_dims[0], p_dim1}; + return memory_desc_init_by_tag(mm_md, 2, dims_2d, ip_md.data_type, tag); + } +} + +int matmul_inner_product_fwd_t::pd_t::get_k_blk(format_tag_t tag) const { + using namespace format_tag; + switch (tag) { + case ba: return 0; + case BA8a8b: + case BA8a24b: return 8; + case BA16a16b: + case BA16a32b: + case BA16a48b: + case BA16a64b: return 16; + case BA16a16b2a: + case BA16a32b2a: + case BA16a48b2a: + case BA16a64b2a: return 32; + case BA16a16b4a: + case BA16a32b4a: + case BA16a48b4a: + case BA16a64b4a: return 64; + default: assert(!"unsupported tag"); return -1; + } +} + +// This implementation is completely based on the MatMul primitive and is +// currently enabled only for `forward_inference` propagation kind. +// +// The implementation allows using blocked weights layouts directly or via +// the special tag `any`. +// The Inner Product weights must meet **ONE** of the following requirements to +// enable using the blocked layouts: +// - Weights don't have spatial. +// - Weights have unit spatial. +// - Weights have non-unit spatial but the number of input channels is a +// multiple of K block (returned by `get_k_blk()`). +// +// If none of the above requirements are met then a plain layout will be +// used. +// +// Note: this implementation is only guranteed to work with a set of the +// pre-defined layouts therefore there is no need to implement a generic +// mechanism to map inner product weights layouts to the matmul ones and +// vice versa. +status_t matmul_inner_product_fwd_t::pd_t::init_matmul_params( + engine_t *engine) { + using namespace format_tag; + + // clang-format off + static const std::map> mm_wei_to_ip_wei = { + { ba, {ab, acb, acdb, acdeb}}, + { BA8a8b, {AB8b8a, AcB8b8a, AcdB8b8a, AcdeB8b8a}}, + { BA8a24b, {AB8b24a, AcB8b24a, AcdB8b24a, AcdeB8b24a}}, + { BA16a16b, {AB16b16a, AcB16b16a, AcdB16b16a, AcdeB16b16a}}, + { BA16a32b, {AB16b32a, AcB16b32a, AcdB16b32a, AcdeB16b32a}}, + { BA16a48b, {AB16b48a, AcB16b48a, AcdB16b48a, AcdeB16b48a}}, + { BA16a64b, {AB16b64a, AcB16b64a, AcdB16b64a, AcdeB16b64a}}, + { BA16a16b2a, {AB16b16a2b, AcB16b16a2b, AcdB16b16a2b, AcdeB16b16a2b}}, + { BA16a32b2a, {AB16b32a2b, AcB16b32a2b, AcdB16b32a2b, AcdeB16b32a2b}}, + { BA16a48b2a, {AB16b48a2b, AcB16b48a2b, AcdB16b48a2b, AcdeB16b48a2b}}, + { BA16a64b2a, {AB16b64a2b, AcB16b64a2b, AcdB16b64a2b, AcdeB16b64a2b}}, + { BA16a16b4a, {AB16b16a4b, AcB16b16a4b, AcdB16b16a4b, AcdeB16b16a4b}}, + { BA16a32b4a, {AB16b32a4b, AcB16b32a4b, AcdB16b32a4b, AcdeB16b32a4b}}, + { BA16a48b4a, {AB16b48a4b, AcB16b48a4b, AcdB16b48a4b, AcdeB16b48a4b}}, + { BA16a64b4a, {AB16b64a4b, AcB16b64a4b, AcdB16b64a4b, AcdeB16b64a4b}}}; + // clang-format on + + auto mm_wei_tag = format_tag::undef; + // Try to initialize Inner Product weights layout based on the user-provided + // layout. + if (weights_md()->format_kind != format_kind::any) { + for (const auto &v : mm_wei_to_ip_wei) { + if (memory_desc_matches_tag( + *weights_md(), v.second[weights_md()->ndims - 2])) { + mm_wei_tag = v.first; + // Check if the user-provided blocked layout can be handled. + const bool has_spatial = KD() + KH() + KW() > 3; + const int k_blk = get_k_blk(mm_wei_tag); + const bool is_wtag_supported = !(weights_md()->ndims > 2 + && has_spatial && k_blk > 0 && IC() % k_blk != 0); + VDISPATCH_INNER_PRODUCT(is_wtag_supported, + VERBOSE_UNSUPPORTED_TAG_S, "weights"); + break; + } + } + } else { + mm_wei_tag = format_tag::any; + } + + VDISPATCH_INNER_PRODUCT(mm_wei_tag != format_tag::undef, + VERBOSE_UNSUPPORTED_TAG_S, "weights"); + + memory_desc_t mm_src_md {}; + memory_desc_t mm_wei_md {}; + memory_desc_t mm_dst_md {}; + + if (bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, x)); + + CHECK(init_matmul_md(mm_src_md, *src_md(), format_tag::ab)); + CHECK(init_matmul_md(mm_wei_md, *weights_md(), mm_wei_tag, true)); + CHECK(init_matmul_md(mm_dst_md, *dst_md(), format_tag::ab)); + + const auto src_tag = utils::pick(src_md()->ndims - 2, ab, acb, acdb, acdeb); + if (src_md()->format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_md_, src_md_.ndims, src_md_.dims, + src_md_.data_type, src_tag)); + else + VDISPATCH_INNER_PRODUCT(memory_desc_matches_tag(*src_md(), src_tag), + VERBOSE_UNSUPPORTED_TAG_S, "src"); + + const auto dst_tag = ab; + if (dst_md()->format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_md_, dst_md_.ndims, dst_md_.dims, + dst_md_.data_type, dst_tag)); + else + VDISPATCH_INNER_PRODUCT(memory_desc_matches_tag(*dst_md(), dst_tag), + VERBOSE_UNSUPPORTED_TAG_S, "dst"); + + VDISPATCH_INNER_PRODUCT_SC( + attr_.set_default_formats(dst_md(0)), VERBOSE_UNSUPPORTED_POSTOP); + + primitive_attr_t matmul_attr = *attr(); + const auto wei_mask = matmul_attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; + if (wei_mask == 1) + VDISPATCH_INNER_PRODUCT_SC(matmul_attr.scales_.set(DNNL_ARG_WEIGHTS, + 1 << (mm_wei_md.ndims - 1)), + VERBOSE_UNSUPPORTED_ATTR); + else if (wei_mask != 0) + VDISPATCH_INNER_PRODUCT(false, VERBOSE_UNSUPPORTED_SCALES_CFG); + + memory_desc_t mm_bia_md {}; + // Inner Product bias is always a vector while MatMul requires bias to have + // the same number of dimensions as that of the output tensor, therefore an + // adjustment is required. + if (with_bias()) { + assert(weights_md(1)->ndims == 1); + dims_t mm_bia_dims = {1, weights_md(1)->dims[0]}; + CHECK(memory_desc_init_by_tag(mm_bia_md, 2, mm_bia_dims, + weights_md(1)->data_type, format_tag::ab)); + } + + CHECK(create_matmul_pd(matmul_pd_, engine, &mm_src_md, &mm_wei_md, + &mm_dst_md, with_bias() ? &mm_bia_md : nullptr, &matmul_attr)); + + // Fallback to a generic GEMM-based Inner Product is preferred rather + // than using a reference or GEMM-based MatMul implementations here. + const bool is_desired_mm_impl + = std::string(matmul_pd_->name()).find("brg_matmul") + != std::string::npos; + VDISPATCH_INNER_PRODUCT(is_desired_mm_impl, VERBOSE_PRIMITIVE_CREATION_FAIL, + "matmul:brg_matmul"); + + // Try to initialize Inner Product weights layout based on the MatMul's one. + if (weights_md()->format_kind == format_kind::any) { + // If the table doesn't have the required layout then fallback + // is needed. + bool is_fallback_required = true; + format_tag_t ip_wei_tag = format_tag::undef; + const auto &mm_queried_wei_md = *matmul_pd_->weights_md(); + for (const auto &v : mm_wei_to_ip_wei) { + if (memory_desc_matches_tag(mm_queried_wei_md, v.first)) { + // Check if the implementation defined blocked layout can be + // handled. + const bool has_spatial = KD() + KH() + KW() > 3; + const int k_blk = get_k_blk(v.first); + is_fallback_required = weights_md()->ndims > 2 && has_spatial + && k_blk > 0 && IC() % k_blk != 0; + + if (!is_fallback_required) + ip_wei_tag = v.second[weights_md()->ndims - 2]; + break; + } + } + if (is_fallback_required) { + // Re-initialize MatMul weights memory descriptor with a plain + // layout. + CHECK(init_matmul_md( + mm_wei_md, *weights_md(), format_tag::ba, true)); + // Re-create MatMul primitive descriptor. + CHECK(create_matmul_pd(matmul_pd_, engine, &mm_src_md, &mm_wei_md, + &mm_dst_md, with_bias() ? &mm_bia_md : nullptr, + &matmul_attr)); + ip_wei_tag = utils::pick( + weights_md()->ndims - 2, ab, acb, acdb, acdeb); + } + CHECK(memory_desc_init_by_tag(weights_md_, weights_md_.ndims, + weights_md_.dims, weights_md_.data_type, ip_wei_tag)); + // Carry over the extra info from MatMul weights memory descriptor. + if (!is_fallback_required && mm_queried_wei_md.extra.flags != 0) { + weights_md_.extra = mm_queried_wei_md.extra; + // Since IP weights are transposed we need to swap bits + // (mask: 2 -> 1). + weights_md_.extra.compensation_mask = 1; + } + } else { + // At this point it's guaranteed that the table contains the requested + // layout that can be handled. + const auto &ip_wei_tags = mm_wei_to_ip_wei.at(mm_wei_tag); + const auto ip_wei_tag = ip_wei_tags[weights_md()->ndims - 2]; + CHECK(memory_desc_init_by_tag(weights_md_, weights_md_.ndims, + weights_md_.dims, weights_md_.data_type, ip_wei_tag)); + } + + return status::success; +} + +status_t matmul_inner_product_fwd_t::execute(const exec_ctx_t &ctx) const { + using namespace memory_tracking::names; + + exec_args_t matmul_args = ctx.args(); + exec_ctx_t matmul_ctx(ctx, std::move(matmul_args)); + + nested_scratchpad_t ns(ctx, key_nested, matmul_); + matmul_ctx.set_scratchpad_grantor(ns.grantor()); + + return matmul_->execute(matmul_ctx); +} + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/x64/matmul_inner_product.hpp b/src/cpu/x64/matmul_inner_product.hpp new file mode 100644 index 00000000000..03aa84a3fad --- /dev/null +++ b/src/cpu/x64/matmul_inner_product.hpp @@ -0,0 +1,109 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_MATMUL_INNER_PRODUCT_HPP +#define CPU_X64_MATMUL_INNER_PRODUCT_HPP + +#include + +#include "common/c_types_map.hpp" +#include "common/matmul_pd.hpp" +#include "common/primitive.hpp" +#include "common/primitive_desc_iterator.hpp" + +#include "cpu/cpu_inner_product_pd.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +status_t create_matmul_pd(std::shared_ptr &matmul_pd, + engine_t *engine, const memory_desc_t *a_md, const memory_desc_t *b_md, + const memory_desc_t *c_md, const memory_desc_t *ip_bia_md, + const primitive_attr_t *attr); + +status_t init_matmul_md(memory_desc_t &mm_md, const memory_desc_t &ip_md, + format_tag_t tag, bool swap_dims = false); + +struct matmul_inner_product_fwd_t : public primitive_t { + using primitive_t::primitive_t; + struct pd_t : public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T((matmul_pd_ ? matmul_pd_->name() : "matmul"), + matmul_inner_product_fwd_t); + + status_t init(impl::engine_t *engine) { + using namespace data_type; + using skip_mask_t = primitive_attr_t::skip_mask_t; + + const auto src_dt = invariant_src_md()->data_type; + const auto wei_dt = invariant_wei_md()->data_type; + const auto dst_dt = invariant_dst_md()->data_type; + const bool is_int8 = utils::one_of(src_dt, u8, s8) && wei_dt == s8 + && utils::one_of(dst_dt, u8, s8, s32, f32, bf16); + + auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt + | skip_mask_t::fpmath_mode; + if (is_int8) skip_mask |= skip_mask_t::scales_runtime; + + // This implementation is currently enabled only for inference. + VDISPATCH_INNER_PRODUCT( + get_prop_kind() == prop_kind::forward_inference, + VERBOSE_BAD_PROPKIND); + VDISPATCH_INNER_PRODUCT( + !has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); + VDISPATCH_INNER_PRODUCT(attr()->has_default_values(skip_mask), + VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_INNER_PRODUCT_SC( + init_matmul_params(engine), "init_matmul_params"); + init_scratchpad(); + + return status::success; + } + + std::shared_ptr matmul_pd_; + + private: + int get_k_blk(format_tag_t tag) const; + status_t init_matmul_params(engine_t *engine); + + void init_scratchpad() { + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_nested, + matmul_pd_->scratchpad_registry()); + } + }; + + status_t init(impl::engine_t *engine) override { + CHECK(pd()->matmul_pd_->create_primitive(matmul_, engine)); + return status::success; + } + + status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::shared_ptr matmul_; +}; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif