Skip to content

Commit

Permalink
update the example
Browse files Browse the repository at this point in the history
  • Loading branch information
shu1chen committed Nov 26, 2024
1 parent 080c5ac commit edeeccd
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions examples/tutorials/matmul/shu_int4_weights_decompression_1126.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,16 @@ void init_vector(std::vector<int32_t> &v) {
e = u(gen);
}

// Transpose the INT4 data from the format tag::ba to tag::ab
// Transpose the INT4 data for the src vector which packs 8 INT4 values as INT32,
// for example, the data babababa is transposed to abababab.
void transpose_s4(const std::vector<int32_t> &src, std::vector<int32_t> &dst, int32_t K, int32_t N) {
// Ensure dst is the correct size
dst.resize(src.size());

// Transpose the data
// Iterate over the src vector to transpose the INT4 data
for (int32_t k = 0; k < K; ++k) {
for (int32_t n = 0; n < N; ++n) {
// Extract int4 values from src
// Extract INT4 values from src
int32_t src_byte = src[k * N + n];
for (int32_t i = 0; i < 8; ++i) {
int8_t src_int4 = (src_byte >> (i * 4)) & 0x0F;
Expand All @@ -62,9 +63,9 @@ void transpose_s4(const std::vector<int32_t> &src, std::vector<int32_t> &dst, in
int32_t dst_index = (n * K + k) / 8;
int32_t dst_offset = (n * K + k) % 8;

// Pack int4 values into dst
dst[dst_index] &= ~(0x0F << (dst_offset * 4)); // Clear the destination int4
dst[dst_index] |= (src_int4 << (dst_offset * 4)); // Set the destination int4
// Pack INT4 values into dst
dst[dst_index] &= ~(0x0F << (dst_offset * 4)); // Clear the destination INT4
dst[dst_index] |= (src_int4 << (dst_offset * 4)); // Set the destination INT4
}
}
}
Expand Down Expand Up @@ -178,26 +179,27 @@ void int4_weights_decompression_matmul(engine::kind engine_kind) {

auto matmul_pd = matmul_pd_create(M, N, K, G, eng);

// Original weights stored by packing 8 INT4 values as INT32 in a format tag::ba
memory::desc B_s32_trans_md({K / 8, N}, memory::data_type::s32, memory::format_tag::ba);
// Original weights stored by packing 8 INT4 values as INT32 in a format tag::ba.
// oneDNN doesn't have a notion of format for zero-points and it's always considered as tag::ab.
// The example of memory::desc for transposed weights with format tag::ba is below.
// memory::desc B_s32_trans_md({K / 8, N}, memory::data_type::s32, memory::format_tag::ba);
// The example of memory::desc for weights with format tag::ab is below.
// memory::desc B_s32_md({K / 8, N}, memory::data_type::s32, memory::format_tag::ab);
// In this example, we transpose the weights data to match the format tag::ab of the zero-points.
std::vector<int32_t> B_s32_trans_data(N * K / 8);
init_vector(B_s32_trans_data);

// oneDNN doesn't have a notion of format for zero-points and it's always considered as tag::ab
// In this example, we transpose the weights format to match the format tag::ab of the zero-points
memory::desc B_s32_md({K / 8, N}, memory::data_type::s32, memory::format_tag::ab);
memory B_s32_mem(B_s32_md, eng);
{
stream s(eng);
// Use handle from data B_s32_trans
memory B_s32_trans_mem(B_s32_trans_md, eng);
write_to_dnnl_memory(B_s32_trans_data.data(), B_s32_trans_mem);
reorder(B_s32_trans_mem, B_s32_mem).execute(s, B_s32_trans_mem, B_s32_mem);
s.wait();
}
// Transpose the s4 data to match the format tag::ab
std::vector<int32_t> B_s32_data(K / 8 * N);
transpose_s4(B_s32_trans_data, B_s32_data, K / 8, N);

// This way of constrcuting memory causes segfault on GPU:
// Fill B_s4_mem data using handle from B_s32 data filled as INT32
memory B_s4_mem(matmul_pd.weights_desc(), eng, B_s32_mem.get_data_handle());
// memory B_s4_mem(matmul_pd.weights_desc(), eng, B_s32_mem.get_data_handle());

// Fill B_s4_mem data using the write_to_dnnl_memory function from B_s32 data filled as INT32
memory B_s4_mem(matmul_pd.weights_desc(), eng);
write_to_dnnl_memory(B_s32_data.data(), B_s4_mem);

matmul matmul_p(matmul_pd);

Expand Down

0 comments on commit edeeccd

Please sign in to comment.