Skip to content

Commit

Permalink
Perf/slice (#2252)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 9, 2024
1 parent 3d91b40 commit 94cd8a2
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 46 deletions.
26 changes: 14 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ systemstat = "0.2.3"
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ea9e6ba3e338aa0c528e885fb17f763b9366b799" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ea9e6ba3e338aa0c528e885fb17f763b9366b799" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
Expand Down
3 changes: 1 addition & 2 deletions backend-comparison/src/burnbenchapp/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,8 @@ fn refresh_tokens(tokens: &Tokens) -> Option<Tokens> {
// reqwest won't send the request in release build
.body(reqwest::blocking::Body::from(""))
.send();
response.ok()?.json::<Tokens>().ok().map(|new_tokens| {
response.ok()?.json::<Tokens>().ok().inspect(|_new_tokens| {
println!("✅ Token refreshed!");
new_tokens
})
}

Expand Down
18 changes: 3 additions & 15 deletions crates/burn-jit/src/fusion/base.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{ElementWise, ElementWiseState};
use crate::tensor::is_contiguous;
use crate::{
element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement,
IntElement, JitBackend, JitRuntime,
Expand Down Expand Up @@ -263,7 +264,7 @@ fn dynamic_inplace<R: JitRuntime>(

let handle = &handles_inputs[pos];

if handle.handle.can_mut() && is_contiguous(&handle.strides) {
if handle.handle.can_mut() && is_contiguous(&desc.shape, &handle.strides) {
Some((pos, desc, input))
} else {
None
Expand Down Expand Up @@ -318,24 +319,11 @@ fn dynamic_reading_strategy<R: JitRuntime>(
continue;
}

if is_contiguous(&handle.strides) {
if is_contiguous(&description_input.shape, &handle.strides) {
settings
.reading_strategy
.push((input_id, ReadingStrategy::Plain));
}
}
settings
}

fn is_contiguous(strides: &[usize]) -> bool {
let mut current = 0;

for stride in strides.iter().rev() {
if current > *stride {
return false;
}
current = *stride;
}

true
}
30 changes: 27 additions & 3 deletions crates/burn-jit/src/kernel/index/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,36 @@ pub(crate) fn slice<R: JitRuntime, E: JitElement, const D1: usize, const D2: usi
indices: [Range<usize>; D2],
) -> JitTensor<R, E, D1> {
let mut dims = tensor.shape.dims;
let mut offset_start = 0;
let mut offset_end = 0;

for i in 0..D2 {
offset_start += tensor.strides[i] * indices[i].start;
offset_end += tensor.strides[i] * (dims[i] - indices[i].end);
dims[i] = indices[i].end - indices[i].start;
}
let shape_output = Shape::new(dims);
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
slice_on_output(tensor, output, indices)

let offset_start = offset_start * E::cube_elem().size();
let offset_end = offset_end * E::cube_elem().size();

let memory_offset_alignment = tensor.client.properties().memory_offset_alignment as usize;

if offset_start % memory_offset_alignment == 0 && offset_end % memory_offset_alignment == 0 {
JitTensor::new(
tensor.client,
tensor
.handle
.offset_start(offset_start)
.offset_end(offset_end),
Shape::new(dims),
tensor.device,
tensor.strides,
)
} else {
let shape_output = Shape::new(dims);
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
slice_on_output(tensor, output, indices)
}
}

pub(crate) fn slice_on_output<R: JitRuntime, E: JitElement, const D1: usize, const D2: usize>(
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-jit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub mod element;
use burn_tensor::backend::{DeviceId, DeviceOps};
use cubecl::{
compute::{CubeCount, CubeTask},
Runtime,
FeatureSet, Properties, Runtime,
};
pub use element::{FloatElement, IntElement, JitElement};

Expand Down Expand Up @@ -55,6 +55,8 @@ pub trait JitRuntime: Runtime<Device = Self::JitDevice, Server = Self::JitServer
type JitServer: cubecl::server::ComputeServer<
Kernel = Box<dyn CubeTask>,
DispatchOptions = CubeCount<Self::JitServer>,
Properties = Properties,
FeatureSet = FeatureSet,
>;
}

Expand Down
64 changes: 60 additions & 4 deletions crates/burn-jit/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::JitRuntime;
use burn_tensor::Shape;
use cubecl::client::ComputeClient;
use cubecl::frontend::Numeric;
use cubecl::linalg::tensor::{matrix_layout, MatrixLayout, TensorHandle};
use cubecl::linalg::tensor::TensorHandle;
use cubecl::prelude::{TensorHandleRef, *};
use cubecl::server::Handle;
use std::marker::PhantomData;
Expand Down Expand Up @@ -192,10 +192,66 @@ where

/// Check if the current tensor is contiguous.
pub fn is_contiguous(&self) -> bool {
self.matrix_layout() == MatrixLayout::Contiguous
is_contiguous(&self.shape.dims, &self.strides)
}
}

pub(crate) fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
if shape.is_empty() {
return true;
}

if shape.len() == 1 {
return strides[0] == 1;
}

let mut prev_stride = 1;
let mut current_num_elems_shape = 1;

for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
if i > 0 {
if current_num_elems_shape != *stride {
return false;
}

if prev_stride >= *stride {
return false;
}
}

current_num_elems_shape *= shape;
prev_stride = *stride;
}

true
}

#[cfg(test)]
mod tests {
use crate::tensor::base::is_contiguous;

#[test]
fn is_contiguous_basic() {
assert!(is_contiguous(&[32, 32], &[32, 1]));
}

#[test]
fn is_contiguous_permuted() {
assert!(!is_contiguous(&[32, 32], &[1, 32]));
}

#[test]
fn is_contiguous_slice() {
assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
}

#[test]
fn is_contiguous_4d_positive() {
assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
}

pub(crate) fn matrix_layout(&self) -> MatrixLayout {
matrix_layout(&self.strides)
#[test]
fn is_contiguous_4d_negative() {
assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
}
}
1 change: 1 addition & 0 deletions crates/burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ impl TensorData {
// Ensure shape is valid
let shape = shape.into();
let shape_numel = Self::numel(&shape);
value.truncate(shape_numel);
let numel = value.len();
assert_eq!(
shape_numel, numel,
Expand Down
29 changes: 28 additions & 1 deletion crates/burn-tensor/src/tests/ops/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod tests {
use burn_tensor::{Shape, Tensor, TensorData};

#[test]
fn test_narrow() {
fn test_narrow_1() {
let tensor: Tensor<TestBackend, 2> = Tensor::from_data(
TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]),
&Default::default(),
Expand All @@ -15,13 +15,40 @@ mod tests {

assert_eq!(output.shape(), Shape::from([2, 3]));
output.into_data().assert_approx_eq(&expected, 3);
}

#[test]
fn test_narrow_2() {
let tensor: Tensor<TestBackend, 2> = Tensor::from_data(
TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]),
&Default::default(),
);

let output = tensor.clone().narrow(1, 1, 2);
let expected = TensorData::from([[2., 3.], [5., 6.], [8., 9.]]);
assert_eq!(output.shape(), Shape::from([3, 2]));
output.into_data().assert_approx_eq(&expected, 3);
}

#[test]
fn test_narrow_3() {
let device = &Default::default();
let shape = Shape::new([8, 8]);
let tensor: Tensor<TestBackend, 2> =
TestTensorInt::arange(0..shape.num_elements() as i64, &device)
.reshape(shape)
.float();

let output = tensor.clone().narrow(0, 3, 4);
let expected = TensorData::from([
[24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
[32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
[40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
[48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0],
]);
output.into_data().assert_approx_eq(&expected, 3);
}

#[test]
#[should_panic]
fn test_narrow_invalid_dim() {
Expand Down
12 changes: 6 additions & 6 deletions crates/onnx-ir/src/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,23 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Cos => same_as_input(node),
NodeType::Div => same_as_input_broadcast(node),
NodeType::Dropout => same_as_input(node),
NodeType::Equal => elementwise_comparsion_outputs(node),
NodeType::Equal => elementwise_comparison_outputs(node),
NodeType::Erf => same_as_input(node),
NodeType::Exp => same_as_input(node),
NodeType::Expand => expand_update_outputs(node),
NodeType::Flatten => flatten_update_outputs(node),
NodeType::Gelu => same_as_input(node),
NodeType::Gather => gather_update_outputs(node),
NodeType::GatherElements => same_as_input(node),
NodeType::Greater => elementwise_comparsion_outputs(node),
NodeType::GreaterOrEqual => elementwise_comparsion_outputs(node),
NodeType::Greater => elementwise_comparison_outputs(node),
NodeType::GreaterOrEqual => elementwise_comparison_outputs(node),
NodeType::HardSigmoid => same_as_input(node),
NodeType::GlobalAveragePool => same_as_input(node),
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
NodeType::LayerNormalization => same_as_input(node),
NodeType::LeakyRelu => same_as_input(node),
NodeType::Less => elementwise_comparsion_outputs(node),
NodeType::LessOrEqual => elementwise_comparsion_outputs(node),
NodeType::Less => elementwise_comparison_outputs(node),
NodeType::LessOrEqual => elementwise_comparison_outputs(node),
NodeType::Linear => linear_update_outputs(node),
NodeType::Log => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
Expand Down Expand Up @@ -488,7 +488,7 @@ fn temporary_pass_through_stub(node: &mut Node) {
/// i.e., comparison operators like Equal, Greater, Less, etc.
///
/// Support for broadcasting is assumed
fn elementwise_comparsion_outputs(node: &mut Node) {
fn elementwise_comparison_outputs(node: &mut Node) {
let input1_type = &node.inputs[0].ty;
let input2_type = &node.inputs[1].ty;

Expand Down

0 comments on commit 94cd8a2

Please sign in to comment.