Skip to content

Commit

Permalink
Select kernel from CPA to CubeCL (#2168)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: louisfd <[email protected]>
  • Loading branch information
mepatrick73 and louisfd authored Aug 27, 2024
1 parent a600a7b commit 795201d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 123 deletions.
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mod slice_assign;

pub use flip::*;
pub use repeat_dim::*;
pub use select::*;
pub(crate) use select::*;
pub(crate) use select_assign::*;
pub use slice::*;
pub use slice_assign::*;
Expand Down
158 changes: 36 additions & 122 deletions crates/burn-jit/src/kernel/index/select.rs
Original file line number Diff line number Diff line change
@@ -1,118 +1,33 @@
use crate::{
element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime,
};
use cubecl::{
cpa,
frontend::TensorHandleRef,
ir::{Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
use std::marker::PhantomData;

#[derive(new)]
struct SelectEagerKernel<R: JitRuntime, E: JitElement> {
dim: usize,
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}

pub struct SelectComputeShader {
input: Variable,
indices: Variable,
output: Variable,
dim: usize,
}

impl SelectComputeShader {
pub fn expand(self, scope: &mut Scope) {
let input = self.input;
let indices = self.indices;
let output = self.output;
let id = Variable::AbsolutePos;
let offset_input = scope.zero(Elem::UInt);

cpa!(
scope,
range(0u32, Variable::Rank).for_each(|i, scope| {
let stride_input = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape_output = scope.create_local(Elem::UInt);

cpa!(scope, stride_input = stride(input, i));
cpa!(scope, stride_output = stride(output, i));
cpa!(scope, shape_output = shape(output, i));

let offset_local = scope.create_local(Elem::UInt);
cpa!(scope, offset_local = id / stride_output);
cpa!(scope, offset_local = offset_local % shape_output);

let dim_index = scope.create_local(Elem::Bool);
cpa!(scope, dim_index = i == self.dim);

cpa!(scope, if(dim_index).then(|scope| {
cpa!(scope, offset_local = indices[offset_local]);
cpa!(scope, offset_local = offset_local * stride_input);
}).else(|scope| {
cpa!(scope, offset_local = offset_local * stride_input);
}));

cpa!(scope, offset_input += offset_local);
})
);

let value = scope.create_local(input.item());
cpa!(scope, value = input[offset_input]);
cpa!(scope, output[id] = value);
use cubecl::prelude::*;
use cubecl::{calculate_cube_count_elemwise, CubeDim};

#[cube(launch_unchecked)]
fn select_kernel<T: Numeric, I: Numeric>(
input: &Tensor<T>,
indices: &Tensor<I>,
output: &mut Tensor<T>,
dim: UInt,
) {
if ABSOLUTE_POS >= output.len() {
return;
}
}

impl<R: JitRuntime, E: JitElement> Kernel for SelectEagerKernel<R, E> {
fn define(&self) -> KernelDefinition {
let mut scope = Scope::root();
let item = E::cube_elem().into();
let item_indices: Item = Elem::Int(IntKind::I32).into();
let mut offset_input = UInt::new(0);

let input = Variable::GlobalInputArray { id: 0, item };
let indices = Variable::GlobalInputArray {
id: 1,
item: item_indices,
};
let output = Variable::GlobalOutputArray { id: 0, item };
for i in range(0u32, output.rank(), Comptime::new(false)) {
let mut offset_local = ABSOLUTE_POS / output.stride(i) % output.shape(i);

scope.write_global_custom(output);

SelectComputeShader {
input,
indices,
output,
dim: self.dim,
if i == dim {
offset_local = UInt::cast_from(indices[offset_local]);
}
.expand(&mut scope);

let input = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let indices = InputInfo::Array {
item: item_indices,
visibility: Visibility::Read,
};
let output = OutputInfo::Array { item };

let info = KernelExpansion {
inputs: vec![input, indices],
outputs: vec![output],
scope,
};

let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
offset_input += offset_local * input.stride(i);
}

fn id(&self) -> cubecl::KernelId {
cubecl::KernelId::new::<Self>().info(self.dim)
}
output[ABSOLUTE_POS] = input[offset_input];
}

pub(crate) fn select<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>(
Expand All @@ -122,26 +37,25 @@ pub(crate) fn select<R: JitRuntime, E: JitElement, I: JitElement, const D: usize
) -> JitTensor<R, E, D> {
let mut shape_output = tensor.shape.clone();
shape_output.dims[dim] = indices.shape.dims[0];
let total_elem = shape_output.num_elements();

let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
let kernel = SelectEagerKernel::<R, E>::new(dim);

let num_elems = indices.shape.dims[0];
let mut shapes = [1; D];
let mut strides = [num_elems; D];
shapes[D - 1] = num_elems;
strides[D - 1] = 1;
Execution::start(kernel, tensor.client.clone())
.inputs(&[
tensor.as_handle_ref(),
// This is a current hacks because the info buffer that contains the strides and shapes is
// hardcoded to only contains information about tensors of the same rank. However, since
// we don't rely on the shape and stride of the indices tensors, it doesn't matter
// which value we put, it just needs to be of the same rank.
unsafe { TensorHandleRef::from_raw_parts(&indices.handle, &strides, &shapes) },
])
.outputs(&[output.as_handle_ref()])
.execute(CubeCountSettings::Output { pos: 0 });

let dummy_array = [1; D];
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim);

unsafe {
select_kernel::launch_unchecked::<E::Primitive, I::Primitive, R>(
&tensor.client,
cube_count,
cube_dim,
tensor.as_tensor_arg(1),
// Ignore shape and stride
TensorArg::from_raw_parts(&indices.handle, &dummy_array, &dummy_array, 1),
output.as_tensor_arg(1),
ScalarArg::new(dim as u32),
)
};
output
}

0 comments on commit 795201d

Please sign in to comment.