Skip to content

Commit

Permalink
update the example
Browse files Browse the repository at this point in the history
  • Loading branch information
shu1chen committed Nov 27, 2024
1 parent 3655bfc commit 0f1ab00
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions examples/tutorials/matmul/shu_int4_weights_decompression_1126.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ matmul::primitive_desc matmul_pd_create(
attr.set_scales(DNNL_ARG_WEIGHTS,
/* mask */ (1 << 0) + (1 << 1), {G_SC, 1}, memory::data_type::f16);

// Set zero points with s4 data type both along K and N dimensions
// Set zero points with s4 data type both along K and N dimensions and with groups along K.
// oneDNN APIs consider zero-points as INT4 elements and are NOT packed into INT32 value
attr.set_zero_points(
DNNL_ARG_WEIGHTS, (1 << 0) + (1 << 1), {G_ZP, 1}, memory::data_type::s4);

Expand All @@ -117,8 +118,11 @@ void prepare_input(memory &A_f16_mem, memory &sc_B_mem, memory &zp_B_mem) {
int64_t N = sc_B_mem.get_desc().get_dims()[0];
int64_t K = A_f16_mem.get_desc().get_dims()[1];
int64_t NUM_G_SC = sc_B_mem.get_desc().get_dims()[1];
int64_t NUM_G_ZP_s4 = zp_B_mem.get_desc().get_dims()[1];

// 8 INT4 values are packed as INT32 in N direction
// oneDNN APIs consider zero-points as INT4 elements and are NOT packed into INT32 value
int64_t NUM_G_ZP_N = zp_B_mem.get_desc().get_dims()[0] / 8;
int64_t NUM_G_ZP_K = zp_B_mem.get_desc().get_dims()[1];

std::vector<float> A_f32(M * K);
init_vector(A_f32);
// Fill A_f16_mem f16 data from A_f32 data filled as f32
Expand All @@ -129,13 +133,11 @@ void prepare_input(memory &A_f16_mem, memory &sc_B_mem, memory &zp_B_mem) {
// Fill sc_B_mem f16 data from sc_B data filled as f32
write_to_dnnl_memory(sc_B.data(), sc_B_mem);

// 8 INT4 values are packed as INT32
int64_t NUM_G_ZP_s32 = NUM_G_ZP_s4 / 8 + 1;
std::vector<int32_t> zp_transpose_B(NUM_G_ZP_s32 * 1);
std::vector<int32_t> zp_transpose_B(NUM_G_ZP_K * NUM_G_ZP_N);
init_vector(zp_transpose_B);
// Transpose the s4 data to match the format tag::ab
std::vector<int32_t> zp_B(NUM_G_ZP_s32 * 1);
transpose_s4(zp_transpose_B, zp_B, NUM_G_ZP_s32 , 1);
std::vector<int32_t> zp_B(NUM_G_ZP_K * NUM_G_ZP_N);
transpose_s4(zp_transpose_B, zp_B, NUM_G_ZP_K , NUM_G_ZP_N);
// Fill zp_B_mem s4 data from zp_B data filled as s32
write_to_dnnl_memory(zp_B.data(), zp_B_mem);
}
Expand All @@ -144,12 +146,10 @@ void infer(const matmul &matmul_p, int64_t M, int64_t N, int64_t K, int64_t G_SC
int64_t G_ZP, const memory &B_s4_mem, const engine &eng) {
// input of the current layer / operation
memory A_f16_mem({{M, K}, memory::data_type::f16, {K, 1}}, eng);
// De-quantization parameters (eg. Scale and Shift)
const int64_t n_sc_groups = K / G_SC;
memory sc_B_mem({{N, n_sc_groups}, memory::data_type::f16, {1, N}}, eng);
// number of groups for zero points
const int64_t n_zp_groups = K / G_ZP;
memory zp_B_mem({{1, n_zp_groups}, memory::data_type::s4, {1, 1}}, eng);
// scale is grouped in K direction [K/G_SC, N]
memory sc_B_mem({{N, K / G_SC}, memory::data_type::f16, {1, N}}, eng);
// zeros point is grouped in K direction and packed in to INT32 in N direction: [K/G_ZP, N/8]
memory zp_B_mem({{N, K / G_ZP}, memory::data_type::s4, {1, N}}, eng);

// the function below fills dnnl::memory with some values
// these memories, typically, come from the previous layers / operations
Expand All @@ -176,26 +176,26 @@ void int4_weights_decompression_matmul(engine::kind engine_kind) {
const int64_t K = 96;
const int64_t N = 1000;
const int64_t M = 100;
// Quantization Group size for scales
// Quantization Group size for scales in K direction
const int64_t G_SC = K / 2;
// Quantization Group size for zero points
// Quantization Group size for zero points in K direction
const int64_t G_ZP = K / 4;

auto matmul_pd = matmul_pd_create(M, N, K, G_SC, G_ZP, eng);

// Original weights stored by packing 8 INT4 values as INT32 in a format tag::ba.
// Original weights stored by packing 8 INT4 values in N direction as INT32 in a format tag::ba.
// oneDNN doesn't have a notion of format for zero-points and it's always considered as tag::ab.
// The example of memory::desc for transposed weights with format tag::ba is below.
// memory::desc B_s32_trans_md({K / 8, N}, memory::data_type::s32, memory::format_tag::ba);
// memory::desc B_s32_trans_md({K, N / 8}, memory::data_type::s32, memory::format_tag::ba);
// The example of memory::desc for weights with format tag::ab is below.
// memory::desc B_s32_md({K / 8, N}, memory::data_type::s32, memory::format_tag::ab);
// memory::desc B_s32_md({K, N / 8}, memory::data_type::s32, memory::format_tag::ab);
// In this example, we transpose the weights data to match the format tag::ab of the zero-points.
std::vector<int32_t> B_s32_trans_data(N * K / 8);
std::vector<int32_t> B_s32_trans_data(K * N / 8);
init_vector(B_s32_trans_data);

// Transpose the s4 data to match the format tag::ab
std::vector<int32_t> B_s32_data(K / 8 * N);
transpose_s4(B_s32_trans_data, B_s32_data, K / 8, N);
std::vector<int32_t> B_s32_data(K * N / 8);
transpose_s4(B_s32_trans_data, B_s32_data, K, N / 8);

// This way of constrcuting memory causes segfault on GPU:
// Fill B_s4_mem data using handle from B_s32 data filled as INT32
Expand Down

0 comments on commit 0f1ab00

Please sign in to comment.