diff --git a/examples/tutorials/matmul/shu_int4_weights_decompression_1126.cpp b/examples/tutorials/matmul/shu_int4_weights_decompression_1126.cpp index 92822c8b007..90f4c1f51eb 100644 --- a/examples/tutorials/matmul/shu_int4_weights_decompression_1126.cpp +++ b/examples/tutorials/matmul/shu_int4_weights_decompression_1126.cpp @@ -45,15 +45,16 @@ void init_vector(std::vector &v) { e = u(gen); } -// Transpose the INT4 data from the format tag::ba to tag::ab +// Transpose the INT4 data for the src vector which packs 8 INT4 values as INT32, +// for example, the data babababa is transposed to abababab. void transpose_s4(const std::vector &src, std::vector &dst, int32_t K, int32_t N) { // Ensure dst is the correct size dst.resize(src.size()); - // Transpose the data + // Iterate over the src vector to transpose the INT4 data for (int32_t k = 0; k < K; ++k) { for (int32_t n = 0; n < N; ++n) { - // Extract int4 values from src + // Extract INT4 values from src int32_t src_byte = src[k * N + n]; for (int32_t i = 0; i < 8; ++i) { int8_t src_int4 = (src_byte >> (i * 4)) & 0x0F; @@ -62,9 +63,9 @@ void transpose_s4(const std::vector &src, std::vector &dst, in int32_t dst_index = (n * K + k) / 8; int32_t dst_offset = (n * K + k) % 8; - // Pack int4 values into dst - dst[dst_index] &= ~(0x0F << (dst_offset * 4)); // Clear the destination int4 - dst[dst_index] |= (src_int4 << (dst_offset * 4)); // Set the destination int4 + // Pack INT4 values into dst + dst[dst_index] &= ~(0x0F << (dst_offset * 4)); // Clear the destination INT4 + dst[dst_index] |= (src_int4 << (dst_offset * 4)); // Set the destination INT4 } } } @@ -178,26 +179,27 @@ void int4_weights_decompression_matmul(engine::kind engine_kind) { auto matmul_pd = matmul_pd_create(M, N, K, G, eng); - // Original weights stored by packing 8 INT4 values as INT32 in a format tag::ba - memory::desc B_s32_trans_md({K / 8, N}, memory::data_type::s32, memory::format_tag::ba); + // Original weights stored by packing 8 INT4 values as INT32 in a format tag::ba. + // oneDNN doesn't have a notion of format for zero-points and it's always considered as tag::ab. + // The example of memory::desc for transposed weights with format tag::ba is below. + // memory::desc B_s32_trans_md({K / 8, N}, memory::data_type::s32, memory::format_tag::ba); + // The example of memory::desc for weights with format tag::ab is below. + // memory::desc B_s32_md({K / 8, N}, memory::data_type::s32, memory::format_tag::ab); + // In this example, we transpose the weights data to match the format tag::ab of the zero-points. std::vector B_s32_trans_data(N * K / 8); init_vector(B_s32_trans_data); - // oneDNN doesn't have a notion of format for zero-points and it's always considered as tag::ab - // In this example, we transpose the weights format to match the format tag::ab of the zero-points - memory::desc B_s32_md({K / 8, N}, memory::data_type::s32, memory::format_tag::ab); - memory B_s32_mem(B_s32_md, eng); - { - stream s(eng); - // Use handle from data B_s32_trans - memory B_s32_trans_mem(B_s32_trans_md, eng); - write_to_dnnl_memory(B_s32_trans_data.data(), B_s32_trans_mem); - reorder(B_s32_trans_mem, B_s32_mem).execute(s, B_s32_trans_mem, B_s32_mem); - s.wait(); - } + // Transpose the s4 data to match the format tag::ab + std::vector B_s32_data(K / 8 * N); + transpose_s4(B_s32_trans_data, B_s32_data, K / 8, N); + // This way of constrcuting memory causes segfault on GPU: // Fill B_s4_mem data using handle from B_s32 data filled as INT32 - memory B_s4_mem(matmul_pd.weights_desc(), eng, B_s32_mem.get_data_handle()); + // memory B_s4_mem(matmul_pd.weights_desc(), eng, B_s32_mem.get_data_handle()); + + // Fill B_s4_mem data using the write_to_dnnl_memory function from B_s32 data filled as INT32 + memory B_s4_mem(matmul_pd.weights_desc(), eng); + write_to_dnnl_memory(B_s32_data.data(), B_s4_mem); matmul matmul_p(matmul_pd);