Skip to content

Commit

Permalink
Migrate to type magic autotune
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Jan 16, 2025
1 parent 3990a8a commit 5ab59ce
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 388 deletions.
52 changes: 26 additions & 26 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" }
#cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" }
#cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ###
# cubecl = { version = "0.3.0", default-features = false }
# cubecl-common = { version = "0.3.0", default-features = false }
Expand Down
31 changes: 1 addition & 30 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ use cubecl::{
tile::{accelerated::Accelerated, TileMatmulFamily},
InvalidConfigError,
},
kernels::{matmul::AdvancedConfig, MatmulAvailabilityError},
kernels::matmul::AdvancedConfig,
},
prelude::*,
};

use super::{
base::{ConvolutionConfigFactory, ConvolutionFamily, ConvolutionProblem},
homogeneous::base::ImplicitGemmConvolutionFamily,
precision::ConvPrecision,
selection::ConvSelection,
};

Expand Down Expand Up @@ -47,34 +46,6 @@ pub trait Algorithm {
Self::GlobalConvolution::check_config(&config)?;
Ok(config)
}

/// Check availability of the matmul algorithm
fn check_availability<R: Runtime, CS: ConvPrecision>(
client: &ComputeClient<R::Server, R::Channel>,
config: &<Self::GlobalConvolution as ConvolutionConfigFactory>::Config,
) -> Result<(), MatmulAvailabilityError> {
Self::GlobalConvolution::check_availability::<R, CS>(client, config)
}

/// Determine whether the given convolution problem is valid to launch (within hardware limits)
fn can_launch<R: Runtime, CS: ConvPrecision>(
client: &ComputeClient<R::Server, R::Channel>,
problem: &ConvolutionProblem,
config: &<Self::GlobalConvolution as ConvolutionConfigFactory>::Config,
selection: &Self::Selection,
) -> bool {
if problem.options.groups > 1 || Self::check_availability::<R, CS>(client, config).is_err()
{
return false;
}

let cube_count = Self::cube_count(selection, problem);
let (max_x, max_y, max_z) = R::max_cube_count();
match cube_count {
CubeCount::Static(x, y, z) => x <= max_x && y <= max_y && z <= max_z,
_ => true,
}
}
}

/// Cmma convolution
Expand Down
8 changes: 1 addition & 7 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use cubecl::linalg::{
stage::{StageMatmul, StageMatmulFamily},
InvalidConfigError, MatmulProblem, MatrixLayout,
},
kernels::{matmul::AdvancedConfig, MatmulAvailabilityError},
kernels::matmul::AdvancedConfig,
},
tensor::{ReadWrite, VirtualTensor},
};
Expand Down Expand Up @@ -91,12 +91,6 @@ pub trait ConvolutionConfigFactory: Send + Sync + 'static {
/// Asserts that the configuration for this matmul will lead to a valid computation
fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError>;

/// Checks if the client can handle the features used in this computation
fn check_availability<R: Runtime, CS: ConvPrecision>(
client: &ComputeClient<R::Server, R::Channel>,
config: &Self::Config,
) -> Result<(), MatmulAvailabilityError>;

fn make_config(
input: Self::Input,
problem: &ConvolutionProblem,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use cubecl::{
},
Ident, InvalidConfigError, MatrixLayout, StageDim,
},
kernels::{matmul::AdvancedConfig, MatmulAvailabilityError},
kernels::matmul::AdvancedConfig,
},
tensor::{ReadWrite, VirtualTensor},
},
Expand Down Expand Up @@ -194,13 +194,6 @@ where
SMM::check_config(&config.to_smm_config())
}

fn check_availability<R: Runtime, CS: ConvPrecision>(
client: &ComputeClient<R::Server, R::Channel>,
config: &Self::Config,
) -> Result<(), MatmulAvailabilityError> {
SMM::check_availability::<R, (CS::EG, CS::ES, CS::EA)>(client, &config.to_smm_config())
}

fn make_config(
input: Self::Input,
problem: &ConvolutionProblem,
Expand Down
56 changes: 2 additions & 54 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use burn_tensor::{
use cubecl::{
flex32,
ir::{Elem, FloatKind},
linalg::matmul::{self, components::MatrixLayout},
linalg::matmul::{self},
tensor_line_size, tf32, Feature,
};
use half::{bf16, f16};
Expand All @@ -23,7 +23,7 @@ use crate::{
algorithm::{Algorithm, ImplicitCmmaConv},
base::{ConvolutionLaunch, ConvolutionProblem},
},
nchw_to_nhwc, Conv2dAutotuneKey, ConvLaunchError,
nchw_to_nhwc, ConvLaunchError,
},
into_contiguous,
},
Expand Down Expand Up @@ -226,58 +226,6 @@ where
Ok(permute(out, &[0, 3, 1, 2]))
}

pub fn problem_from_key<R: JitRuntime, F: FloatElement>(
key: &Conv2dAutotuneKey,
out_h: usize,
out_w: usize,
) -> ConvolutionProblem {
let in_stride_2 = key.in_channels;
let in_stride_1 = key.width * in_stride_2;
let in_stride_0 = key.height * in_stride_1;

let m = key.batch_size * out_h * out_w;
let n = key.out_channels;
let k = key.kernel_size[0] * key.kernel_size[1] * key.in_channels;

let options = ConvOptions {
stride: key.stride,
padding: key.padding,
dilation: key.dilation,
groups: key.groups,
};

// Target 128 bit accesses
let available_vectorizations = R::supported_line_sizes()
.iter()
.copied()
.filter(|it| *it as usize * size_of::<F>() <= 16)
.collect::<Vec<_>>();
let lhs_line_size = tensor_line_size(
&available_vectorizations,
&[key.batch_size, key.height, key.width, key.in_channels],
&[in_stride_0, in_stride_1, in_stride_2, 1],
3,
);
let rhs_line_size = tensor_line_size(&available_vectorizations, &[k, n], &[n, 1], 1);
let out_line_size = tensor_line_size(&available_vectorizations, &[m, n], &[n, 1], 1);

ConvolutionProblem {
m,
n,
k,
lhs_layout: MatrixLayout::RowMajor,
rhs_layout: MatrixLayout::RowMajor,
lhs_line_size,
rhs_line_size,
out_line_size,
kernel_size: (key.kernel_size[0] as u32, key.kernel_size[1] as u32),
options,
out_shape_y: out_h,
out_shape_x: out_w,
has_bias: key.has_bias,
}
}

pub(crate) fn has_tf32<R: JitRuntime>(c: &JitTensor<R>) -> bool {
c.client
.properties()
Expand Down
Loading

0 comments on commit 5ab59ce

Please sign in to comment.