Skip to content

Commit

Permalink
benchdnn: graph: fix the dt setting for input displace
Browse files Browse the repository at this point in the history
  • Loading branch information
wzt1997 committed Jan 8, 2025
1 parent d65e846 commit 50714b2
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions tests/benchdnn/graph/input_displacer.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2023-2024 Intel Corporation
* Copyright 2023-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -324,7 +324,6 @@ int partition_data_displacer_t::gen_quantize_filling(
// clone a deserialized op object and modify to specified data type
::graph::deserialized_op op = main_op;
auto driver = opkind2driver(opstr2kind(op.kind_));
bool is_f8_quantization = (dt == "f8_e5m2" || dt == "f8_e4m3");

op.in_lts_[0].data_type_ = dt;
if (op.in_lts_.size() > 1) {
Expand All @@ -341,23 +340,20 @@ int partition_data_displacer_t::gen_quantize_filling(
}
}
}

if (driver == dnnl_driver_t::pool || driver == dnnl_driver_t::binary) {
// pool does not support x8f32 on cpu
// binary does not support x8x8bf16 on gpu
// replace output with x8
op.out_lts_[0].data_type_ = dt;
} else if (op.out_lts_[0].data_type_ != "bf16") {
if (op.in_lts_.size() > 1 && op.in_lts_[1].data_type_ == "s8") {
// Use u8 as output data type for two-input operations to avoid
// data overflow due to the specific driver logic.
op.out_lts_[0].data_type_ = "u8";
} else if (is_f8_quantization) {
op.out_lts_[0].data_type_ = "f8_e5m2";
} else {
// Use f32 as output data type since not all primitives support
// different data types for input and output.
op.out_lts_[0].data_type_ = "f32";
}
} else if (op.in_lts_.size() > 1 && op.in_lts_[1].data_type_ == "s8") {
// Use u8 as output data type for two-input operations to avoid
// data overflow due to the specific driver logic.
op.out_lts_[0].data_type_ = "u8";
} else {
// Use f32 as output data type since not all primitives support
// different data types for input and output.
op.out_lts_[0].data_type_ = "f32";
}

::std::unordered_set<size_t> empty_set;
Expand Down

0 comments on commit 50714b2

Please sign in to comment.