From 58ce502498adb7accca9f121122d60012c9f0374 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 10 Sep 2024 13:36:00 -0400 Subject: [PATCH] Fix (#2269) --- Cargo.lock | 16 ++++++++-------- Cargo.toml | 4 ++-- crates/burn-autodiff/src/tests/flip.rs | 8 ++++---- crates/burn-autodiff/src/tests/permute.rs | 8 ++++---- .../burn-jit/src/kernel/index/select_assign.rs | 10 +++++----- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 988a4bf4e1..4ed0dccc7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1397,7 +1397,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b#1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" +source = "git+https://github.com/tracel-ai/cubecl?rev=7a86f9a86e376fedb09f096f2b548e501a130883#7a86f9a86e376fedb09f096f2b548e501a130883" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1408,7 +1408,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b#1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" +source = "git+https://github.com/tracel-ai/cubecl?rev=7a86f9a86e376fedb09f096f2b548e501a130883#7a86f9a86e376fedb09f096f2b548e501a130883" dependencies = [ "derive-new", "getrandom", @@ -1423,7 +1423,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b#1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" +source = "git+https://github.com/tracel-ai/cubecl?rev=7a86f9a86e376fedb09f096f2b548e501a130883#7a86f9a86e376fedb09f096f2b548e501a130883" dependencies = [ "bytemuck", "cubecl-common", @@ -1439,7 +1439,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b#1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" +source = "git+https://github.com/tracel-ai/cubecl?rev=7a86f9a86e376fedb09f096f2b548e501a130883#7a86f9a86e376fedb09f096f2b548e501a130883" dependencies = [ "bytemuck", "cubecl-common", @@ -1454,7 +1454,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b#1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" +source = "git+https://github.com/tracel-ai/cubecl?rev=7a86f9a86e376fedb09f096f2b548e501a130883#7a86f9a86e376fedb09f096f2b548e501a130883" dependencies = [ "bytemuck", "cubecl-core", @@ -1465,7 +1465,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b#1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" +source = "git+https://github.com/tracel-ai/cubecl?rev=7a86f9a86e376fedb09f096f2b548e501a130883#7a86f9a86e376fedb09f096f2b548e501a130883" dependencies = [ "cubecl-common", "darling", @@ -1480,7 +1480,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b#1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" +source = "git+https://github.com/tracel-ai/cubecl?rev=7a86f9a86e376fedb09f096f2b548e501a130883#7a86f9a86e376fedb09f096f2b548e501a130883" dependencies = [ "async-channel", "cfg_aliases 0.2.1", @@ -1500,7 +1500,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b#1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" +source = "git+https://github.com/tracel-ai/cubecl?rev=7a86f9a86e376fedb09f096f2b548e501a130883#7a86f9a86e376fedb09f096f2b548e501a130883" dependencies = [ "async-channel", "bytemuck", diff --git a/Cargo.toml b/Cargo.toml index a264bcd530..52079ea290 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -151,8 +151,8 @@ tch = "0.15.0" portable-atomic-util = { version = "0.2.2", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7a86f9a86e376fedb09f096f2b548e501a130883" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7a86f9a86e376fedb09f096f2b548e501a130883" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl" } # cubecl-common = { path = "../cubecl/crates/cubecl-common" } diff --git a/crates/burn-autodiff/src/tests/flip.rs b/crates/burn-autodiff/src/tests/flip.rs index bfa6b2870d..d63336e9a8 100644 --- a/crates/burn-autodiff/src/tests/flip.rs +++ b/crates/burn-autodiff/src/tests/flip.rs @@ -20,11 +20,11 @@ mod tests { let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 - .to_data() - .assert_eq(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), false); // 1x2x2 - grad_2.to_data().assert_eq( + .into_data() + .assert_approx_eq(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), 3); // 1x2x2 + grad_2.into_data().assert_approx_eq( &TensorData::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]), - false, + 3, ); // 1x2x3 } } diff --git a/crates/burn-autodiff/src/tests/permute.rs b/crates/burn-autodiff/src/tests/permute.rs index 4d785c4ab8..14a5cca065 100644 --- a/crates/burn-autodiff/src/tests/permute.rs +++ b/crates/burn-autodiff/src/tests/permute.rs @@ -20,11 +20,11 @@ mod tests { let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 - .to_data() - .assert_eq(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), false); // 1x2x2 - grad_2.to_data().assert_eq( + .into_data() + .assert_approx_eq(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), 3); // 1x2x2 + grad_2.into_data().assert_approx_eq( &TensorData::from([[[3.0, 10.0], [3.0, 10.0], [3.0, 10.0]]]), - false, + 3, ); // 1x3x2 } } diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index 0a271a4cd2..184848da31 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -9,14 +9,14 @@ fn select_assign_kernel( value: &Tensor, dim: &u32, ) { - let dim2 = *dim; + let dim = *dim; let mut offset_tensor = 0u32; let mut offset_value = 0u32; let mut num_elems = 1u32; // Calculate offsets and num_elems for i in 0..tensor.rank() { - if i != dim2 { + if i != dim { let shape_tensor = tensor.shape(i); num_elems *= shape_tensor; @@ -32,11 +32,11 @@ fn select_assign_kernel( return; } - let strides_tensor_dim = tensor.stride(dim2); - let strides_value_dim = value.stride(dim2); + let strides_tensor_dim = tensor.stride(dim); + let strides_value_dim = value.stride(dim); // Main operation - for i in 0..value.shape(dim2) { + for i in 0..value.shape(dim) { let index_tensor = u32::cast_from(indices[i]) * strides_tensor_dim + offset_tensor; let index_value = i * strides_value_dim + offset_value;