diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index c0129dac17..f80c1d5343 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -4,7 +4,28 @@ use syn::{PathArguments, Stmt}; use crate::VariableKey; -pub const KEYWORDS: [&str; 1] = ["ABSOLUTE_INDEX"]; +pub const KEYWORDS: [&str; 20] = [ + "ABSOLUTE_POS", + "ABSOLUTE_POS_X", + "ABSOLUTE_POS_Y", + "ABSOLUTE_POS_Z", + "UNIT_POS", + "UNIT_POS_X", + "UNIT_POS_Y", + "UNIT_POS_Z", + "CUBE_POS", + "CUBE_POS_X", + "CUBE_POS_Y", + "CUBE_POS_Z", + "CUBE_DIM", + "CUBE_DIM_X", + "CUBE_DIM_Y", + "CUBE_DIM_Z", + "CUBE_COUNT", + "CUBE_COUNT_X", + "CUBE_COUNT_Y", + "CUBE_COUNT_Z", +]; #[derive(Debug)] /// Information about a single variable's use in Cube code diff --git a/crates/burn-cube-macros/src/codegen/branch.rs b/crates/burn-cube-macros/src/codegen/branch.rs index 950c05e6e7..5533f60900 100644 --- a/crates/burn-cube-macros/src/codegen/branch.rs +++ b/crates/burn-cube-macros/src/codegen/branch.rs @@ -30,19 +30,33 @@ pub(crate) fn codegen_for_loop( }; if &func_name.to_string() == "range" { - let mut args = quote::quote! { - context, - }; - - for argument in call.args.iter() { - let arg = codegen_expr(argument, loop_level, variable_analyses); - args.extend(quote::quote! { #arg, }); - } + let mut args = call.args.clone(); + + let unroll = codegen_expr( + &args.pop().unwrap().into_value(), + loop_level, + variable_analyses, + ); + let end = codegen_expr( + &args.pop().unwrap().into_value(), + loop_level, + variable_analyses, + ); + let start = codegen_expr( + &args.pop().unwrap().into_value(), + loop_level, + variable_analyses, + ); let block = codegen_block(&for_loop.body, loop_level + 1, variable_analyses); quote::quote! { - burn_cube::branch::range_expand(#args |context, #i| #block); + { + let _start = #start; + let _end = #end; + let _unroll = #unroll; + burn_cube::branch::range_expand(context, _start, _end, _unroll, |context, #i| #block); + } } } else { todo!("Codegen: Only range is supported") diff --git a/crates/burn-cube/src/codegen/dialect/variable.rs b/crates/burn-cube/src/codegen/dialect/variable.rs index 5b748b974b..a22401907e 100644 --- a/crates/burn-cube/src/codegen/dialect/variable.rs +++ b/crates/burn-cube/src/codegen/dialect/variable.rs @@ -17,6 +17,7 @@ pub enum Variable { LocalInvocationIdX, LocalInvocationIdY, LocalInvocationIdZ, + WorkgroupId, WorkgroupIdX, WorkgroupIdY, WorkgroupIdZ, @@ -24,9 +25,11 @@ pub enum Variable { GlobalInvocationIdY, GlobalInvocationIdZ, Rank, + WorkgroupSize, WorkgroupSizeX, WorkgroupSizeY, WorkgroupSizeZ, + NumWorkgroups, NumWorkgroupsX, NumWorkgroupsY, NumWorkgroupsZ, @@ -61,6 +64,9 @@ impl Variable { Variable::NumWorkgroupsX => None, Variable::NumWorkgroupsY => None, Variable::NumWorkgroupsZ => None, + Variable::WorkgroupId => None, + Variable::NumWorkgroups => None, + Variable::WorkgroupSize => None, } } @@ -93,6 +99,9 @@ impl Variable { Variable::NumWorkgroupsX => Item::new(Elem::UInt), Variable::NumWorkgroupsY => Item::new(Elem::UInt), Variable::NumWorkgroupsZ => Item::new(Elem::UInt), + Variable::WorkgroupId => Item::new(Elem::UInt), + Variable::NumWorkgroups => Item::new(Elem::UInt), + Variable::WorkgroupSize => Item::new(Elem::UInt), } } } diff --git a/crates/burn-cube/src/codegen/dialect/vectorization.rs b/crates/burn-cube/src/codegen/dialect/vectorization.rs index 5aa00aac03..a7b3c9291f 100644 --- a/crates/burn-cube/src/codegen/dialect/vectorization.rs +++ b/crates/burn-cube/src/codegen/dialect/vectorization.rs @@ -151,6 +151,9 @@ impl Variable { Variable::NumWorkgroupsX => *self, Variable::NumWorkgroupsY => *self, Variable::NumWorkgroupsZ => *self, + Variable::WorkgroupId => *self, + Variable::NumWorkgroups => *self, + Variable::WorkgroupSize => *self, } } } diff --git a/crates/burn-cube/src/language/element/uint.rs b/crates/burn-cube/src/language/element/uint.rs index 4f8b09a926..8593d11bbb 100644 --- a/crates/burn-cube/src/language/element/uint.rs +++ b/crates/burn-cube/src/language/element/uint.rs @@ -1,6 +1,6 @@ use crate::dialect::{Elem, Variable, Vectorization}; use crate::language::{CubeContext, CubeElem, CubeType, ExpandElement, Numeric}; -use crate::{ArgSettings, KernelLauncher, LaunchArg, Runtime}; +use crate::{ArgSettings, Comptime, KernelLauncher, LaunchArg, Runtime}; #[derive(Clone, Copy, Debug)] /// An unsigned int. @@ -68,6 +68,12 @@ impl From for UInt { } } +impl From> for UInt { + fn from(value: Comptime) -> Self { + UInt::new(value.inner) + } +} + impl From for UInt { fn from(value: usize) -> Self { UInt::new(value as u32) diff --git a/crates/burn-cube/src/language/mod.rs b/crates/burn-cube/src/language/mod.rs index 8efdef1feb..0b653ec53a 100644 --- a/crates/burn-cube/src/language/mod.rs +++ b/crates/burn-cube/src/language/mod.rs @@ -6,6 +6,7 @@ mod context; mod element; mod indexation; mod operation; +pub mod synchronization; mod topology; pub use comptime::*; diff --git a/crates/burn-cube/src/language/synchronization.rs b/crates/burn-cube/src/language/synchronization.rs new file mode 100644 index 0000000000..06e5377fda --- /dev/null +++ b/crates/burn-cube/src/language/synchronization.rs @@ -0,0 +1,7 @@ +use crate::{dialect::Synchronization, CubeContext}; + +pub fn sync_units() {} + +pub fn sync_units_expand(context: &mut CubeContext) { + context.register(Synchronization::WorkgroupBarrier) +} diff --git a/crates/burn-cube/src/language/topology.rs b/crates/burn-cube/src/language/topology.rs index 0ea04cfc95..c5b1fd0592 100644 --- a/crates/burn-cube/src/language/topology.rs +++ b/crates/burn-cube/src/language/topology.rs @@ -1,17 +1,181 @@ -use crate::UInt; +//! In this file we use a trick where the constant has the same name as the module containing +//! the expand function, so that a user implicitly imports the expand function when importing the constant. -/// In this file we use a trick where the constant has the same name as the module containing -/// the expand function, so that a user implicitly imports the expand function when importing the constant. +use crate::UInt; -/// The index of the working unit in the whole cube kernel, without regards to blocks. -pub const ABSOLUTE_INDEX: UInt = UInt::new(0u32); +macro_rules! constant { + ($ident:ident, $var:expr, $doc:expr) => { + #[doc = $doc] + pub const $ident: UInt = UInt::new(0u32); -#[allow(non_snake_case)] -pub mod ABSOLUTE_INDEX { - use crate::{CubeContext, ExpandElement}; + #[allow(non_snake_case)] + #[doc = $doc] + pub mod $ident { + use crate::{CubeContext, ExpandElement}; - /// Expanded version of ABSOLUTE_INDEX - pub fn expand(_context: &mut CubeContext) -> ExpandElement { - ExpandElement::Plain(crate::dialect::Variable::Id) - } + /// Expansion of the constant variable. + pub fn expand(_context: &mut CubeContext) -> ExpandElement { + ExpandElement::Plain($var) + } + } + }; } + +constant!( + UNIT_POS, + crate::dialect::Variable::LocalInvocationIndex, + r" +The position of the working unit inside the cube, without regards to axis. +" +); + +constant!( + UNIT_POS_X, + crate::dialect::Variable::LocalInvocationIdX, + r" +The position of the working unit inside the cube along the X axis. +" +); + +constant!( + UNIT_POS_Y, + crate::dialect::Variable::LocalInvocationIdY, + r" +The position of the working unit inside the cube along the Y axis. +" +); + +constant!( + UNIT_POS_Z, + crate::dialect::Variable::LocalInvocationIdZ, + r" +The position of the working unit inside the cube along the Z axis. +" +); + +constant!( + CUBE_DIM, + crate::dialect::Variable::WorkgroupSize, + r" +The total amount of working units in a cube. +" +); + +constant!( + CUBE_DIM_X, + crate::dialect::Variable::WorkgroupSizeX, + r" +The dimension of the cube along the X axis. +" +); + +constant!( + CUBE_DIM_Y, + crate::dialect::Variable::WorkgroupSizeY, + r" +The dimension of the cube along the Y axis. +" +); + +constant!( + CUBE_DIM_Z, + crate::dialect::Variable::WorkgroupSizeZ, + r" +The dimension of the cube along the Z axis. +" +); + +constant!( + CUBE_POS, + crate::dialect::Variable::WorkgroupId, + r" +The cube position, without regards to axis. +" +); + +constant!( + CUBE_POS_X, + crate::dialect::Variable::WorkgroupIdX, + r" +The cube position along the X axis. +" +); + +constant!( + CUBE_POS_Y, + crate::dialect::Variable::WorkgroupIdY, + r" +The cube position along the Y axis. +" +); + +constant!( + CUBE_POS_Z, + crate::dialect::Variable::WorkgroupIdZ, + r" +The cube position along the Z axis. +" +); +constant!( + CUBE_COUNT, + crate::dialect::Variable::NumWorkgroups, + r" +The number of cubes launched. +" +); + +constant!( + CUBE_COUNT_X, + crate::dialect::Variable::NumWorkgroupsX, + r" +The number of cubes launched along the X axis. +" +); + +constant!( + CUBE_COUNT_Y, + crate::dialect::Variable::NumWorkgroupsY, + r" +The number of cubes launched along the Y axis. +" +); + +constant!( + CUBE_COUNT_Z, + crate::dialect::Variable::NumWorkgroupsZ, + r" +The number of cubes launched along the Z axis. +" +); + +constant!( + ABSOLUTE_POS, + crate::dialect::Variable::Id, + r" +The position of the working unit in the whole cube kernel, without regards to cubes and axis. +" +); + +constant!( + ABSOLUTE_POS_X, + crate::dialect::Variable::GlobalInvocationIdX, + r" +The index of the working unit in the whole cube kernel along the X axis, without regards to cubes. +" +); + +constant!( + ABSOLUTE_POS_Y, + crate::dialect::Variable::GlobalInvocationIdY, + r" +The index of the working unit in the whole cube kernel along the Y axis, without regards to cubes. +" +); + +constant!( + ABSOLUTE_POS_Z, + crate::dialect::Variable::GlobalInvocationIdZ, + r" +The index of the working unit in the whole cube kernel along the Z axis, without regards to cubes. +" +); diff --git a/crates/burn-cube/tests/language/topology.rs b/crates/burn-cube/tests/language/topology.rs index 35a5b7c668..c2d9a7ddec 100644 --- a/crates/burn-cube/tests/language/topology.rs +++ b/crates/burn-cube/tests/language/topology.rs @@ -1,8 +1,8 @@ -use burn_cube::{cube, Numeric, Tensor, UInt, ABSOLUTE_INDEX}; +use burn_cube::{cube, Numeric, Tensor, UInt, ABSOLUTE_POS}; #[cube] fn topology_kernel(input: Tensor) { - let x = ABSOLUTE_INDEX + UInt::new(4); + let x = ABSOLUTE_POS + UInt::new(4); let _ = input[x]; } diff --git a/crates/burn-cuda/src/compiler/base.rs b/crates/burn-cuda/src/compiler/base.rs index a9959bb5e9..9b5ffff549 100644 --- a/crates/burn-cuda/src/compiler/base.rs +++ b/crates/burn-cuda/src/compiler/base.rs @@ -368,6 +368,9 @@ impl CudaCompiler { } super::Variable::LocalArray(id, item, depth, size) } + gpu::Variable::WorkgroupId => todo!(), + gpu::Variable::WorkgroupSize => todo!(), + gpu::Variable::NumWorkgroups => todo!(), } } diff --git a/crates/burn-jit/src/kernel/conv/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d.rs index b218d920c5..7345fd0388 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d.rs @@ -31,7 +31,7 @@ fn kernel( kernel_size_0_unroll: Comptime>, kernel_size_1_unroll: Comptime>, ) { - if ABSOLUTE_INDEX >= output.len() { + if ABSOLUTE_POS >= output.len() { return; } @@ -42,10 +42,10 @@ fn kernel( let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3)); let unroll_1 = Comptime::is_some(kernel_size_1_unroll); - let b = ABSOLUTE_INDEX / output.stride(0) % output.shape(0); - let oc = ABSOLUTE_INDEX / output.stride(1) % output.shape(1); - let oh = ABSOLUTE_INDEX / output.stride(2) % output.shape(2); - let ow = ABSOLUTE_INDEX / output.stride(3) % output.shape(3); + let b = ABSOLUTE_POS / output.stride(0) % output.shape(0); + let oc = ABSOLUTE_POS / output.stride(1) % output.shape(1); + let oh = ABSOLUTE_POS / output.stride(2) % output.shape(2); + let ow = ABSOLUTE_POS / output.stride(3) % output.shape(3); let g = (weight.shape(0) + oc) % groups; let ic_start = in_channels * g; @@ -107,7 +107,7 @@ fn kernel( } } - output[ABSOLUTE_INDEX] = sum; + output[ABSOLUTE_POS] = sum; } pub(crate) fn conv2d( diff --git a/crates/burn-jit/src/tests/conv2d.rs b/crates/burn-jit/src/tests/conv2d.rs index 1430b73d8d..91a4c5c643 100644 --- a/crates/burn-jit/src/tests/conv2d.rs +++ b/crates/burn-jit/src/tests/conv2d.rs @@ -12,9 +12,11 @@ mod tests { Tensor::::random([12, 8, 3, 3], Distribution::Default, &test_device); let bias = Tensor::::random([12], Distribution::Default, &test_device); let ref_device = Default::default(); + let input_ref = Tensor::::from_data(input.to_data(), &ref_device); let weight_ref = Tensor::::from_data(weight.to_data(), &ref_device); let bias_ref = Tensor::::from_data(bias.to_data(), &ref_device); + let options = burn_tensor::ops::ConvOptions::new([2, 3], [2, 3], [2, 3], 2); let output = module::conv2d(input, weight, Some(bias), options.clone()); diff --git a/crates/burn-wgpu/src/compiler/wgsl/base.rs b/crates/burn-wgpu/src/compiler/wgsl/base.rs index 27e6bb2723..d069d2d802 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/base.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/base.rs @@ -25,15 +25,18 @@ pub enum Variable { LocalInvocationIdY, LocalInvocationIdZ, Rank, + WorkgroupId, WorkgroupIdX, WorkgroupIdY, WorkgroupIdZ, GlobalInvocationIdX, GlobalInvocationIdY, GlobalInvocationIdZ, + WorkgroupSize, WorkgroupSizeX, WorkgroupSizeY, WorkgroupSizeZ, + NumWorkgroups, NumWorkgroupsX, NumWorkgroupsY, NumWorkgroupsZ, @@ -98,6 +101,9 @@ impl Variable { Variable::NumWorkgroupsX => true, Variable::NumWorkgroupsY => true, Variable::NumWorkgroupsZ => true, + Variable::WorkgroupId => true, + Variable::WorkgroupSize => true, + Variable::NumWorkgroups => true, } } pub fn index(&self, index: usize) -> IndexedVariable { @@ -131,15 +137,18 @@ impl Variable { elem, scope_depth: _, } => Item::Scalar(*elem), + Self::WorkgroupId => Item::Scalar(Elem::U32), Self::WorkgroupIdX => Item::Scalar(Elem::U32), Self::WorkgroupIdY => Item::Scalar(Elem::U32), Self::WorkgroupIdZ => Item::Scalar(Elem::U32), Self::GlobalInvocationIdX => Item::Scalar(Elem::U32), Self::GlobalInvocationIdY => Item::Scalar(Elem::U32), Self::GlobalInvocationIdZ => Item::Scalar(Elem::U32), + Self::WorkgroupSize => Item::Scalar(Elem::U32), Self::WorkgroupSizeX => Item::Scalar(Elem::U32), Self::WorkgroupSizeY => Item::Scalar(Elem::U32), Self::WorkgroupSizeZ => Item::Scalar(Elem::U32), + Self::NumWorkgroups => Item::Scalar(Elem::U32), Self::NumWorkgroupsX => Item::Scalar(Elem::U32), Self::NumWorkgroupsY => Item::Scalar(Elem::U32), Self::NumWorkgroupsZ => Item::Scalar(Elem::U32), @@ -234,6 +243,7 @@ impl Display for Variable { Variable::LocalInvocationIdY => f.write_str("local_invocation_id.y"), Variable::LocalInvocationIdZ => f.write_str("local_invocation_id.z"), Variable::Rank => f.write_str("rank"), + Variable::WorkgroupId => f.write_str("workgroup_id_no_axis"), Variable::WorkgroupIdX => f.write_str("workgroup_id.x"), Variable::WorkgroupIdY => f.write_str("workgroup_id.y"), Variable::WorkgroupIdZ => f.write_str("workgroup_id.z"), @@ -246,6 +256,8 @@ impl Display for Variable { Variable::NumWorkgroupsX => f.write_str("num_workgroups.x"), Variable::NumWorkgroupsY => f.write_str("num_workgroups.y"), Variable::NumWorkgroupsZ => f.write_str("num_workgroups.z"), + Variable::WorkgroupSize => f.write_str("workgroup_size_no_axis"), + Variable::NumWorkgroups => f.write_str("num_workgroups_no_axis"), } } } diff --git a/crates/burn-wgpu/src/compiler/wgsl/body.rs b/crates/burn-wgpu/src/compiler/wgsl/body.rs index dfa11638c1..debb734045 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/body.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/body.rs @@ -18,7 +18,7 @@ impl Display for Body { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if self.id { f.write_str( - "let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n", + "let id = (global_id.z * num_workgroups.x * WORKGROUP_SIZE_X * num_workgroups.y * WORKGROUP_SIZE_Y) + (global_id.y * num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n", )?; } if self.rank || self.stride || self.shape { diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index af2a43a994..0b97b02f5a 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -17,6 +17,9 @@ pub struct WgslCompiler { stride: bool, shape: bool, num_workgroups: bool, + workgroup_id_no_axis: bool, + workgroup_size_no_axis: bool, + num_workgroup_no_axis: bool, shared_memories: Vec, local_arrays: Vec, } @@ -81,10 +84,16 @@ impl WgslCompiler { global_invocation_id: self.global_invocation_id || self.id, local_invocation_index: self.local_invocation_index, local_invocation_id: self.local_invocation_id, - num_workgroups: self.id || self.num_workgroups, - workgroup_id: self.workgroup_id, + num_workgroups: self.id + || self.num_workgroups + || self.num_workgroup_no_axis + || self.workgroup_id_no_axis, + workgroup_id: self.workgroup_id || self.workgroup_id_no_axis, body, extensions, + num_workgroups_no_axis: self.num_workgroup_no_axis, + workgroup_id_no_axis: self.workgroup_id_no_axis, + workgroup_size_no_axis: self.workgroup_size_no_axis, } } @@ -219,6 +228,18 @@ impl WgslCompiler { self.num_workgroups = true; wgsl::Variable::NumWorkgroupsZ } + cube::Variable::WorkgroupId => { + self.workgroup_id_no_axis = true; + wgsl::Variable::WorkgroupId + } + cube::Variable::WorkgroupSize => { + self.workgroup_size_no_axis = true; + wgsl::Variable::WorkgroupSize + } + cube::Variable::NumWorkgroups => { + self.num_workgroup_no_axis = true; + wgsl::Variable::NumWorkgroups + } } } diff --git a/crates/burn-wgpu/src/compiler/wgsl/shader.rs b/crates/burn-wgpu/src/compiler/wgsl/shader.rs index 5bd7847b2f..5e8b461cd4 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/shader.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/shader.rs @@ -73,6 +73,9 @@ pub struct ComputeShader { pub local_invocation_id: bool, pub num_workgroups: bool, pub workgroup_id: bool, + pub num_workgroups_no_axis: bool, + pub workgroup_id_no_axis: bool, + pub workgroup_size_no_axis: bool, pub body: Body, pub extensions: Vec, } @@ -146,6 +149,18 @@ fn main( } // Body + if self.workgroup_id_no_axis { + f.write_str("let workgroup_id_no_axis = (num_workgroups.y * num_workgroups.x * workgroup_id.z) + (num_workgroups.x * workgroup_id.y) + workgroup_id.x;\n")?; + } + + if self.workgroup_size_no_axis { + f.write_str("let workgroup_size_no_axis = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y * WORKGROUP_SIZE_Z;\n")?; + } + + if self.num_workgroups_no_axis { + f.write_str("let num_workgroups_no_axis = num_workgroups.x * num_workgroups.y * num_workgroups.z;\n")?; + } + f.write_fmt(format_args!("{}", self.body))?; // Close body