-
Notifications
You must be signed in to change notification settings - Fork 478
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
166 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
use crate::{element::JitElement, tensor::JitTensor, Runtime}; | ||
use burn_tensor::Shape; | ||
|
||
use super::{prod_dim, ReduceStrategy}; | ||
|
||
/// Multiply all elements in the input buffer. | ||
pub fn prod<R: Runtime, E: JitElement, const D: usize>( | ||
input: JitTensor<R, E, D>, | ||
strategy: ReduceStrategy, | ||
) -> JitTensor<R, E, 1> { | ||
let shape = Shape::new([input.shape.num_elements()]); | ||
let input: JitTensor<R, E, 1> = JitTensor::new(input.client, input.device, shape, input.handle); | ||
prod_dim(input, 0, strategy) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
use crate::{ | ||
codegen::dialect::gpu::{gpu, Item, Scope, Variable}, | ||
JitElement, | ||
}; | ||
|
||
use super::ReduceDimAlgorithm; | ||
|
||
pub(crate) struct ProdDim; | ||
|
||
impl<E: JitElement> ReduceDimAlgorithm<E> for ProdDim { | ||
type Accumulator = Variable; | ||
|
||
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable { | ||
scope.create_with_value(1, output_item) | ||
} | ||
|
||
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) { | ||
gpu!(scope, accumulator *= value); | ||
} | ||
|
||
fn assign_naive( | ||
scope: &mut Scope, | ||
output: Variable, | ||
accumulator: Variable, | ||
_shape_reduce_dim: Variable, | ||
) { | ||
let id = Variable::Id; | ||
gpu!(scope, output[id] = accumulator); | ||
} | ||
|
||
fn initialize_shared( | ||
scope: &mut Scope, | ||
shared_memory_size: u32, | ||
write_position: Variable, | ||
input_item: Item, | ||
) -> Self::Accumulator { | ||
let shared_memory = scope.create_shared(input_item, shared_memory_size); | ||
let neutral_element = scope.create_with_value(1, shared_memory.item()); | ||
gpu!(scope, shared_memory[write_position] = neutral_element); | ||
shared_memory | ||
} | ||
|
||
fn write_to_shared( | ||
scope: &mut Scope, | ||
shared_memory: Self::Accumulator, | ||
write_position: Variable, | ||
value: Self::Accumulator, | ||
) { | ||
let current_value = scope.create_local(value.item()); | ||
let computed = scope.create_local(value.item()); | ||
gpu!(scope, current_value = shared_memory[write_position]); | ||
gpu!(scope, computed = current_value * value); | ||
gpu!(scope, shared_memory[write_position] = computed); | ||
} | ||
|
||
fn read_from_input( | ||
scope: &mut Scope, | ||
input: Variable, | ||
read_position: Variable, | ||
_i: Variable, | ||
) -> Self::Accumulator { | ||
let value = scope.create_local(input.item()); | ||
gpu!(scope, value = input[read_position]); | ||
value | ||
} | ||
|
||
fn read_from_shared( | ||
scope: &mut Scope, | ||
shared_memory: Self::Accumulator, | ||
read_position: Variable, | ||
) -> Self::Accumulator { | ||
let read_value = scope.create_local(shared_memory.item()); | ||
gpu!(scope, read_value = shared_memory[read_position]); | ||
read_value | ||
} | ||
|
||
fn assign_shared( | ||
scope: &mut Scope, | ||
shared_memory: Self::Accumulator, | ||
output: Variable, | ||
write_position: Variable, | ||
_shape_reduce_dim: Variable, | ||
) { | ||
let final_value = scope.create_local(output.item()); | ||
gpu!(scope, final_value = shared_memory[0]); | ||
gpu!(scope, output[write_position] = final_value); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 0 additions & 92 deletions
92
crates/burn-jit/src/template/reduction/reduce_dim_shared_memory.wgsl
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters