Skip to content

Commit

Permalink
Code generation bug fix for ONNX import (#2708)
Browse files Browse the repository at this point in the history
  • Loading branch information
antimora authored Jan 16, 2025
1 parent 05925f1 commit 6750fd6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions crates/burn-import/src/burn/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use burn::nn::PaddingConfig1d;
use burn::nn::PaddingConfig2d;
use burn::nn::PaddingConfig3d;

fn convert_primitive<T: ToString>(primitive: T) -> TokenStream {
let value = primitive.to_string();
fn convert_primitive<T: core::fmt::Debug>(primitive: T) -> TokenStream {
let value = format!("{:?}", primitive);

value.parse().unwrap()
}

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-import/src/burn/node/resize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ mod tests {
TensorType::new_float("tensor1", 3),
TensorType::new_float("tensor2", 3),
"cubic".to_string(),
vec![],
vec![2.0],
vec![20],
));

Expand All @@ -253,7 +253,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
let resize = Interpolate1dConfig::new()
.with_output_size(Some(20))
.with_scale_factor(None)
.with_scale_factor(Some(2.0))
.with_mode(InterpolateMode::Cubic)
.init();
Self {
Expand Down

0 comments on commit 6750fd6

Please sign in to comment.