From 0292967000635d9a5cbfca055dfdef3bb36692da Mon Sep 17 00:00:00 2001 From: AlteredOxide <143226375+alteredoxide@users.noreply.github.com> Date: Wed, 28 Aug 2024 04:51:19 -0700 Subject: [PATCH] Feature/codegen gather indices greater than rank 1 (#2199) * implemented muli-dim index for GatherNode The `NodeCodegen` impl for `GatherNode` now performs gather in complete accordance with the ONNX Gather spec. - a `gather` function was added to the gather.rs file - `gather()` is now called within the codegen instead of `tensor.select()` - a test with two test cases have been added - test axes 0 and 1 - both use 2D index tensors * add gather_onnx to numeric api Added int and float implementations of gather to the burn-tensor numeric api: - named the methods `gather_onnx` to not be confused with the current `gather` - these implementations follow the `Gather` ONNX spec Updated the gather*.py variants and their onnx outputs * modified files didn't end up in last commit * tests passing for onnx gather The implementation of gather for the ONNX `Gather` spec is tentatively complete: - py test models are updated - onnx_tests are modified and passing: `gather`, `gather_scalar`, and `gather_shape` - node/gather tests are passing NOTE: The two additional tests in crates/burn-import/src/burn/node/gather.rs that test the actual functionality of gather are likely to be deleted, since they are redundant to the tests in crates/burn-import/onnx-tests/tests/onnx_tests.rs. * inlined onnx gather within codegen * rm gather_onnx from public api; rm unnecessary tests * add comments to gather py models * some codegen changes; formatting to appease run-checks - Some necessary changes and improvements to the codegen inlined code after translating from public api (removed in previous commit). - Changed some formatting that run-checks complained about. * simplify gather codegen; include 1d and 2d onnx tests Modified the `Gather` codegen per requested changes: - combined match statements on index - remove use of `alloc::vec::Vec` - use map -> collect instead of procedural - include a 1d index gather onnx test - remove superflous tests * delete unused gather.onnx --- crates/burn-import/onnx-tests/build.rs | 3 +- .../onnx-tests/tests/gather/gather.onnx | 18 --- .../onnx-tests/tests/gather/gather.py | 47 ------ .../tests/gather/gather_1d_idx.onnx | Bin 0 -> 155 bytes .../onnx-tests/tests/gather/gather_1d_idx.py | 62 ++++++++ .../tests/gather/gather_2d_idx.onnx | Bin 0 -> 163 bytes .../onnx-tests/tests/gather/gather_2d_idx.py | 62 ++++++++ .../tests/gather/gather_scalar.onnx | Bin 181 -> 147 bytes .../onnx-tests/tests/gather/gather_scalar.py | 83 ++++++----- .../onnx-tests/tests/gather/gather_shape.py | 4 +- .../onnx-tests/tests/onnx_tests.rs | 23 ++- crates/burn-import/src/burn/node/gather.rs | 136 ++++++++++++++++-- crates/burn-tensor/src/tensor/api/check.rs | 1 + crates/onnx-ir/src/dim_inference.rs | 4 - 14 files changed, 321 insertions(+), 122 deletions(-) delete mode 100644 crates/burn-import/onnx-tests/tests/gather/gather.onnx delete mode 100644 crates/burn-import/onnx-tests/tests/gather/gather.py create mode 100644 crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.onnx create mode 100644 crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.py create mode 100644 crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.onnx create mode 100644 crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.py diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 2defcbe303..23312bec9c 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -32,7 +32,8 @@ fn main() { .input("tests/exp/exp.onnx") .input("tests/expand/expand.onnx") .input("tests/flatten/flatten.onnx") - .input("tests/gather/gather.onnx") + .input("tests/gather/gather_1d_idx.onnx") + .input("tests/gather/gather_2d_idx.onnx") .input("tests/gather/gather_scalar.onnx") .input("tests/gather/gather_shape.onnx") .input("tests/gather_elements/gather_elements.onnx") diff --git a/crates/burn-import/onnx-tests/tests/gather/gather.onnx b/crates/burn-import/onnx-tests/tests/gather/gather.onnx deleted file mode 100644 index 9589d8410e..0000000000 --- a/crates/burn-import/onnx-tests/tests/gather/gather.onnx +++ /dev/null @@ -1,18 +0,0 @@ -pytorch2.1.1:¤ -A -onnx::Gather_0 -onnx::Gather_12/Gather"Gather* -axis  -main_graphZ -onnx::Gather_0 -  - -Z -onnx::Gather_1 - - -b -2 -  - -B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/gather/gather.py b/crates/burn-import/onnx-tests/tests/gather/gather.py deleted file mode 100644 index 39688d34d6..0000000000 --- a/crates/burn-import/onnx-tests/tests/gather/gather.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python3 - -# used to generate model: onnx-tests/tests/gather/gather.onnx - -import torch -import torch.nn as nn - - -class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() - - def forward(self, x, index): - gathered = torch.index_select(x, 1, index) - return gathered - - -def main(): - # Set random seed for reproducibility - torch.manual_seed(0) - - # Export to onnx - model = Model() - model.eval() - device = torch.device("cpu") - onnx_name = "gather.onnx" - - dummy_input = torch.randn(2, 3, device=device) - dummy_index = torch.tensor([0, 2], device=device, dtype=torch.int64) - - torch.onnx.export(model, (dummy_input, dummy_index), onnx_name, - verbose=False, opset_version=16) - - print("Finished exporting model to {}".format(onnx_name)) - - # Output some test data for use in the test - test_input = torch.tensor([[1.0, 2.0, 3.0], - [4.0, 5.0, 6.0]]) - test_index = torch.tensor([0, 2], dtype=torch.int64) - - print("Test input data: {}, {}".format(test_input, test_index)) - output = model.forward(test_input, test_index) - print("Test output data: {}".format(output)) - - -if __name__ == '__main__': - main() diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.onnx b/crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.onnx new file mode 100644 index 0000000000000000000000000000000000000000..97b0ddefe9ac339013d5310530838e57624898aa GIT binary patch literal 155 zcmd-OF#lr?E3DBB^jwjN^B5Xi<^rju_Cirf^h*O zlMq*KVrE`^dQoCQMwA515FtJ;9u7t!4lX7RCLm^x5`i1b#l^wFF2n+oNs{D(S%Rj4 L$%%!FK|lZio0A|U literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.py b/crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.py new file mode 100644 index 0000000000..b4e4a3bd1e --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/gather/gather.onnx + +# There is no current support for `Split`, and the `for` loop over the indices +# results in a `Split` node in the ONNX model. +# Therefore, this model is built and exported using ONNX directly. + +import onnx + + +def build_model(): + return onnx.helper.make_model( + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 16)], + graph=onnx.helper.make_graph(name="main_graph", nodes=[ + onnx.helper.make_node( + "Gather", + inputs=["input1", "input2"], + outputs=["output1"], + name="/Gather", + axis=1 + ), + ], + inputs=[ + onnx.helper.make_value_info( + name="input1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.FLOAT, shape=[2, 3] + ), + ), + onnx.helper.make_value_info( + name="input2", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT64, shape=[2] + ), + ), + + ], + outputs=[ + onnx.helper.make_value_info( + name="output1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.FLOAT, shape=[2, 2] + ), + ) + ]), + ) + + +def main(): + onnx_model = build_model() + file_name = "gather_1d_idx.onnx" + + # Ensure valid ONNX: + onnx.checker.check_model(onnx_model) + + onnx.save(onnx_model, file_name) + + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.onnx b/crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ff64b029db1e83573e87de419016e1efd1cf6496 GIT binary patch literal 163 zcmd-OF#lr?E3DBB^jwjN^B5Xi<^rju_Cirf?)w8 zlMq*KVrE`^dQoCQMwA515FtJ;9u7t!4lX7RCLm^p%NT)W*x@ovNwQopi-d%@1b~`& PpqfA$nVeX-7z6|Wmm?r| literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.py b/crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.py new file mode 100644 index 0000000000..767168eefb --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/gather/gather.onnx + +# There is no current support for `Split`, and the `for` loop over the indices +# results in a `Split` node in the ONNX model. +# Therefore, this model is built and exported using ONNX directly. + +import onnx + + +def build_model(): + return onnx.helper.make_model( + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 16)], + graph=onnx.helper.make_graph(name="main_graph", nodes=[ + onnx.helper.make_node( + "Gather", + inputs=["input1", "input2"], + outputs=["output1"], + name="/Gather", + axis=0 + ), + ], + inputs=[ + onnx.helper.make_value_info( + name="input1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.FLOAT, shape=[2, 3] + ), + ), + onnx.helper.make_value_info( + name="input2", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT64, shape=[2, 2] + ), + ), + + ], + outputs=[ + onnx.helper.make_value_info( + name="output1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.FLOAT, shape=[2, 2, 2] + ), + ) + ]), + ) + + +def main(): + onnx_model = build_model() + file_name = "gather_2d_idx.onnx" + + # Ensure valid ONNX: + onnx.checker.check_model(onnx_model) + + onnx.save(onnx_model, file_name) + + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_scalar.onnx b/crates/burn-import/onnx-tests/tests/gather/gather_scalar.onnx index 7afca04a85657476627ccbfe5a7514a29271cc9c..a8a5586c026f2d73c5c80e87aa3d1d5111483961 100644 GIT binary patch delta 113 zcmdnWIGIs_gTtzWk;|NmEi = gather::Model::default(); + fn gather_1d_idx() { + let model: gather_1d_idx::Model = gather_1d_idx::Model::default(); let device = Default::default(); let input = Tensor::::from_floats([[1., 2., 3.], [4., 5., 6.]], &device); let index = Tensor::::from_ints([0, 2], &device); - let output = model.forward(input, index); let expected = TensorData::from([[1f32, 3.], [4., 6.]]); + let output = model.forward(input, index); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn gather_2d_idx() { + let model: gather_2d_idx::Model = gather_2d_idx::Model::default(); + + let device = Default::default(); + + let input = Tensor::::from_data([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], &device); + let index = Tensor::::from_data([[0, 1], [1, 2]], &device); + let expected = TensorData::from([[[1f32, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]]); + let output = model.forward(input, index); assert_eq!(output.to_data(), expected); } diff --git a/crates/burn-import/src/burn/node/gather.rs b/crates/burn-import/src/burn/node/gather.rs index b79757518c..a9b66e1d0c 100644 --- a/crates/burn-import/src/burn/node/gather.rs +++ b/crates/burn-import/src/burn/node/gather.rs @@ -27,6 +27,12 @@ impl NodeCodegen for GatherNode { node_position: usize, ) -> proc_macro2::TokenStream { let dim = self.dim.to_tokens(); + let input_rank = match &self.input { + Type::Tensor(in_tensor) => in_tensor.dim, + Type::Shape(_) => 1, + _ => panic!("Gather needs Tensor or Shape input, got {:?}!", self.input), + }; + let input = match &self.input { Type::Tensor(in_tensor) => scope.tensor_use_owned(in_tensor, node_position), Type::Shape(in_shape) => { @@ -34,10 +40,11 @@ impl NodeCodegen for GatherNode { // To copy just the values from the shape value without moving it // (which could lead to ownership problems if the same Shape is used multiple times) // borrow the array as a slice and use that to create the Tensor: - quote! { Tensor::from_data(&#in_shape_name as &[_], &*self.device) } + quote! { Tensor::::from_data(&#in_shape_name as &[_], &*self.device) } } _ => panic!("Gather needs Scalar or Shape input, got {:?}!", self.input), }; + let output = &self.output.name; match &self.index { @@ -46,14 +53,41 @@ impl NodeCodegen for GatherNode { // convert the 0-D index to a 1-D Tensor with len 1 to use burn's select, // then squeeze the dimension to reduce the rank let index = &idx_scalar.name; + let output_rank = input_rank - 1; quote! { - let #output = #input.select(#dim, Tensor::from_data([#index], &*self.device)).squeeze(#dim); + let indices = Tensor::::from_data([#index], &*self.device); + let slice = Tensor::select(#input, #dim, indices); + let #output = slice.squeeze::<#output_rank>(#dim); } } Type::Tensor(idx_tensor) => { let index = scope.tensor_use_owned(idx_tensor, node_position); - quote! { - let #output = #input.select(#dim, #index); + let index_rank = idx_tensor.dim; + let output_rank = index_rank + input_rank - 1; + match index_rank { + 1 => quote! { + let indices = #index; + let #output = Tensor::select(#input, #dim, indices); + }, + _ => quote! { + let indices = #index; + + let n_dims = indices.dims().len(); + let index_flat = match n_dims { + 1 => indices.reshape([1, -1]), + n if n >= 2 => indices.flatten::<2>(0, n - 2), + _ => panic!("Number of dimensions must be greater than 0"), + }; + + let out = index_flat + .iter_dim(0) + .map(|idxs| { + let idxs = idxs.squeeze::<1>(0); + Tensor::select(#input.clone(), #dim, idxs) + }) + .collect(); + let #output = Tensor::stack::<#output_rank>(out, #dim); + }, } } _ => panic!("Gather needs Scalar or Tensor index, got {:?}!", self.index), @@ -78,7 +112,7 @@ mod tests { }; #[test] - fn test_codegen_gather() { + fn test_codegen_gather_1d_idx() { let mut graph = BurnGraph::::default(); graph.register(GatherNode::new( @@ -121,8 +155,77 @@ mod tests { tensor1: Tensor, tensor2: Tensor ) -> Tensor { - let tensor3 = tensor1.select(0, tensor2); + let indices = tensor2; + let tensor3 = Tensor::select(tensor1, 0, indices); + tensor3 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_gather_2d_idx() { + let mut graph = BurnGraph::::default(); + + graph.register(GatherNode::new( + Type::Tensor(TensorType::new_float("tensor1", 2)), + Type::Tensor(TensorType::new_int("tensor2", 2)), + TensorType::new_float("tensor3", 3), + 0, + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor + ) -> Tensor { + let indices = tensor2; + + let n_dims = indices.dims().len(); + let index_flat = match n_dims { + 1 => indices.reshape([1, -1]), + n if n >= 2 => indices.flatten::<2>(0, n - 2), + _ => panic!("Number of dimensions must be greater than 0"), + }; + + let out = index_flat + .iter_dim(0) + .map(|idxs| { + let idxs = idxs.squeeze::<1>(0); + Tensor::select(tensor1.clone(), 0, idxs) + }) + .collect(); + let tensor3 = Tensor::stack::<3usize>(out, 0); tensor3 } } @@ -138,7 +241,7 @@ mod tests { graph.register(GatherNode::new( Type::Shape(ShapeType::new("shape1", 3)), Type::Tensor(TensorType::new_int("tensor1", 1)), - TensorType::new_float("tensor2", 2), + TensorType::new_int("tensor2", 1), 0, )); @@ -174,8 +277,14 @@ mod tests { &self, shape1: [usize; 3], tensor1: Tensor - ) -> Tensor { - let tensor2 = Tensor::from_data(&shape1 as &[_], &*self.device).select(0, tensor1); + ) -> Tensor { + let indices = tensor1; + + let tensor2 = Tensor::select( + Tensor::::from_data(&shape1 as &[_], &*self.device), + 0, + indices, + ); tensor2 } @@ -192,7 +301,7 @@ mod tests { graph.register(GatherNode::new( Type::Tensor(TensorType::new_float("tensor1", 2)), Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)), - TensorType::new_float("tensor2", 2), + TensorType::new_float("tensor2", 1), 0, )); @@ -227,8 +336,11 @@ mod tests { &self, tensor1: Tensor, scalar1: i64 - ) -> Tensor { - let tensor2 = tensor1.select(0, Tensor::from_data([scalar1], &*self.device)).squeeze(0); + ) -> Tensor { + let indices = Tensor::::from_data([scalar1], &*self.device); + + let slice = Tensor::select(tensor1, 0, indices); + let tensor2 = slice.squeeze::<1usize>(0); tensor2 } diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index 871a51e4d6..b62720c102 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -892,6 +892,7 @@ impl TensorCheck { check } + pub(crate) fn check_prelu_shape( shape_tensor: &Shape, shape_weight: &Shape<1>, diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index 646887128f..97b251e0a3 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -816,10 +816,6 @@ fn gather_update_outputs(node: &mut Node) { _ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty), }; - if indices_dim > 1 { - panic!("Gather: indices tensor rank above 1 not supported") - } - match &node.inputs[0].ty { ArgType::Tensor(input_tensor) => { // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input