Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpu: x64: enable matmul-based IP for forward inference #2341

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/cpu/cpu_inner_product_list.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,7 @@
#if DNNL_X64
#include "cpu/x64/gemm_bf16_inner_product.hpp"
#include "cpu/x64/jit_brgemm_inner_product.hpp"
#include "cpu/x64/matmul_inner_product.hpp"
using namespace dnnl::impl::cpu::x64;
#endif

Expand All @@ -43,6 +44,7 @@ using namespace dnnl::impl::prop_kind;
#define BRGEMM_FP8_FWD_IP(dtsrc, dtwei, dtdst) \
{ \
{forward, dtsrc, dtwei, dtdst}, { \
CPU_INSTANCE_X64(matmul_inner_product_fwd_t) \
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx10_1_512_amx_fp16>) \
CPU_INSTANCE(ref_inner_product_fwd_t) nullptr, \
} \
Expand All @@ -52,6 +54,7 @@ using namespace dnnl::impl::prop_kind;
const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
static const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_IP_P({
{{forward, f32, f32, f32}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) // bf32
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t<avx2>)
Expand All @@ -61,6 +64,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, bf16, bf16, f32}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_bf16>)
CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t<f32>)
Expand All @@ -69,6 +73,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, bf16, bf16, bf16}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_bf16>)
CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t<bf16>)
Expand All @@ -77,13 +82,15 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, f16, f16, f32}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx_fp16>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_fp16>)
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t<avx2_vnni_2>)
CPU_INSTANCE(ref_inner_product_fwd_t)
nullptr,
}},
{{forward, f16, f16, f16}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx_fp16>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_fp16>)
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t<avx2_vnni_2>)
Expand Down Expand Up @@ -187,6 +194,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
})},
{{forward, s8, s8, f32}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
Expand All @@ -197,6 +205,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, s8, s8, s32}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
Expand All @@ -207,6 +216,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, s8, s8, s8}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
Expand All @@ -217,6 +227,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, s8, s8, u8}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
Expand All @@ -227,6 +238,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, u8, s8, f32}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
Expand All @@ -237,6 +249,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, u8, s8, s32}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
Expand All @@ -247,6 +260,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, u8, s8, s8}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
Expand All @@ -257,6 +271,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, u8, s8, u8}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
Expand All @@ -267,6 +282,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, s8, s8, bf16}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
Expand All @@ -275,6 +291,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
nullptr,
}},
{{forward, u8, s8, bf16}, {
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
Expand Down
6 changes: 4 additions & 2 deletions src/cpu/x64/jit_brgemm_inner_product.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2024 Intel Corporation
* Copyright 2020-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -67,7 +67,9 @@ struct brgemm_inner_product_fwd_t : public primitive_t {
// better readability
if (!mayiuse(isa)) return status::unimplemented;

VDISPATCH_INNER_PRODUCT(is_fwd(), VERBOSE_BAD_PROPKIND);
VDISPATCH_INNER_PRODUCT(
get_prop_kind() == prop_kind::forward_training,
VERBOSE_BAD_PROPKIND);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify why jit_brgemm_ip gets restricted to forward training?

IIUC, jit_brgemm_ip is after the new matmul_ip in the dispatch list, so there is a chance that if a forward inference case is not handled by matmul_ip (per the documented restrictions), it will also be skipped by jit_brgemm_ip and will go to a lower performance impl.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so there is a chance that if a forward inference case is not handled by matmul_ip (per the documented restrictions)

If a blocked layout cannot be used for weights then it'll fall back to a plain layout so all cases that can be handled by brgemm-ip can also be handled by matmul-based ip.

The goal is to remove brgemm-ip completely so we can't use it as a complement.

VDISPATCH_INNER_PRODUCT(
expect_data_types(src_dt, wei_dt, data_type::undef, dst_dt,
data_type::undef),
Expand Down
Loading
Loading