Skip to content

Commit

Permalink
Feature/codegen gather indices greater than rank 1 (#2199)
Browse files Browse the repository at this point in the history
* implemented muli-dim index for GatherNode

The `NodeCodegen` impl for `GatherNode` now performs gather in complete
accordance with the ONNX Gather spec.
- a `gather` function was added to the gather.rs file
- `gather()` is now called within the codegen instead of `tensor.select()`
- a test with two test cases have been added
    - test axes 0 and 1
    - both use 2D index tensors

* add gather_onnx to numeric api

Added int and float implementations of gather to the burn-tensor numeric
api:
- named the methods `gather_onnx` to not be confused with the current
  `gather`
- these implementations follow the `Gather` ONNX spec

Updated the gather*.py variants and their onnx outputs

* modified files didn't end up in last commit

* tests passing for onnx gather

The implementation of gather for the ONNX `Gather` spec is tentatively
complete:
- py test models are updated
- onnx_tests are modified and passing: `gather`, `gather_scalar`, and
  `gather_shape`
- node/gather tests are passing

NOTE: The two additional tests in crates/burn-import/src/burn/node/gather.rs that test
the actual functionality of gather are likely to be deleted, since they
are redundant to the tests in
crates/burn-import/onnx-tests/tests/onnx_tests.rs.

* inlined onnx gather within codegen

* rm gather_onnx from public api; rm unnecessary tests

* add comments to gather py models

* some codegen changes; formatting to appease run-checks

- Some necessary changes and improvements to the codegen inlined code
  after translating from public api (removed in previous commit).
- Changed some formatting that run-checks complained about.

* simplify gather codegen; include 1d and 2d onnx tests

Modified the `Gather` codegen per requested changes:
- combined match statements on index
- remove use of `alloc::vec::Vec`
- use map -> collect instead of procedural
- include a 1d index gather onnx test
- remove superflous tests

* delete unused gather.onnx
  • Loading branch information
alteredoxide authored Aug 28, 2024
1 parent 795201d commit 0292967
Show file tree
Hide file tree
Showing 14 changed files with 321 additions and 122 deletions.
3 changes: 2 additions & 1 deletion crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ fn main() {
.input("tests/exp/exp.onnx")
.input("tests/expand/expand.onnx")
.input("tests/flatten/flatten.onnx")
.input("tests/gather/gather.onnx")
.input("tests/gather/gather_1d_idx.onnx")
.input("tests/gather/gather_2d_idx.onnx")
.input("tests/gather/gather_scalar.onnx")
.input("tests/gather/gather_shape.onnx")
.input("tests/gather_elements/gather_elements.onnx")
Expand Down
18 changes: 0 additions & 18 deletions crates/burn-import/onnx-tests/tests/gather/gather.onnx

This file was deleted.

47 changes: 0 additions & 47 deletions crates/burn-import/onnx-tests/tests/gather/gather.py

This file was deleted.

Binary file not shown.
62 changes: 62 additions & 0 deletions crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/gather/gather.onnx

# There is no current support for `Split`, and the `for` loop over the indices
# results in a `Split` node in the ONNX model.
# Therefore, this model is built and exported using ONNX directly.

import onnx


def build_model():
return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
graph=onnx.helper.make_graph(name="main_graph", nodes=[
onnx.helper.make_node(
"Gather",
inputs=["input1", "input2"],
outputs=["output1"],
name="/Gather",
axis=1
),
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
),
),
onnx.helper.make_value_info(
name="input2",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[2]
),
),

],
outputs=[
onnx.helper.make_value_info(
name="output1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 2]
),
)
]),
)


def main():
onnx_model = build_model()
file_name = "gather_1d_idx.onnx"

# Ensure valid ONNX:
onnx.checker.check_model(onnx_model)

onnx.save(onnx_model, file_name)


if __name__ == '__main__':
main()
Binary file not shown.
62 changes: 62 additions & 0 deletions crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/gather/gather.onnx

# There is no current support for `Split`, and the `for` loop over the indices
# results in a `Split` node in the ONNX model.
# Therefore, this model is built and exported using ONNX directly.

import onnx


def build_model():
return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
graph=onnx.helper.make_graph(name="main_graph", nodes=[
onnx.helper.make_node(
"Gather",
inputs=["input1", "input2"],
outputs=["output1"],
name="/Gather",
axis=0
),
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
),
),
onnx.helper.make_value_info(
name="input2",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[2, 2]
),
),

],
outputs=[
onnx.helper.make_value_info(
name="output1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 2, 2]
),
)
]),
)


def main():
onnx_model = build_model()
file_name = "gather_2d_idx.onnx"

# Ensure valid ONNX:
onnx.checker.check_model(onnx_model)

onnx.save(onnx_model, file_name)


if __name__ == '__main__':
main()
Binary file modified crates/burn-import/onnx-tests/tests/gather/gather_scalar.onnx
Binary file not shown.
83 changes: 49 additions & 34 deletions crates/burn-import/onnx-tests/tests/gather/gather_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,60 @@

# used to generate model: onnx-tests/tests/gather/gather_scalar.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, index):
gathered = torch.select(x, 0, index)
return gathered
# There is no current support for `Split`, and the `for` loop over the indices
# results in a `Split` node in the ONNX model.
# Therefore, this model is built and exported using ONNX directly.

import onnx


def build_model():
return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
graph=onnx.helper.make_graph(name="main_graph", nodes=[
onnx.helper.make_node(
"Gather",
inputs=["input1", "input2"],
outputs=["output1"],
name="/Gather",
axis=0
),
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
),
),
onnx.helper.make_value_info(
name="input2",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[]
),
),

],
outputs=[
onnx.helper.make_value_info(
name="output1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[3]
),
)
]),
)


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "gather_scalar.onnx"

dummy_input = torch.randn(2, 3, device=device)
dummy_index = 0

torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))
onnx_model = build_model()
file_name = "gather_scalar.onnx"

# Output some test data for use in the test
test_input = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
test_index = 0
# Ensure valid ONNX:
onnx.checker.check_model(onnx_model)

print("Test input data: {}, {}".format(test_input, test_index))
output = model.forward(test_input, test_index)
print("Test output data: {}".format(output))
onnx.save(onnx_model, file_name)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-import/onnx-tests/tests/gather/gather_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def build_model():
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2,3]
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
),
),
onnx.helper.make_value_info(
Expand Down Expand Up @@ -66,4 +66,4 @@ def main():


if __name__ == "__main__":
main()
main()
23 changes: 19 additions & 4 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ include_models!(
exp,
expand,
flatten,
gather,
gather_1d_idx,
gather_2d_idx,
gather_scalar,
gather_shape,
gather_elements,
Expand Down Expand Up @@ -451,15 +452,29 @@ mod tests {
}

#[test]
fn gather() {
let model: gather::Model<Backend> = gather::Model::default();
fn gather_1d_idx() {
let model: gather_1d_idx::Model<Backend> = gather_1d_idx::Model::default();

let device = Default::default();

let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
let index = Tensor::<Backend, 1, Int>::from_ints([0, 2], &device);
let output = model.forward(input, index);
let expected = TensorData::from([[1f32, 3.], [4., 6.]]);
let output = model.forward(input, index);

assert_eq!(output.to_data(), expected);
}

#[test]
fn gather_2d_idx() {
let model: gather_2d_idx::Model<Backend> = gather_2d_idx::Model::default();

let device = Default::default();

let input = Tensor::<Backend, 2>::from_data([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], &device);
let index = Tensor::<Backend, 2, Int>::from_data([[0, 1], [1, 2]], &device);
let expected = TensorData::from([[[1f32, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]]);
let output = model.forward(input, index);

assert_eq!(output.to_data(), expected);
}
Expand Down
Loading

0 comments on commit 0292967

Please sign in to comment.