Skip to content

Commit

Permalink
Update cubecl (#2680)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Jan 12, 2025
1 parent 5b3079a commit 51b742f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
24 changes: 12 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8244dbb4660e373ff1ffb780feb73a5b899e5977" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8244dbb4660e373ff1ffb780feb73a5b899e5977" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ fn deform_col2img_kernel<F: Float>(
offset: &Tensor<F>,
mask: &Tensor<F>,
columns: &Tensor<F>,
grad_input: &mut Tensor<AtomicU32>,
grad_input: &mut Tensor<Atomic<u32>>,
args: &DeformConv2dCol2ImgArgs,
#[comptime] use_mask: bool,
) {
Expand Down Expand Up @@ -589,14 +589,14 @@ fn deform_col2img_kernel<F: Float>(
}

#[cube]
fn float_atomic_add(ptr: &mut AtomicU32, value: f32) {
fn float_atomic_add(ptr: &mut Atomic<u32>, value: f32) {
if value != 0.0 {
let mut v = AtomicU32::load(ptr);
let mut v = Atomic::<u32>::load(ptr);
loop {
let prev = v;
let v_float = f32::bitcast_from(v);
let new = u32::bitcast_from(v_float + value);
v = AtomicU32::compare_and_swap(ptr, v, new);
v = Atomic::<u32>::compare_and_swap(ptr, v, new);
if prev == v {
break;
}
Expand Down

0 comments on commit 51b742f

Please sign in to comment.