Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Jan 19, 2025
1 parent d7dedde commit 1e12738
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
32 changes: 27 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
#cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" }
#cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" }
### For local development. ###
cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ###
# cubecl = { version = "0.4.0", default-features = false }
# cubecl-common = { version = "0.4.0", default-features = false }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn_tensor::{ops::ConvTransposeOptions, ElementConversion, Shape};
use cubecl::tune::{local_tuner, tune_with, LocalTuner, TunableSet};
use cubecl::tune::{local_tuner, LocalTuner, TunableSet};

use crate::{
kernel::{
Expand Down Expand Up @@ -64,7 +64,7 @@ pub fn create_transpose2d_input<R: JitRuntime, E: FloatElement>(
let bias = key
.has_bias
.then(|| random_uniform(bias_shape, device, random_bounds.0, random_bounds.1));
tune_with!(input, weights, bias, options.clone())
(input, weights, bias, options.clone())
}

fn create_key<R: JitRuntime, E: FloatElement>(
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/matmul/tune/base.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use burn_tensor::{Element, ElementConversion};
use cubecl::{
linalg::matmul::{kernels::tiling2d::Tiling2dConfig, Strategy},
tune::{local_tuner, tune_with, LocalTuner, TunableSet},
tune::{local_tuner, LocalTuner, TunableSet},
};

use crate::{
Expand All @@ -27,7 +27,7 @@ fn matmul_input_gen<R: JitRuntime, E: FloatElement>(

let out = empty_device::<R, E>(out.client.clone(), out.device.clone(), out.shape.clone());

tune_with!(lhs, rhs, out)
(lhs, rhs, out)
}

/// Executes autotune on matmul operations
Expand Down

0 comments on commit 1e12738

Please sign in to comment.