diff --git a/Cargo.lock b/Cargo.lock index 223f515b83..7e7cad83d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1396,7 +1396,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" +source = "git+https://github.com/tracel-ai/cubecl?rev=ea9e6ba3e338aa0c528e885fb17f763b9366b799#ea9e6ba3e338aa0c528e885fb17f763b9366b799" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1407,7 +1407,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" +source = "git+https://github.com/tracel-ai/cubecl?rev=ea9e6ba3e338aa0c528e885fb17f763b9366b799#ea9e6ba3e338aa0c528e885fb17f763b9366b799" dependencies = [ "derive-new", "getrandom", @@ -1422,7 +1422,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" +source = "git+https://github.com/tracel-ai/cubecl?rev=ea9e6ba3e338aa0c528e885fb17f763b9366b799#ea9e6ba3e338aa0c528e885fb17f763b9366b799" dependencies = [ "bytemuck", "cubecl-macros", @@ -1437,7 +1437,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" +source = "git+https://github.com/tracel-ai/cubecl?rev=ea9e6ba3e338aa0c528e885fb17f763b9366b799#ea9e6ba3e338aa0c528e885fb17f763b9366b799" dependencies = [ "bytemuck", "cubecl-common", @@ -1452,7 +1452,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" +source = "git+https://github.com/tracel-ai/cubecl?rev=ea9e6ba3e338aa0c528e885fb17f763b9366b799#ea9e6ba3e338aa0c528e885fb17f763b9366b799" dependencies = [ "bytemuck", "cubecl-core", @@ -1463,7 +1463,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" +source = "git+https://github.com/tracel-ai/cubecl?rev=ea9e6ba3e338aa0c528e885fb17f763b9366b799#ea9e6ba3e338aa0c528e885fb17f763b9366b799" dependencies = [ "derive-new", "proc-macro2", @@ -1474,7 +1474,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" +source = "git+https://github.com/tracel-ai/cubecl?rev=ea9e6ba3e338aa0c528e885fb17f763b9366b799#ea9e6ba3e338aa0c528e885fb17f763b9366b799" dependencies = [ "async-channel", "cfg_aliases 0.2.1", @@ -1494,7 +1494,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" +source = "git+https://github.com/tracel-ai/cubecl?rev=ea9e6ba3e338aa0c528e885fb17f763b9366b799#ea9e6ba3e338aa0c528e885fb17f763b9366b799" dependencies = [ "async-channel", "bytemuck", @@ -6665,8 +6665,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracel-xtask" -version = "1.0.0" -source = "git+https://github.com/tracel-ai/xtask?rev=921408bc16e74d3ef8ae59356d928fb6706fb8f4#921408bc16e74d3ef8ae59356d928fb6706fb8f4" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63f307c8a22d3c67bb2a0678290243e4917d235c507f0c8b35e8612919978a08" dependencies = [ "anyhow", "clap 4.5.16", @@ -6683,8 +6684,9 @@ dependencies = [ [[package]] name = "tracel-xtask-macros" -version = "1.0.0" -source = "git+https://github.com/tracel-ai/xtask?rev=921408bc16e74d3ef8ae59356d928fb6706fb8f4#921408bc16e74d3ef8ae59356d928fb6706fb8f4" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9b7ee23c050536c8c932ca7daaebbf45ff6c1d57f1bd65fc084833ac6a8d419" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index adef4cc1f9..532b9240f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -151,8 +151,8 @@ systemstat = "0.2.3" portable-atomic-util = { version = "0.2.2", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ea9e6ba3e338aa0c528e885fb17f763b9366b799" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ea9e6ba3e338aa0c528e885fb17f763b9366b799" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl" } # cubecl-common = { path = "../cubecl/crates/cubecl-common" } diff --git a/backend-comparison/src/burnbenchapp/auth.rs b/backend-comparison/src/burnbenchapp/auth.rs index 8f0577dfec..3e9470b2bd 100644 --- a/backend-comparison/src/burnbenchapp/auth.rs +++ b/backend-comparison/src/burnbenchapp/auth.rs @@ -156,9 +156,8 @@ fn refresh_tokens(tokens: &Tokens) -> Option { // reqwest won't send the request in release build .body(reqwest::blocking::Body::from("")) .send(); - response.ok()?.json::().ok().map(|new_tokens| { + response.ok()?.json::().ok().inspect(|_new_tokens| { println!("✅ Token refreshed!"); - new_tokens }) } diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index df79f1fd44..9e5ff39882 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,4 +1,5 @@ use super::{ElementWise, ElementWiseState}; +use crate::tensor::is_contiguous; use crate::{ element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime, @@ -263,7 +264,7 @@ fn dynamic_inplace( let handle = &handles_inputs[pos]; - if handle.handle.can_mut() && is_contiguous(&handle.strides) { + if handle.handle.can_mut() && is_contiguous(&desc.shape, &handle.strides) { Some((pos, desc, input)) } else { None @@ -318,7 +319,7 @@ fn dynamic_reading_strategy( continue; } - if is_contiguous(&handle.strides) { + if is_contiguous(&description_input.shape, &handle.strides) { settings .reading_strategy .push((input_id, ReadingStrategy::Plain)); @@ -326,16 +327,3 @@ fn dynamic_reading_strategy( } settings } - -fn is_contiguous(strides: &[usize]) -> bool { - let mut current = 0; - - for stride in strides.iter().rev() { - if current > *stride { - return false; - } - current = *stride; - } - - true -} diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index b010d99010..7be63d4af6 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -110,12 +110,36 @@ pub(crate) fn slice; D2], ) -> JitTensor { let mut dims = tensor.shape.dims; + let mut offset_start = 0; + let mut offset_end = 0; + for i in 0..D2 { + offset_start += tensor.strides[i] * indices[i].start; + offset_end += tensor.strides[i] * (dims[i] - indices[i].end); dims[i] = indices[i].end - indices[i].start; } - let shape_output = Shape::new(dims); - let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); - slice_on_output(tensor, output, indices) + + let offset_start = offset_start * E::cube_elem().size(); + let offset_end = offset_end * E::cube_elem().size(); + + let memory_offset_alignment = tensor.client.properties().memory_offset_alignment as usize; + + if offset_start % memory_offset_alignment == 0 && offset_end % memory_offset_alignment == 0 { + JitTensor::new( + tensor.client, + tensor + .handle + .offset_start(offset_start) + .offset_end(offset_end), + Shape::new(dims), + tensor.device, + tensor.strides, + ) + } else { + let shape_output = Shape::new(dims); + let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); + slice_on_output(tensor, output, indices) + } } pub(crate) fn slice_on_output( diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index 0778a083f9..292505845d 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -21,7 +21,7 @@ pub mod element; use burn_tensor::backend::{DeviceId, DeviceOps}; use cubecl::{ compute::{CubeCount, CubeTask}, - Runtime, + FeatureSet, Properties, Runtime, }; pub use element::{FloatElement, IntElement, JitElement}; @@ -55,6 +55,8 @@ pub trait JitRuntime: Runtime, DispatchOptions = CubeCount, + Properties = Properties, + FeatureSet = FeatureSet, >; } diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index f98de3a6ab..a77d9db2b3 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -4,7 +4,7 @@ use crate::JitRuntime; use burn_tensor::Shape; use cubecl::client::ComputeClient; use cubecl::frontend::Numeric; -use cubecl::linalg::tensor::{matrix_layout, MatrixLayout, TensorHandle}; +use cubecl::linalg::tensor::TensorHandle; use cubecl::prelude::{TensorHandleRef, *}; use cubecl::server::Handle; use std::marker::PhantomData; @@ -192,10 +192,66 @@ where /// Check if the current tensor is contiguous. pub fn is_contiguous(&self) -> bool { - self.matrix_layout() == MatrixLayout::Contiguous + is_contiguous(&self.shape.dims, &self.strides) + } +} + +pub(crate) fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { + if shape.is_empty() { + return true; + } + + if shape.len() == 1 { + return strides[0] == 1; + } + + let mut prev_stride = 1; + let mut current_num_elems_shape = 1; + + for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() { + if i > 0 { + if current_num_elems_shape != *stride { + return false; + } + + if prev_stride >= *stride { + return false; + } + } + + current_num_elems_shape *= shape; + prev_stride = *stride; + } + + true +} + +#[cfg(test)] +mod tests { + use crate::tensor::base::is_contiguous; + + #[test] + fn is_contiguous_basic() { + assert!(is_contiguous(&[32, 32], &[32, 1])); + } + + #[test] + fn is_contiguous_permuted() { + assert!(!is_contiguous(&[32, 32], &[1, 32])); + } + + #[test] + fn is_contiguous_slice() { + assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1])); + } + + #[test] + fn is_contiguous_4d_positive() { + assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1])); } - pub(crate) fn matrix_layout(&self) -> MatrixLayout { - matrix_layout(&self.strides) + #[test] + fn is_contiguous_4d_negative() { + assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1])); } } diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index ff3998efe2..0c952c3ac4 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -66,6 +66,7 @@ impl TensorData { // Ensure shape is valid let shape = shape.into(); let shape_numel = Self::numel(&shape); + value.truncate(shape_numel); let numel = value.len(); assert_eq!( shape_numel, numel, diff --git a/crates/burn-tensor/src/tests/ops/narrow.rs b/crates/burn-tensor/src/tests/ops/narrow.rs index 4e06ff57fb..6d89e997a5 100644 --- a/crates/burn-tensor/src/tests/ops/narrow.rs +++ b/crates/burn-tensor/src/tests/ops/narrow.rs @@ -4,7 +4,7 @@ mod tests { use burn_tensor::{Shape, Tensor, TensorData}; #[test] - fn test_narrow() { + fn test_narrow_1() { let tensor: Tensor = Tensor::from_data( TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]), &Default::default(), @@ -15,6 +15,14 @@ mod tests { assert_eq!(output.shape(), Shape::from([2, 3])); output.into_data().assert_approx_eq(&expected, 3); + } + + #[test] + fn test_narrow_2() { + let tensor: Tensor = Tensor::from_data( + TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]), + &Default::default(), + ); let output = tensor.clone().narrow(1, 1, 2); let expected = TensorData::from([[2., 3.], [5., 6.], [8., 9.]]); @@ -22,6 +30,25 @@ mod tests { output.into_data().assert_approx_eq(&expected, 3); } + #[test] + fn test_narrow_3() { + let device = &Default::default(); + let shape = Shape::new([8, 8]); + let tensor: Tensor = + TestTensorInt::arange(0..shape.num_elements() as i64, &device) + .reshape(shape) + .float(); + + let output = tensor.clone().narrow(0, 3, 4); + let expected = TensorData::from([ + [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0], + [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0], + [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0], + [48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0], + ]); + output.into_data().assert_approx_eq(&expected, 3); + } + #[test] #[should_panic] fn test_narrow_invalid_dim() { diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index 6275b36ed8..ff580b37aa 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -28,7 +28,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Cos => same_as_input(node), NodeType::Div => same_as_input_broadcast(node), NodeType::Dropout => same_as_input(node), - NodeType::Equal => elementwise_comparsion_outputs(node), + NodeType::Equal => elementwise_comparison_outputs(node), NodeType::Erf => same_as_input(node), NodeType::Exp => same_as_input(node), NodeType::Expand => expand_update_outputs(node), @@ -36,15 +36,15 @@ pub fn dim_inference(node: &mut Node) { NodeType::Gelu => same_as_input(node), NodeType::Gather => gather_update_outputs(node), NodeType::GatherElements => same_as_input(node), - NodeType::Greater => elementwise_comparsion_outputs(node), - NodeType::GreaterOrEqual => elementwise_comparsion_outputs(node), + NodeType::Greater => elementwise_comparison_outputs(node), + NodeType::GreaterOrEqual => elementwise_comparison_outputs(node), NodeType::HardSigmoid => same_as_input(node), NodeType::GlobalAveragePool => same_as_input(node), NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node), NodeType::LayerNormalization => same_as_input(node), NodeType::LeakyRelu => same_as_input(node), - NodeType::Less => elementwise_comparsion_outputs(node), - NodeType::LessOrEqual => elementwise_comparsion_outputs(node), + NodeType::Less => elementwise_comparison_outputs(node), + NodeType::LessOrEqual => elementwise_comparison_outputs(node), NodeType::Linear => linear_update_outputs(node), NodeType::Log => same_as_input(node), NodeType::LogSoftmax => same_as_input(node), @@ -488,7 +488,7 @@ fn temporary_pass_through_stub(node: &mut Node) { /// i.e., comparison operators like Equal, Greater, Less, etc. /// /// Support for broadcasting is assumed -fn elementwise_comparsion_outputs(node: &mut Node) { +fn elementwise_comparison_outputs(node: &mut Node) { let input1_type = &node.inputs[0].ty; let input2_type = &node.inputs[1].ty;