From 3dc4b43e92ec5683ca39c194acf274b131d44dfa Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Fri, 29 Nov 2024 15:08:46 -0500 Subject: [PATCH] Matmul + CubeCL Update (#2551) --- Cargo.lock | 26 ++-- Cargo.toml | 4 +- backend-comparison/Cargo.toml | 2 + backend-comparison/benches/matmul.rs | 24 +-- backend-comparison/src/lib.rs | 27 ++++ crates/burn-jit/Cargo.toml | 18 +-- .../burn-jit/src/kernel/conv/conv2d/col2im.rs | 2 +- .../conv/conv2d/gemm/homogeneous/base.rs | 14 +- .../kernel/conv/conv2d/gemm/loader/im2col.rs | 7 +- .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 15 +- .../burn-jit/src/kernel/conv/deform_conv2d.rs | 2 +- .../kernel/conv/deform_conv_transpose2d.rs | 4 +- crates/burn-jit/src/kernel/matmul/base.rs | 69 +-------- crates/burn-jit/src/kernel/matmul/mod.rs | 2 - crates/burn-jit/src/kernel/matmul/simple.rs | 143 ------------------ .../burn-jit/src/kernel/matmul/tune/base.rs | 63 +++++--- crates/burn-jit/src/ops/float_ops.rs | 2 +- crates/burn-wgpu/Cargo.toml | 4 +- 18 files changed, 137 insertions(+), 291 deletions(-) delete mode 100644 crates/burn-jit/src/kernel/matmul/simple.rs diff --git a/Cargo.lock b/Cargo.lock index 8897110b75..e7f9b7ea4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -424,6 +424,7 @@ dependencies = [ "github-device-flow", "half", "indicatif", + "log", "os_info", "percent-encoding", "rand", @@ -435,6 +436,7 @@ dependencies = [ "strum", "strum_macros", "sysinfo 0.32.1", + "tracing-subscriber", "wgpu", "wsl", ] @@ -1666,7 +1668,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1698,7 +1700,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1715,7 +1717,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1733,7 +1735,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1747,7 +1749,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1763,7 +1765,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1789,7 +1791,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "bytemuck", "cubecl-core", @@ -1800,7 +1802,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1815,7 +1817,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1852,7 +1854,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "async-channel", "async-lock", @@ -1873,7 +1875,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1887,7 +1889,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index ca45f967bb..1a22a867a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 1d1a62cf49..39ddd0c6f5 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -54,6 +54,8 @@ strum_macros = { workspace = true } sysinfo = { workspace = true, features = ["serde"] } wgpu = { workspace = true } wsl = { workspace = true } +tracing-subscriber = { workspace = true } +log = { workspace = true } [dev-dependencies] rstest = { workspace = true } diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index d31e7cc954..0e9f3622b1 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -44,16 +44,20 @@ fn bench( url: Option<&str>, token: Option<&str>, ) { - let benchmarks = [(2, 4096, 4096, 4096), (8, 2048, 2048, 2048)] - .into_iter() - .map(|(b, m, n, k)| { - let shape_lhs = [b, m, k].into(); - let shape_rhs = [b, k, n].into(); - - MatmulBenchmark::::new(shape_lhs, shape_rhs, device.clone()) - }) - .map(run_benchmark) - .collect(); + let benchmarks = [ + (3, 4096, 4096, 4096), + (8, 2048, 2048, 2048), + (2, 4096, 4096, 512), + ] + .into_iter() + .map(|(b, m, n, k)| { + let shape_lhs = [b, m, k].into(); + let shape_rhs = [b, k, n].into(); + + MatmulBenchmark::::new(shape_lhs, shape_rhs, device.clone()) + }) + .map(run_benchmark) + .collect(); save::(benchmarks, device, feature_name, url, token).unwrap(); } diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index a1eb8f16a1..03e2d70444 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -1,3 +1,7 @@ +use std::error::Error; + +use tracing_subscriber::filter::LevelFilter; + pub mod burnbenchapp; pub mod persistence; @@ -26,10 +30,33 @@ pub fn get_sharing_url(args: &[String]) -> Option<&str> { get_argument(args, "--sharing-url") } +pub fn init_log() -> Result<(), Box> { + let result = tracing_subscriber::fmt() + .with_max_level(LevelFilter::DEBUG) + .without_time() + .try_init(); + + if result.is_ok() { + update_panic_hook(); + } + result +} + +fn update_panic_hook() { + let hook = std::panic::take_hook(); + + std::panic::set_hook(Box::new(move |info| { + log::error!("PANIC => {}", info.to_string()); + hook(info); + })); +} + #[macro_export] macro_rules! bench_on_backend { () => { use std::env; + backend_comparison::init_log().unwrap(); + let args: Vec = env::args().collect(); let url = backend_comparison::get_sharing_url(&args); let token = backend_comparison::get_sharing_token(&args); diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index ce5b22b8ac..0811374fd1 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -16,13 +16,13 @@ autotune = [] default = ["autotune", "std", "fusion", "cubecl/default"] doc = ["default"] export_tests = [ - "burn-tensor-testgen", - "serial_test", - "burn-autodiff/export_tests", - "burn-tensor/export_tests", - "burn-ndarray", - "fusion", - "paste", + "burn-tensor-testgen", + "serial_test", + "burn-autodiff/export_tests", + "burn-tensor/export_tests", + "burn-ndarray", + "fusion", + "paste", ] fusion = ["burn-fusion"] std = ["cubecl/std"] @@ -32,8 +32,8 @@ template = [] burn-common = { path = "../burn-common", version = "0.16.0" } burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ - "cubecl", - "repr", + "cubecl", + "repr", ] } cubecl = { workspace = true, features = ["linalg"] } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 0659561805..0d9c48dc30 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -145,7 +145,7 @@ fn execute( let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]); let input = reshape(input, input_shape); - let columns = matmul::(weight, input, MatmulStrategy::default()); + let columns = matmul::(weight, input, None, MatmulStrategy::default()); let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); col2im::( diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs index ca7399e72d..582b1e59af 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs @@ -3,7 +3,7 @@ use cubecl::{ components::{ global::{ self, - homogeneous::{self, CyclicLoading, RhsLoader}, + full_load::{self, CyclicLoading, RhsLoader}, unloader::Unloader, AccumulatorLoader, Config as _, Loader, }, @@ -93,9 +93,11 @@ where for _ in 0..num_loops { sync_units(); - let lhs_stage_reader = &Self::LhsLoader::fill_stage(&mut lhs_loader, config); - let rhs_stage_reader = - &Self::RhsLoader::fill_stage(&mut rhs_loader, config.to_matmul_config()); + Self::LhsLoader::fill_stage(&mut lhs_loader, config); + Self::RhsLoader::fill_stage(&mut rhs_loader, config.to_matmul_config()); + + let lhs_stage_reader = &Self::LhsLoader::as_stage_reader(&lhs_loader); + let rhs_stage_reader = &Self::RhsLoader::as_stage_reader(&rhs_loader); sync_units(); @@ -172,7 +174,7 @@ where Acc: Numeric, SMM: stage::Matmul, { - type Config = config::Config>; + type Config = config::Config>; fn check_config(config: Self::Config) { SMM::check_config(config.to_smm_config()); @@ -198,7 +200,7 @@ where ); config::Config::new( - homogeneous::Config::new( + full_load::Config::new( smm_config, problem.m as u32 % SMM::M != 0, problem.n as u32 % SMM::N != 0, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs index 0a1ed8728c..11ee03d83e 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs @@ -25,19 +25,22 @@ pub struct SimpleIm2colLoader { impl Loader for SimpleIm2colLoader { type StageReader = LhsReader; - fn fill_stage(this: &mut Self, #[comptime] config: G) -> Self::StageReader { + fn fill_stage(this: &mut Self, #[comptime] config: G) { SimpleIm2col::load_to_slice::( &this.tensor_view, &mut this.stage.as_slice_mut(), Ident::Lhs, config, ); - LhsReader::new(this.stage) } fn advance_view(this: &mut Self, k_offset: u32) { this.tensor_view.update_view(k_offset); } + + fn as_stage_reader(this: &Self) -> Self::StageReader { + LhsReader::new(this.stage) + } } #[cube] diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index 3a7b23df76..a65c29466c 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -2,13 +2,13 @@ use burn_tensor::{ ops::{conv::calculate_conv_output_size, ConvOptions}, Shape, }; -use cubecl::{calculate_cube_count_elemwise, linalg::matmul, prelude::*}; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ kernel::{ conv::index, into_contiguous, launch_binop, - matmul::{cube_strategy, matmul, MatmulStrategy}, + matmul::{matmul, MatmulStrategy}, AddOp, }, ops::{numeric::empty_device, reshape, swap_dims}, @@ -271,7 +271,7 @@ fn execute_1x1_kernel( let weight = reshape(weight, Shape::new([groups, out_c_per_grp, in_c_per_grp])); let in_shape = Shape::new([groups, in_c_per_grp, batch_size * height * width]); let input = reshape(input, in_shape); - let out = matmul::(weight, input, MatmulStrategy::default()); + let out = matmul::(weight, input, None, MatmulStrategy::default()); let mut out = reshape(out, Shape::new([out_channels, batch_size, height, width])); if let Some(bias) = bias { @@ -290,7 +290,6 @@ fn execute( out_h: usize, out_w: usize, ) { - let client = input.client.clone(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let groups = options.groups; @@ -302,11 +301,5 @@ fn execute( let columns = reshape(columns, Shape::new([groups, col_shape_0, col_shape_1])); let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0])); - matmul::launch_ref::( - &cube_strategy::(&client), - &client, - &weight.as_handle_ref(), - &columns.as_handle_ref(), - &out.as_handle_ref(), - ); + matmul::(weight, columns, Some(out), Default::default()); } diff --git a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs index 438850fe72..b22821aef1 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs @@ -298,7 +298,7 @@ pub(crate) fn deform_conv2d( let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0])); let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); - let out = matmul::(weight, columns, MatmulStrategy::default()); + let out = matmul::(weight, columns, None, MatmulStrategy::default()); let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); let out = swap_dims(out, 0, 1); diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index 163e4796e4..b75ac43182 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -108,7 +108,7 @@ fn compute_weight_grad( let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); let columns = swap_dims(columns, 1, 2); - let grad_weight = matmul::(out_grad, columns, MatmulStrategy::default()); + let grad_weight = matmul::(out_grad, columns, None, MatmulStrategy::default()); reshape( grad_weight, @@ -150,7 +150,7 @@ fn backward_gradient_inputs( for group in 0..groups { let weight = swap_dims(index::(weight.clone(), group), 0, 1); let out_grad = index::(out_grad.clone(), group); - let values = matmul::(weight, out_grad, MatmulStrategy::default()); + let values = matmul::(weight, out_grad, None, MatmulStrategy::default()); let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1])); columns = slice_assign::( columns, diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 60c7cbbd1c..e0b87e8931 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,25 +1,11 @@ -use super::{init_matmul_output, matmul_simple}; +use super::init_matmul_output; use crate::{tensor::JitTensor, FloatElement, JitRuntime}; -use burn_tensor::Shape; -use cubecl::{ - ir::{Elem, FloatKind}, - linalg::matmul::Strategy, - prelude::*, - Feature, -}; #[cfg(feature = "autotune")] use super::matmul_autotune; /// The strategy to be used when launching a matmul kernel. pub enum MatmulStrategy { - /// A simple kernel will be used with memory coalescing optimization. - Simple { - /// Number of invocations in x - grid_x: usize, - /// Number of invocations in y - grid_y: usize, - }, #[cfg(feature = "autotune")] /// Using autotune to choose the best kernel based on runtime information. Autotune, @@ -42,21 +28,17 @@ impl Default for MatmulStrategy { pub fn matmul( lhs: JitTensor, rhs: JitTensor, + out: Option>, strategy: MatmulStrategy, ) -> JitTensor { match strategy { - MatmulStrategy::Simple { grid_x, grid_y } => { - let out = init_matmul_output::(&lhs, &rhs); - - matmul_simple::(lhs, rhs, out, grid_x, grid_y) - } MatmulStrategy::Cube => { - let out = init_matmul_output::(&lhs, &rhs); + let out = out.unwrap_or_else(|| init_matmul_output::(&lhs, &rhs)); let client = &lhs.client; cubecl::linalg::matmul::launch_ref::( - &cube_strategy::(client), + &Default::default(), client, &lhs.as_handle_ref(), &rhs.as_handle_ref(), @@ -65,47 +47,6 @@ pub fn matmul( out } #[cfg(feature = "autotune")] - MatmulStrategy::Autotune => matmul_autotune::(lhs, rhs), - } -} - -pub(crate) fn cube_strategy( - client: &ComputeClient, -) -> Strategy { - // TODO: Replace with auto option once cubecl has one - let cmma_available = client.properties().feature_enabled(Feature::Cmma { - a: Elem::Float(FloatKind::F16), - b: Elem::Float(FloatKind::F16), - c: Elem::Float(FloatKind::F32), - m: 16, - k: 16, - n: 16, - }); - let plane_available = client.properties().feature_enabled(Feature::Plane); - match (cmma_available, plane_available) { - (true, _) => Strategy::Accelerated, - (false, true) => Strategy::PlaneMma, - _ => Strategy::Tiling2D(Default::default()), + MatmulStrategy::Autotune => matmul_autotune::(lhs, rhs, out), } } - -pub(crate) fn simple_cube_count( - lhs_shape: &Shape, - rhs_shape: &Shape, - output_shape: &Shape, - cube_dim_x: usize, - cube_dim_y: usize, -) -> CubeCount { - let ndims = lhs_shape.num_dims(); - let num_rows = lhs_shape.dims[ndims - 2]; - let num_cols = rhs_shape.dims[ndims - 1]; - - let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32; - let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32; - let mut num_iter = 1; - for i in 0..ndims - 2 { - num_iter *= output_shape.dims[i]; - } - - CubeCount::Static(cubes_x, cubes_y, num_iter as u32) -} diff --git a/crates/burn-jit/src/kernel/matmul/mod.rs b/crates/burn-jit/src/kernel/matmul/mod.rs index 633743564b..80fa8ed82c 100644 --- a/crates/burn-jit/src/kernel/matmul/mod.rs +++ b/crates/burn-jit/src/kernel/matmul/mod.rs @@ -1,11 +1,9 @@ mod base; -mod simple; mod tune; /// Contains utilitary for matmul operation pub mod utils; pub use base::*; -pub use simple::*; pub use tune::*; pub use utils::*; diff --git a/crates/burn-jit/src/kernel/matmul/simple.rs b/crates/burn-jit/src/kernel/matmul/simple.rs deleted file mode 100644 index 7d75b30395..0000000000 --- a/crates/burn-jit/src/kernel/matmul/simple.rs +++ /dev/null @@ -1,143 +0,0 @@ -//! Naive matmul kernel implementation -//! -//! Each local unit will compute a single element of the output matrix. -use crate::{ - kernel::{into_contiguous, PLANE_DIM_APPROX}, - ops::swap_dims, - tensor::JitTensor, - FloatElement, JitRuntime, -}; - -use super::simple_cube_count; -use cubecl::prelude::*; - -#[cube(launch_unchecked)] -fn matmul_kernel( - lhs: &Tensor, - rhs: &Tensor, - out: &mut Tensor, - // number of dimensions not involved in the matmul - #[comptime] num_batches: Option, -) { - let rank = out.rank(); - let end = num_batches.unwrap_or_else(|| rank - 2); - let unroll = num_batches.is_some(); - - let n_rows = lhs.shape(rank - 2); - let n_cols = rhs.shape(rank - 1); - let mut k = rhs.shape(rank - 2); - - let batch_pos = ABSOLUTE_POS_Z; - let row = CUBE_DIM_X * CUBE_POS_X + UNIT_POS_X; - let col = CUBE_DIM_Y * CUBE_POS_Y + UNIT_POS_Y; - - if row >= n_rows || col >= n_cols { - return; - } - - let vectorization_factor = vectorization_of(lhs); - - let mut offset_lhs = 0; - let mut offset_rhs = 0; - let offset_out = n_rows * n_cols * batch_pos; - - #[unroll(unroll)] - for i in 0..end { - let ogwl = offset_out / out.stride(i); - - offset_lhs += ogwl % lhs.shape(i) * lhs.stride(i); - offset_rhs += ogwl % rhs.shape(i) * rhs.stride(i); - } - - offset_lhs /= vectorization_factor; - offset_rhs /= vectorization_factor; - - let mut sum = F::vectorized(0., vectorization_factor); - - k /= vectorization_factor; - - for i in 0..k { - let lhs_index = row * k + i + offset_lhs; - let rhs_index = col * k + i + offset_rhs; - - sum += lhs[lhs_index] * rhs[rhs_index]; - } - - let mut out_index = row * n_cols + col; - out_index += offset_out; - - let unroll_sum = vectorization_factor != 1; - if unroll_sum { - let mut accum = F::new(0.); - // we unroll the loop to sum `vectorization_factor` elements at once, which lets us - // use SIMD instructions to speed up the computation - #[unroll] - for v in 0..vectorization_factor { - accum += sum[v]; - } - - out[out_index] = accum; - } else { - out[out_index] = sum; - } -} - -/// Matrix multiplication using memory coalescing algorithm with cube dimensions of size 16 -pub fn matmul_mem_coalescing_default( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, -) -> JitTensor { - matmul_simple::(lhs, rhs, out, PLANE_DIM_APPROX, PLANE_DIM_APPROX) -} - -/// Matrix multiplication using memory coalescing algorithm with custom cube dimensions -pub fn matmul_simple( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, - cube_dim_x: usize, - cube_dim_y: usize, -) -> JitTensor { - lhs.assert_is_on_same_device(&rhs); - let ndims = lhs.shape.num_dims(); - let lhs = into_contiguous(lhs); - - let rhs_original_shape = rhs.shape.clone(); - // we swap the dimensions to achieve memory-coalescing: - // consecutive elements of a column in the original rhs tensor will now be stored - // consecutively in memory, which allows to fetch them with fewer memory instructions - let rhs = into_contiguous(swap_dims(rhs, ndims - 1, ndims - 2)); - - let cube_count = simple_cube_count( - &lhs.shape, - &rhs_original_shape, - &out.shape, - cube_dim_x, - cube_dim_y, - ); - - let vectorization_factor = match lhs.shape.dims[ndims - 1] % 4 == 0 { - true => 4, - false => 1, - }; - - unsafe { - matmul_kernel::launch_unchecked::( - &lhs.client, - cube_count, - CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1), - lhs.as_tensor_arg::(vectorization_factor), - TensorArg::from_raw_parts::( - &rhs.handle, - &rhs.strides, - &rhs_original_shape.dims, // We need the original shape. - vectorization_factor, - ), - out.as_tensor_arg::(1), - Some(ndims as u32 - 2), - ); - }; - - out -} diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index e49ea3154c..2b8050dc07 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -1,14 +1,14 @@ use core::marker::PhantomData; use burn_tensor::{Element, ElementConversion}; -use cubecl::tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTuner}; +use cubecl::{ + linalg::matmul::{kernels::tiling2d::Tiling2dConfig, Strategy}, + tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTuner}, +}; use crate::{ element::FloatElement, - kernel::{ - matmul::{cube_strategy, utils::init_matmul_output}, - prng::random_like_uniform, - }, + kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, ops::numeric::empty_device, tensor::JitTensor, tune_key::JitAutotuneKey, @@ -57,17 +57,17 @@ impl AutotuneOperationSet ); vec![ - Box::new(SimpleMatmul::::new( + Box::new(MatmulTiling2d::::new( lhs.clone(), rhs.clone(), out.clone(), )), - Box::new(SimpleMatmul16x16::::new( + Box::new(MatmulAccelerated::::new( lhs.clone(), rhs.clone(), out.clone(), )), - Box::new(MatmulCube::::new( + Box::new(MatmulSimple::::new( lhs.clone(), rhs.clone(), out.clone(), @@ -77,9 +77,9 @@ impl AutotuneOperationSet fn fastest(self: Box, fastest_index: usize) -> Box { match fastest_index { - 0 => Box::new(SimpleMatmul::::new(self.lhs, self.rhs, self.out)), - 1 => Box::new(SimpleMatmul16x16::::new(self.lhs, self.rhs, self.out)), - 2 => Box::new(MatmulCube::::new(self.lhs, self.rhs, self.out)), + 0 => Box::new(MatmulTiling2d::::new(self.lhs, self.rhs, self.out)), + 1 => Box::new(MatmulAccelerated::::new(self.lhs, self.rhs, self.out)), + 2 => Box::new(MatmulSimple::::new(self.lhs, self.rhs, self.out)), _ => panic!("Fastest index is out of bound"), } } @@ -89,8 +89,9 @@ impl AutotuneOperationSet pub fn matmul_autotune( lhs: JitTensor, rhs: JitTensor, + out: Option>, ) -> JitTensor { - let output = init_matmul_output::(&lhs, &rhs); + let output = out.unwrap_or_else(|| init_matmul_output::(&lhs, &rhs)); let client = lhs.client.clone(); @@ -137,24 +138,40 @@ macro_rules! matmul_tune_ops { }; } -// Potentially better for small matrices. +// Probably the fastest in the general case. matmul_tune_ops!( - SimpleMatmul, - crate::kernel::matmul::matmul_mem_coalescing_default:: + MatmulAccelerated, + |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { + cubecl::linalg::matmul::launch_ref::( + &Strategy::Accelerated, + &lhs.client, + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), + ); + } ); -// Potentially better for small matrices. -matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| { - crate::kernel::matmul::matmul_simple::(lhs, rhs, out, 16, 16) -}); +// Probably the fastest when tensor cores are not available. +matmul_tune_ops!( + MatmulTiling2d, + |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { + cubecl::linalg::matmul::launch_ref::( + &Strategy::Tiling2D(Tiling2dConfig::default()), + &lhs.client, + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), + ); + } +); -// Probably the fastest in the general case, without loop unrolling +// Probably the fastest for small matrices. matmul_tune_ops!( - MatmulCube, + MatmulSimple, |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { - let strategy = cube_strategy::(&lhs.client); cubecl::linalg::matmul::launch_ref::( - &strategy, + &Strategy::Simple, &lhs.client, &lhs.as_handle_ref(), &rhs.as_handle_ref(), diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index f97b1609ff..2dc8a4a6f2 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -162,7 +162,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - matmul::(lhs, rhs, MatmulStrategy::default()) + matmul::(lhs, rhs, None, MatmulStrategy::default()) ) } diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index ea629af3f5..d3975faad3 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -27,13 +27,13 @@ cubecl = { workspace = true, features = ["wgpu"] } burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ - "cubecl-wgpu", + "cubecl-wgpu", ] } [dev-dependencies] burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ - "export_tests", + "export_tests", ] } half = { workspace = true } paste = { workspace = true }