Skip to content

Commit

Permalink
Migrate/jit/prod (#1474)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Mar 15, 2024
1 parent cfc0a4d commit 41d01b8
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 105 deletions.
6 changes: 5 additions & 1 deletion crates/burn-jit/src/kernel/prng/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,14 @@ impl<E: JitElement> Prng<E> for Uniform<E> {
gpu!(scope, int_random = int_random ^ state_3);

let float_random = scope.create_local(Elem::Float);
let float_scale = scope.create_local(Elem::Float);
cast_uint_to_float(scope, int_random, float_random);
gpu!(scope, float_scale = cast(scale));

let uniform_float = scope.create_local(Elem::Float);
let uniform = scope.create_local(item);
gpu!(scope, uniform = float_random * scale);
gpu!(scope, uniform_float = float_random * float_scale);
gpu!(scope, uniform = cast(uniform_float));
gpu!(scope, uniform += lower_bound);

let write_index = scope.create_local(Elem::UInt);
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-jit/src/kernel/reduce/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
Runtime,
};

use super::{reduce_dim_naive, reduce_dim_shared, ArgMax, ArgMin, MeanDim, SumDim};
use super::{reduce_dim_naive, reduce_dim_shared, ArgMax, ArgMin, MeanDim, ProdDim, SumDim};

/// Specifies the reduce dim algorithm in use
pub trait ReduceDimAlgorithm<E: JitElement>: Send + Sync + 'static {
Expand Down Expand Up @@ -145,5 +145,6 @@ macro_rules! reduce_operation {
// Autotunable reduce operation variants
reduce_operation!(sum_dim, SumDim);
reduce_operation!(mean_dim, MeanDim);
reduce_operation!(prod_dim, ProdDim);
reduce_operation!(argmin, ArgMin);
reduce_operation!(argmax, ArgMax);
4 changes: 4 additions & 0 deletions crates/burn-jit/src/kernel/reduce/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ mod argmin_dim;
mod base;
mod mean_dim;
mod naive_reduce_shader;
mod prod;
mod prod_dim;
mod shared_reduce_shader;
mod sum;
mod sum_dim;
Expand All @@ -13,6 +15,8 @@ pub(crate) use argmin_dim::*;
pub use base::*;
pub(crate) use mean_dim::*;
pub use naive_reduce_shader::*;
pub use prod::*;
pub(crate) use prod_dim::*;
pub use shared_reduce_shader::*;
pub use sum::*;
pub(crate) use sum_dim::*;
Expand Down
14 changes: 14 additions & 0 deletions crates/burn-jit/src/kernel/reduce/prod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use crate::{element::JitElement, tensor::JitTensor, Runtime};
use burn_tensor::Shape;

use super::{prod_dim, ReduceStrategy};

/// Multiply all elements in the input buffer.
pub fn prod<R: Runtime, E: JitElement, const D: usize>(
input: JitTensor<R, E, D>,
strategy: ReduceStrategy,
) -> JitTensor<R, E, 1> {
let shape = Shape::new([input.shape.num_elements()]);
let input: JitTensor<R, E, 1> = JitTensor::new(input.client, input.device, shape, input.handle);
prod_dim(input, 0, strategy)
}
88 changes: 88 additions & 0 deletions crates/burn-jit/src/kernel/reduce/prod_dim.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use crate::{
codegen::dialect::gpu::{gpu, Item, Scope, Variable},
JitElement,
};

use super::ReduceDimAlgorithm;

pub(crate) struct ProdDim;

impl<E: JitElement> ReduceDimAlgorithm<E> for ProdDim {
type Accumulator = Variable;

fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
scope.create_with_value(1, output_item)
}

fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
gpu!(scope, accumulator *= value);
}

fn assign_naive(
scope: &mut Scope,
output: Variable,
accumulator: Variable,
_shape_reduce_dim: Variable,
) {
let id = Variable::Id;
gpu!(scope, output[id] = accumulator);
}

fn initialize_shared(
scope: &mut Scope,
shared_memory_size: u32,
write_position: Variable,
input_item: Item,
) -> Self::Accumulator {
let shared_memory = scope.create_shared(input_item, shared_memory_size);
let neutral_element = scope.create_with_value(1, shared_memory.item());
gpu!(scope, shared_memory[write_position] = neutral_element);
shared_memory
}

fn write_to_shared(
scope: &mut Scope,
shared_memory: Self::Accumulator,
write_position: Variable,
value: Self::Accumulator,
) {
let current_value = scope.create_local(value.item());
let computed = scope.create_local(value.item());
gpu!(scope, current_value = shared_memory[write_position]);
gpu!(scope, computed = current_value * value);
gpu!(scope, shared_memory[write_position] = computed);
}

fn read_from_input(
scope: &mut Scope,
input: Variable,
read_position: Variable,
_i: Variable,
) -> Self::Accumulator {
let value = scope.create_local(input.item());
gpu!(scope, value = input[read_position]);
value
}

fn read_from_shared(
scope: &mut Scope,
shared_memory: Self::Accumulator,
read_position: Variable,
) -> Self::Accumulator {
let read_value = scope.create_local(shared_memory.item());
gpu!(scope, read_value = shared_memory[read_position]);
read_value
}

fn assign_shared(
scope: &mut Scope,
shared_memory: Self::Accumulator,
output: Variable,
write_position: Variable,
_shape_reduce_dim: Variable,
) {
let final_value = scope.create_local(output.item());
gpu!(scope, final_value = shared_memory[0]);
gpu!(scope, output[write_position] = final_value);
}
}
11 changes: 11 additions & 0 deletions crates/burn-jit/src/ops/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,17 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
reduce::mean_dim(tensor, dim, Default::default())
}

fn float_prod<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
reduce::prod(tensor, Default::default())
}

fn float_prod_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
reduce::prod_dim(tensor, dim, Default::default())
}

fn float_to_full_precision<const D: usize>(
tensor: &FloatTensor<Self, D>,
) -> FloatTensor<FullPrecisionBackend<Self>, D> {
Expand Down
13 changes: 4 additions & 9 deletions crates/burn-jit/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,12 @@ impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
kernel::reduce::sum_dim(tensor, dim, Default::default())
}

fn int_prod<const D: usize>(_tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
// kernel::reduce::prod(tensor, Default::default())
todo!("prod for int tensor")
fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
kernel::reduce::prod(tensor, Default::default())
}

fn int_prod_dim<const D: usize>(
_tensor: IntTensor<Self, D>,
_dim: usize,
) -> IntTensor<Self, D> {
// kernel::reduce::prod_dim(tensor, dim, Default::default())
todo!("prod_dim for int tensor")
fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
kernel::reduce::prod_dim(tensor, dim, Default::default())
}

fn int_mean_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
Expand Down

This file was deleted.

38 changes: 37 additions & 1 deletion crates/burn-jit/src/tests/reduce.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#[burn_tensor_testgen::testgen(reduction)]
mod reduction {
use super::*;
use burn_jit::kernel::reduce::{argmax, argmin, mean_dim, sum, sum_dim, ReduceStrategy};
use burn_jit::kernel::reduce::{
argmax, argmin, mean_dim, prod, prod_dim, sum, sum_dim, ReduceStrategy,
};
use burn_tensor::{ops::IntTensorOps, Data, Distribution, Int, Shape, Tensor};

#[test]
Expand All @@ -22,6 +24,24 @@ mod reduction {
val_ref.into_data().assert_approx_eq(&val.into_data(), 2);
}

#[test]
fn reduction_prod_dim_should_work_with_multiple_invocations() {
let tensor =
Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default, &Default::default());
let tensor_ref =
Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data(), &Default::default());
let reduce_dim = 1;

let val = Tensor::<TestBackend, 2>::from_primitive(prod_dim::<TestRuntime, f32, f32, 2>(
tensor.into_primitive(),
reduce_dim,
ReduceStrategy::Naive,
));
let val_ref = tensor_ref.prod_dim(1);

val_ref.into_data().assert_approx_eq(&val.into_data(), 2);
}

#[test]
fn reduction_argmin_dim_should_work_with_multiple_invocations() {
let tensor =
Expand Down Expand Up @@ -237,6 +257,22 @@ mod reduction {
val_ref.into_data().assert_approx_eq(&val.into_data(), 3);
}

#[test]
fn reduction_prod_should_work_with_multiple_invocations() {
let tensor =
Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default, &Default::default());
let tensor_ref =
Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data(), &Default::default());

let val = Tensor::<TestBackend, 1>::from_primitive(prod(
tensor.into_primitive(),
ReduceStrategy::default(),
));
let val_ref = tensor_ref.prod();

val_ref.into_data().assert_approx_eq(&val.into_data(), 3);
}

#[test]
fn reduction_argmax_shared_memory_extreme_values_float() {
let data: Data<f32, 1> = Data::from([-999999., -999997., -999998.]);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-tensor/src/tests/ops/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ mod tests {
let data_actual = tensor.prod().to_data();

// 2 * 1 * 2 * 3 * 4 * 5 = 240 but we need to check the precision because of the float
Data::from([240.0]).assert_approx_eq(&data_actual, 4);
Data::from([240.0]).assert_approx_eq(&data_actual, 3);

let tensor_with_zero = TestTensor::from([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]);
let data_actual = tensor_with_zero.prod().to_data();
Expand Down

0 comments on commit 41d01b8

Please sign in to comment.