Skip to content

Commit

Permalink
Cube: Topology constants (#1838)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: nathaniel <[email protected]>
  • Loading branch information
louisfd and nathanielsimard authored May 30, 2024
1 parent 0d4374c commit de0b49e
Show file tree
Hide file tree
Showing 16 changed files with 312 additions and 34 deletions.
23 changes: 22 additions & 1 deletion crates/burn-cube-macros/src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,28 @@ use syn::{PathArguments, Stmt};

use crate::VariableKey;

pub const KEYWORDS: [&str; 1] = ["ABSOLUTE_INDEX"];
pub const KEYWORDS: [&str; 20] = [
"ABSOLUTE_POS",
"ABSOLUTE_POS_X",
"ABSOLUTE_POS_Y",
"ABSOLUTE_POS_Z",
"UNIT_POS",
"UNIT_POS_X",
"UNIT_POS_Y",
"UNIT_POS_Z",
"CUBE_POS",
"CUBE_POS_X",
"CUBE_POS_Y",
"CUBE_POS_Z",
"CUBE_DIM",
"CUBE_DIM_X",
"CUBE_DIM_Y",
"CUBE_DIM_Z",
"CUBE_COUNT",
"CUBE_COUNT_X",
"CUBE_COUNT_Y",
"CUBE_COUNT_Z",
];

#[derive(Debug)]
/// Information about a single variable's use in Cube code
Expand Down
32 changes: 23 additions & 9 deletions crates/burn-cube-macros/src/codegen/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,33 @@ pub(crate) fn codegen_for_loop(
};

if &func_name.to_string() == "range" {
let mut args = quote::quote! {
context,
};

for argument in call.args.iter() {
let arg = codegen_expr(argument, loop_level, variable_analyses);
args.extend(quote::quote! { #arg, });
}
let mut args = call.args.clone();

let unroll = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_analyses,
);
let end = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_analyses,
);
let start = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_analyses,
);

let block = codegen_block(&for_loop.body, loop_level + 1, variable_analyses);

quote::quote! {
burn_cube::branch::range_expand(#args |context, #i| #block);
{
let _start = #start;
let _end = #end;
let _unroll = #unroll;
burn_cube::branch::range_expand(context, _start, _end, _unroll, |context, #i| #block);
}
}
} else {
todo!("Codegen: Only range is supported")
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-cube/src/codegen/dialect/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@ pub enum Variable {
LocalInvocationIdX,
LocalInvocationIdY,
LocalInvocationIdZ,
WorkgroupId,
WorkgroupIdX,
WorkgroupIdY,
WorkgroupIdZ,
GlobalInvocationIdX,
GlobalInvocationIdY,
GlobalInvocationIdZ,
Rank,
WorkgroupSize,
WorkgroupSizeX,
WorkgroupSizeY,
WorkgroupSizeZ,
NumWorkgroups,
NumWorkgroupsX,
NumWorkgroupsY,
NumWorkgroupsZ,
Expand Down Expand Up @@ -61,6 +64,9 @@ impl Variable {
Variable::NumWorkgroupsX => None,
Variable::NumWorkgroupsY => None,
Variable::NumWorkgroupsZ => None,
Variable::WorkgroupId => None,
Variable::NumWorkgroups => None,
Variable::WorkgroupSize => None,
}
}

Expand Down Expand Up @@ -93,6 +99,9 @@ impl Variable {
Variable::NumWorkgroupsX => Item::new(Elem::UInt),
Variable::NumWorkgroupsY => Item::new(Elem::UInt),
Variable::NumWorkgroupsZ => Item::new(Elem::UInt),
Variable::WorkgroupId => Item::new(Elem::UInt),
Variable::NumWorkgroups => Item::new(Elem::UInt),
Variable::WorkgroupSize => Item::new(Elem::UInt),
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-cube/src/codegen/dialect/vectorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ impl Variable {
Variable::NumWorkgroupsX => *self,
Variable::NumWorkgroupsY => *self,
Variable::NumWorkgroupsZ => *self,
Variable::WorkgroupId => *self,
Variable::NumWorkgroups => *self,
Variable::WorkgroupSize => *self,
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion crates/burn-cube/src/language/element/uint.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::dialect::{Elem, Variable, Vectorization};
use crate::language::{CubeContext, CubeElem, CubeType, ExpandElement, Numeric};
use crate::{ArgSettings, KernelLauncher, LaunchArg, Runtime};
use crate::{ArgSettings, Comptime, KernelLauncher, LaunchArg, Runtime};

#[derive(Clone, Copy, Debug)]
/// An unsigned int.
Expand Down Expand Up @@ -68,6 +68,12 @@ impl From<u32> for UInt {
}
}

impl From<Comptime<u32>> for UInt {
fn from(value: Comptime<u32>) -> Self {
UInt::new(value.inner)
}
}

impl From<usize> for UInt {
fn from(value: usize) -> Self {
UInt::new(value as u32)
Expand Down
1 change: 1 addition & 0 deletions crates/burn-cube/src/language/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod context;
mod element;
mod indexation;
mod operation;
pub mod synchronization;
mod topology;

pub use comptime::*;
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-cube/src/language/synchronization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use crate::{dialect::Synchronization, CubeContext};

pub fn sync_units() {}

pub fn sync_units_expand(context: &mut CubeContext) {
context.register(Synchronization::WorkgroupBarrier)
}
188 changes: 176 additions & 12 deletions crates/burn-cube/src/language/topology.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,181 @@
use crate::UInt;
//! In this file we use a trick where the constant has the same name as the module containing
//! the expand function, so that a user implicitly imports the expand function when importing the constant.
/// In this file we use a trick where the constant has the same name as the module containing
/// the expand function, so that a user implicitly imports the expand function when importing the constant.
use crate::UInt;

/// The index of the working unit in the whole cube kernel, without regards to blocks.
pub const ABSOLUTE_INDEX: UInt = UInt::new(0u32);
macro_rules! constant {
($ident:ident, $var:expr, $doc:expr) => {
#[doc = $doc]
pub const $ident: UInt = UInt::new(0u32);

#[allow(non_snake_case)]
pub mod ABSOLUTE_INDEX {
use crate::{CubeContext, ExpandElement};
#[allow(non_snake_case)]
#[doc = $doc]
pub mod $ident {
use crate::{CubeContext, ExpandElement};

/// Expanded version of ABSOLUTE_INDEX
pub fn expand(_context: &mut CubeContext) -> ExpandElement {
ExpandElement::Plain(crate::dialect::Variable::Id)
}
/// Expansion of the constant variable.
pub fn expand(_context: &mut CubeContext) -> ExpandElement {
ExpandElement::Plain($var)
}
}
};
}

constant!(
UNIT_POS,
crate::dialect::Variable::LocalInvocationIndex,
r"
The position of the working unit inside the cube, without regards to axis.
"
);

constant!(
UNIT_POS_X,
crate::dialect::Variable::LocalInvocationIdX,
r"
The position of the working unit inside the cube along the X axis.
"
);

constant!(
UNIT_POS_Y,
crate::dialect::Variable::LocalInvocationIdY,
r"
The position of the working unit inside the cube along the Y axis.
"
);

constant!(
UNIT_POS_Z,
crate::dialect::Variable::LocalInvocationIdZ,
r"
The position of the working unit inside the cube along the Z axis.
"
);

constant!(
CUBE_DIM,
crate::dialect::Variable::WorkgroupSize,
r"
The total amount of working units in a cube.
"
);

constant!(
CUBE_DIM_X,
crate::dialect::Variable::WorkgroupSizeX,
r"
The dimension of the cube along the X axis.
"
);

constant!(
CUBE_DIM_Y,
crate::dialect::Variable::WorkgroupSizeY,
r"
The dimension of the cube along the Y axis.
"
);

constant!(
CUBE_DIM_Z,
crate::dialect::Variable::WorkgroupSizeZ,
r"
The dimension of the cube along the Z axis.
"
);

constant!(
CUBE_POS,
crate::dialect::Variable::WorkgroupId,
r"
The cube position, without regards to axis.
"
);

constant!(
CUBE_POS_X,
crate::dialect::Variable::WorkgroupIdX,
r"
The cube position along the X axis.
"
);

constant!(
CUBE_POS_Y,
crate::dialect::Variable::WorkgroupIdY,
r"
The cube position along the Y axis.
"
);

constant!(
CUBE_POS_Z,
crate::dialect::Variable::WorkgroupIdZ,
r"
The cube position along the Z axis.
"
);
constant!(
CUBE_COUNT,
crate::dialect::Variable::NumWorkgroups,
r"
The number of cubes launched.
"
);

constant!(
CUBE_COUNT_X,
crate::dialect::Variable::NumWorkgroupsX,
r"
The number of cubes launched along the X axis.
"
);

constant!(
CUBE_COUNT_Y,
crate::dialect::Variable::NumWorkgroupsY,
r"
The number of cubes launched along the Y axis.
"
);

constant!(
CUBE_COUNT_Z,
crate::dialect::Variable::NumWorkgroupsZ,
r"
The number of cubes launched along the Z axis.
"
);

constant!(
ABSOLUTE_POS,
crate::dialect::Variable::Id,
r"
The position of the working unit in the whole cube kernel, without regards to cubes and axis.
"
);

constant!(
ABSOLUTE_POS_X,
crate::dialect::Variable::GlobalInvocationIdX,
r"
The index of the working unit in the whole cube kernel along the X axis, without regards to cubes.
"
);

constant!(
ABSOLUTE_POS_Y,
crate::dialect::Variable::GlobalInvocationIdY,
r"
The index of the working unit in the whole cube kernel along the Y axis, without regards to cubes.
"
);

constant!(
ABSOLUTE_POS_Z,
crate::dialect::Variable::GlobalInvocationIdZ,
r"
The index of the working unit in the whole cube kernel along the Z axis, without regards to cubes.
"
);
4 changes: 2 additions & 2 deletions crates/burn-cube/tests/language/topology.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use burn_cube::{cube, Numeric, Tensor, UInt, ABSOLUTE_INDEX};
use burn_cube::{cube, Numeric, Tensor, UInt, ABSOLUTE_POS};

#[cube]
fn topology_kernel<T: Numeric>(input: Tensor<T>) {
let x = ABSOLUTE_INDEX + UInt::new(4);
let x = ABSOLUTE_POS + UInt::new(4);
let _ = input[x];
}

Expand Down
3 changes: 3 additions & 0 deletions crates/burn-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@ impl CudaCompiler {
}
super::Variable::LocalArray(id, item, depth, size)
}
gpu::Variable::WorkgroupId => todo!(),
gpu::Variable::WorkgroupSize => todo!(),
gpu::Variable::NumWorkgroups => todo!(),
}
}

Expand Down
Loading

0 comments on commit de0b49e

Please sign in to comment.