Skip to content

Commit

Permalink
Refactor execute_dynamic with Execution struct (#1550)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Mar 28, 2024
1 parent efc3b2d commit edcd92f
Show file tree
Hide file tree
Showing 30 changed files with 304 additions and 441 deletions.
2 changes: 1 addition & 1 deletion backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ pub(crate) fn run_backend_comparison_benchmarks(
let status = run_cargo("bench", &args).unwrap();
if !status.success() {
println!(
"Benchmark {} didn't ran successfully on the backend {}",
"Benchmark {} didn't run successfully on the backend {}",
bench_str, backend_str
);
continue;
Expand Down
1 change: 0 additions & 1 deletion crates/burn-jit/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Burn JIT Backend

Generic backend that can be compiled just-in-time (JIT) to any shader language target
In progress: At the moment, only WGSL compilation is supported, and some kernels still rely on pure WGSL
59 changes: 8 additions & 51 deletions crates/burn-jit/src/codegen/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,8 @@ pub fn execute_static<R, K, E>(
R: Runtime,
E: JitElement,
{
execute_static_::<R, K, E, E, E>(inputs, outputs, scalar_elems, None, None, launch, client)
}

fn execute_static_<R, K, E1, E2, E3>(
inputs: &[EagerHandle<R>],
outputs: &[EagerHandle<R>],
scalars_1: Option<&[E1]>,
scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>,
launch: WorkgroupLaunch,
client: ComputeClient<R::Server, R::Channel>,
) where
K: StaticKernelSource + 'static,
R: Runtime,
E1: JitElement,
E2: JitElement,
E3: JitElement,
{
let settings = execute_settings(
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
);
let settings =
execute_settings::<R, E, E, E>(inputs, outputs, scalar_elems, None, None, launch, &client);
let mut handles = settings.handles_tensors;
let workgroup = settings.workgroup;

Expand All @@ -64,6 +45,7 @@ fn execute_static_<R, K, E1, E2, E3>(
}

let kernel = Box::new(StaticKernel::<K>::new(workgroup));

client.execute(kernel, &handles);
}

Expand Down Expand Up @@ -128,7 +110,7 @@ where
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: WorkgroupLaunch) {
execute_dynamic_::<R, K, f32, f32, f32>(
execute_dynamic::<R, K, f32, f32, f32>(
self.inputs,
self.outputs,
None,
Expand Down Expand Up @@ -163,7 +145,7 @@ where
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: WorkgroupLaunch) {
execute_dynamic_::<R, K, E, f32, f32>(
execute_dynamic::<R, K, E, f32, f32>(
self.inputs,
self.outputs,
Some(self.scalars.0),
Expand Down Expand Up @@ -203,7 +185,7 @@ where
K: DynamicKernelSource + 'static,
R: Runtime,
{
execute_dynamic_::<R, K, E1, E2, f32>(
execute_dynamic::<R, K, E1, E2, f32>(
self.inputs,
self.outputs,
Some(self.scalars.0),
Expand All @@ -227,7 +209,7 @@ where
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: WorkgroupLaunch) {
execute_dynamic_::<R, K, E1, E2, E3>(
execute_dynamic::<R, K, E1, E2, E3>(
self.inputs,
self.outputs,
Some(self.scalars.0),
Expand All @@ -240,33 +222,8 @@ where
}
}

/// Execute a dynamic kernel.
pub fn execute_dynamic<R, K, E>(
inputs: &[EagerHandle<R>],
outputs: &[EagerHandle<R>],
scalar_elems: Option<&[E]>,
kernel: K,
launch: WorkgroupLaunch,
client: ComputeClient<R::Server, R::Channel>,
) where
K: DynamicKernelSource + 'static,
R: Runtime,
E: JitElement,
{
execute_dynamic_::<R, K, E, E, E>(
inputs,
outputs,
scalar_elems,
None,
None,
kernel,
launch,
client,
)
}

#[allow(clippy::too_many_arguments)]
fn execute_dynamic_<R, K, E1, E2, E3>(
fn execute_dynamic<R, K, E1, E2, E3>(
inputs: &[EagerHandle<R>],
outputs: &[EagerHandle<R>],
scalars_1: Option<&[E1]>,
Expand Down
20 changes: 8 additions & 12 deletions crates/burn-jit/src/kernel/cast/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{any::TypeId, marker::PhantomData};

use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, Scope, Variable, Visibility},
Expand All @@ -21,7 +21,7 @@ pub fn cast<R: Runtime, EI: JitElement, EO: JitElement, const D: usize>(
return JitTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle);
}

let kernel = CastEagerKernel::new();
let kernel = CastEagerKernel::<R, EI, EO>::new();
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
let output = JitTensor::new(
Expand All @@ -31,22 +31,18 @@ pub fn cast<R: Runtime, EI: JitElement, EO: JitElement, const D: usize>(
buffer,
);

execute_dynamic::<R, CastEagerKernel<R, EI, EO>, u32>(
&[EagerHandle::new(
Execution::start(kernel, tensor.client)
.inputs(&[EagerHandle::<R>::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });

output
}
Expand Down
20 changes: 8 additions & 12 deletions crates/burn-jit/src/kernel/cast/bool_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::marker::PhantomData;

use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
Expand All @@ -20,7 +20,7 @@ use crate::{
pub fn bool_cast<R: Runtime, EO: JitElement, const D: usize>(
tensor: JitTensor<R, u32, D>,
) -> JitTensor<R, EO, D> {
let kernel = BoolCastEagerKernel::new();
let kernel = BoolCastEagerKernel::<R, EO>::new();
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
let output = JitTensor::new(
Expand All @@ -30,22 +30,18 @@ pub fn bool_cast<R: Runtime, EO: JitElement, const D: usize>(
buffer,
);

execute_dynamic::<R, BoolCastEagerKernel<R, EO>, u32>(
&[EagerHandle::new(
Execution::start(kernel, tensor.client)
.inputs(&[EagerHandle::<R>::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });

output
}
Expand Down
24 changes: 10 additions & 14 deletions crates/burn-jit/src/kernel/contiguous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::marker::PhantomData;

use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, Elem, IndexOffsetGlobalWithLayout, Scope, Variable, Visibility},
Expand All @@ -18,9 +18,9 @@ pub(crate) struct IntoContiguousShader {
}

#[derive(new)]
pub(crate) struct IntoContiguousEagerKernel<R: Runtime, EO: JitElement> {
pub(crate) struct IntoContiguousEagerKernel<R: Runtime, E: JitElement> {
_runtime: PhantomData<R>,
_elem_out: PhantomData<EO>,
_elem_out: PhantomData<E>,
}

/// Make a jit tensor contiguous.
Expand All @@ -31,7 +31,7 @@ pub fn into_contiguous<R: Runtime, E: JitElement, const D: usize>(
return tensor;
}

let kernel = IntoContiguousEagerKernel::new();
let kernel = IntoContiguousEagerKernel::<R, E>::new();
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
let output = JitTensor::new(
Expand All @@ -41,22 +41,18 @@ pub fn into_contiguous<R: Runtime, E: JitElement, const D: usize>(
buffer,
);

execute_dynamic::<R, IntoContiguousEagerKernel<R, E>, u32>(
&[EagerHandle::new(
Execution::start(kernel, tensor.client)
.inputs(&[EagerHandle::<R>::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });

output
}
Expand Down
43 changes: 20 additions & 23 deletions crates/burn-jit/src/kernel/conv/conv2d.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions},
ElementConversion, Shape,
Shape,
};
use std::marker::PhantomData;

use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
Expand All @@ -17,7 +17,7 @@ use crate::{
reshape,
},
tensor::JitTensor,
Runtime, RuntimeInt,
Runtime,
};

#[derive(new)]
Expand Down Expand Up @@ -335,32 +335,29 @@ pub(crate) fn conv2d<R: Runtime, E: JitElement>(
}
};

let kernel = Conv2dEagerKernel::new();
let kernel = Conv2dEagerKernel::<R, E>::new();

execute_dynamic::<R, Conv2dEagerKernel<R, E>, RuntimeInt<R>>(
&[
EagerHandle::new(&input.handle, &input.strides, &input.shape.dims),
Execution::start(kernel, input.client)
.inputs(&[
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
EagerHandle::new(&weight.handle, &weight.strides, &weight.shape.dims),
EagerHandle::new(&bias.handle, &bias.strides, &bias.shape.dims),
],
&[EagerHandle::new(
])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(options.stride[0] as u32).elem(),
(options.stride[1] as u32).elem(),
(options.dilation[0] as u32).elem(),
(options.dilation[1] as u32).elem(),
(options.padding[0] as u32).elem(),
(options.padding[1] as u32).elem(),
(options.groups as u32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
input.client,
);
)])
.with_scalars(&[
options.stride[0] as u32,
options.stride[1] as u32,
options.dilation[0] as u32,
options.dilation[1] as u32,
options.padding[0] as u32,
options.padding[1] as u32,
options.groups as u32,
])
.execute(WorkgroupLaunch::Output { pos: 0 });

output
}
16 changes: 8 additions & 8 deletions crates/burn-jit/src/kernel/conv/conv_transpose2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
tensor::JitTensor,
Compiler, Runtime,
};
use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape};
use burn_tensor::{ops::ConvTransposeOptions, Element, Shape};

#[derive(new)]
struct Conv2dTransposeEagerKernel<R, E> {
Expand Down Expand Up @@ -399,13 +399,13 @@ pub(crate) fn conv_transpose2d<R: Runtime, E: JitElement + Element>(
&output.shape.dims,
)])
.with_scalars(&[
(options.stride[0] as u32).elem::<u32>(),
(options.stride[1] as u32).elem(),
(options.dilation[0] as u32).elem(),
(options.dilation[1] as u32).elem(),
(options.padding[0] as u32).elem(),
(options.padding[1] as u32).elem(),
(options.groups as u32).elem(),
options.stride[0] as u32,
options.stride[1] as u32,
options.dilation[0] as u32,
options.dilation[1] as u32,
options.padding[0] as u32,
options.padding[1] as u32,
options.groups as u32,
])
.execute(WorkgroupLaunch::Output { pos: 0 });

Expand Down
Loading

0 comments on commit edcd92f

Please sign in to comment.