Skip to content

Commit

Permalink
Matmul + CubeCL Update (#2551)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 29, 2024
1 parent a5624c1 commit 3dc4b43
Show file tree
Hide file tree
Showing 18 changed files with 137 additions and 291 deletions.
26 changes: 14 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
2 changes: 2 additions & 0 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ strum_macros = { workspace = true }
sysinfo = { workspace = true, features = ["serde"] }
wgpu = { workspace = true }
wsl = { workspace = true }
tracing-subscriber = { workspace = true }
log = { workspace = true }

[dev-dependencies]
rstest = { workspace = true }
Expand Down
24 changes: 14 additions & 10 deletions backend-comparison/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,20 @@ fn bench<B: Backend>(
url: Option<&str>,
token: Option<&str>,
) {
let benchmarks = [(2, 4096, 4096, 4096), (8, 2048, 2048, 2048)]
.into_iter()
.map(|(b, m, n, k)| {
let shape_lhs = [b, m, k].into();
let shape_rhs = [b, k, n].into();

MatmulBenchmark::<B, 3>::new(shape_lhs, shape_rhs, device.clone())
})
.map(run_benchmark)
.collect();
let benchmarks = [
(3, 4096, 4096, 4096),
(8, 2048, 2048, 2048),
(2, 4096, 4096, 512),
]
.into_iter()
.map(|(b, m, n, k)| {
let shape_lhs = [b, m, k].into();
let shape_rhs = [b, k, n].into();

MatmulBenchmark::<B, 3>::new(shape_lhs, shape_rhs, device.clone())
})
.map(run_benchmark)
.collect();

save::<B>(benchmarks, device, feature_name, url, token).unwrap();
}
Expand Down
27 changes: 27 additions & 0 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::error::Error;

use tracing_subscriber::filter::LevelFilter;

pub mod burnbenchapp;
pub mod persistence;

Expand Down Expand Up @@ -26,10 +30,33 @@ pub fn get_sharing_url(args: &[String]) -> Option<&str> {
get_argument(args, "--sharing-url")
}

pub fn init_log() -> Result<(), Box<dyn Error + Send + Sync>> {
let result = tracing_subscriber::fmt()
.with_max_level(LevelFilter::DEBUG)
.without_time()
.try_init();

if result.is_ok() {
update_panic_hook();
}
result
}

fn update_panic_hook() {
let hook = std::panic::take_hook();

std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {}", info.to_string());
hook(info);
}));
}

#[macro_export]
macro_rules! bench_on_backend {
() => {
use std::env;
backend_comparison::init_log().unwrap();

let args: Vec<String> = env::args().collect();
let url = backend_comparison::get_sharing_url(&args);
let token = backend_comparison::get_sharing_token(&args);
Expand Down
18 changes: 9 additions & 9 deletions crates/burn-jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ autotune = []
default = ["autotune", "std", "fusion", "cubecl/default"]
doc = ["default"]
export_tests = [
"burn-tensor-testgen",
"serial_test",
"burn-autodiff/export_tests",
"burn-tensor/export_tests",
"burn-ndarray",
"fusion",
"paste",
"burn-tensor-testgen",
"serial_test",
"burn-autodiff/export_tests",
"burn-tensor/export_tests",
"burn-ndarray",
"fusion",
"paste",
]
fusion = ["burn-fusion"]
std = ["cubecl/std"]
Expand All @@ -32,8 +32,8 @@ template = []
burn-common = { path = "../burn-common", version = "0.16.0" }
burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true }
burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [
"cubecl",
"repr",
"cubecl",
"repr",
] }
cubecl = { workspace = true, features = ["linalg"] }

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/col2im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ fn execute<R: JitRuntime, E: FloatElement>(
let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]);
let input = reshape(input, input_shape);

let columns = matmul::<R, E>(weight, input, MatmulStrategy::default());
let columns = matmul::<R, E>(weight, input, None, MatmulStrategy::default());
let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));

col2im::<R, E>(
Expand Down
14 changes: 8 additions & 6 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use cubecl::{
components::{
global::{
self,
homogeneous::{self, CyclicLoading, RhsLoader},
full_load::{self, CyclicLoading, RhsLoader},
unloader::Unloader,
AccumulatorLoader, Config as _, Loader,
},
Expand Down Expand Up @@ -93,9 +93,11 @@ where
for _ in 0..num_loops {
sync_units();

let lhs_stage_reader = &Self::LhsLoader::fill_stage(&mut lhs_loader, config);
let rhs_stage_reader =
&Self::RhsLoader::fill_stage(&mut rhs_loader, config.to_matmul_config());
Self::LhsLoader::fill_stage(&mut lhs_loader, config);
Self::RhsLoader::fill_stage(&mut rhs_loader, config.to_matmul_config());

let lhs_stage_reader = &Self::LhsLoader::as_stage_reader(&lhs_loader);
let rhs_stage_reader = &Self::RhsLoader::as_stage_reader(&rhs_loader);

sync_units();

Expand Down Expand Up @@ -172,7 +174,7 @@ where
Acc: Numeric,
SMM: stage::Matmul<ES, EG, Acc>,
{
type Config = config::Config<homogeneous::Config<SMM::Config>>;
type Config = config::Config<full_load::Config<SMM::Config>>;

fn check_config(config: Self::Config) {
SMM::check_config(config.to_smm_config());
Expand All @@ -198,7 +200,7 @@ where
);

config::Config::new(
homogeneous::Config::new(
full_load::Config::new(
smm_config,
problem.m as u32 % SMM::M != 0,
problem.n as u32 % SMM::N != 0,
Expand Down
7 changes: 5 additions & 2 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,22 @@ pub struct SimpleIm2colLoader<EG: Numeric, ES: Numeric, G: Config> {
impl<EG: Numeric, ES: Numeric, G: Config> Loader<EG, ES, G> for SimpleIm2colLoader<EG, ES, G> {
type StageReader = LhsReader<ES>;

fn fill_stage(this: &mut Self, #[comptime] config: G) -> Self::StageReader {
fn fill_stage(this: &mut Self, #[comptime] config: G) {
SimpleIm2col::load_to_slice::<EG, ES, G>(
&this.tensor_view,
&mut this.stage.as_slice_mut(),
Ident::Lhs,
config,
);
LhsReader::new(this.stage)
}

fn advance_view(this: &mut Self, k_offset: u32) {
this.tensor_view.update_view(k_offset);
}

fn as_stage_reader(this: &Self) -> Self::StageReader {
LhsReader::new(this.stage)
}
}

#[cube]
Expand Down
15 changes: 4 additions & 11 deletions crates/burn-jit/src/kernel/conv/conv2d/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions},
Shape,
};
use cubecl::{calculate_cube_count_elemwise, linalg::matmul, prelude::*};
use cubecl::{calculate_cube_count_elemwise, prelude::*};

use crate::{
kernel::{
conv::index,
into_contiguous, launch_binop,
matmul::{cube_strategy, matmul, MatmulStrategy},
matmul::{matmul, MatmulStrategy},
AddOp,
},
ops::{numeric::empty_device, reshape, swap_dims},
Expand Down Expand Up @@ -271,7 +271,7 @@ fn execute_1x1_kernel<R: JitRuntime, E: FloatElement>(
let weight = reshape(weight, Shape::new([groups, out_c_per_grp, in_c_per_grp]));
let in_shape = Shape::new([groups, in_c_per_grp, batch_size * height * width]);
let input = reshape(input, in_shape);
let out = matmul::<R, E>(weight, input, MatmulStrategy::default());
let out = matmul::<R, E>(weight, input, None, MatmulStrategy::default());
let mut out = reshape(out, Shape::new([out_channels, batch_size, height, width]));

if let Some(bias) = bias {
Expand All @@ -290,7 +290,6 @@ fn execute<R: JitRuntime, E: FloatElement>(
out_h: usize,
out_w: usize,
) {
let client = input.client.clone();
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims();
let groups = options.groups;

Expand All @@ -302,11 +301,5 @@ fn execute<R: JitRuntime, E: FloatElement>(
let columns = reshape(columns, Shape::new([groups, col_shape_0, col_shape_1]));
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0]));

matmul::launch_ref::<R, E>(
&cube_strategy::<R>(&client),
&client,
&weight.as_handle_ref(),
&columns.as_handle_ref(),
&out.as_handle_ref(),
);
matmul::<R, E>(weight, columns, Some(out), Default::default());
}
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/deform_conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ pub(crate) fn deform_conv2d<R: JitRuntime, E: FloatElement>(

let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0]));
let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));
let out = matmul::<R, E>(weight, columns, MatmulStrategy::default());
let out = matmul::<R, E>(weight, columns, None, MatmulStrategy::default());

let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w]));
let out = swap_dims(out, 0, 1);
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ fn compute_weight_grad<R: JitRuntime, E: FloatElement>(
let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));
let columns = swap_dims(columns, 1, 2);

let grad_weight = matmul::<R, E>(out_grad, columns, MatmulStrategy::default());
let grad_weight = matmul::<R, E>(out_grad, columns, None, MatmulStrategy::default());

reshape(
grad_weight,
Expand Down Expand Up @@ -150,7 +150,7 @@ fn backward_gradient_inputs<R: JitRuntime, E: FloatElement>(
for group in 0..groups {
let weight = swap_dims(index::<R, E>(weight.clone(), group), 0, 1);
let out_grad = index::<R, E>(out_grad.clone(), group);
let values = matmul::<R, E>(weight, out_grad, MatmulStrategy::default());
let values = matmul::<R, E>(weight, out_grad, None, MatmulStrategy::default());
let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1]));
columns = slice_assign::<R, E>(
columns,
Expand Down
Loading

0 comments on commit 3dc4b43

Please sign in to comment.