Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Constant nodes in pytorch parser #1123

Merged
merged 9 commits into from
Jan 10, 2025
19 changes: 19 additions & 0 deletions hls4ml/converters/pytorch/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
import numpy as np

from hls4ml.converters.pytorch_to_hls import pytorch_handler


@pytorch_handler('Constant')
def parse_constant_layer(operation, layer_name, node):
assert 'Constant' in operation

layer = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add empty list as inputs here (and do the same in onnx parser)? The Constant node in the IR will override anyway, but it feels like that feature should be a check and an error, otherwise it will silently changing things making it harder to debug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the suggested change here would just be layer['inputs'] = [], and then we'd add a check to the Constant node to throw an error if the input list is not empty?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and in qonnx parser. Then the self.inputs = [] in Constant layer could be a check and a warning before setting. This could be better for futureproofing for constant nodes coming from various sources (like optimizers). @jmitrevs what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added it here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1119 tried to address it. The issue is that if the inputs has length 0, graph._make_graph adds the previous layer in the layer list as an input, so I think the inputs would get overriden. In #1119 we override the incorrect input for Const nodes in the init. Overall it's a bit clunky.

layer['inputs'] = []

layer['class_name'] = 'Constant'
layer['name'] = layer_name

constant = np.array(node._args)
layer['value'] = constant
output_shape = constant.shape

return layer, output_shape


@pytorch_handler('Linear')
def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert 'Linear' in operation
Expand Down
43 changes: 37 additions & 6 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch

from hls4ml.model import ModelGraph
Expand Down Expand Up @@ -159,6 +160,23 @@ def parse_pytorch_model(config, verbose=True):

n_inputs = 0

# check for constant nodes
merge_layers = ['add', 'mul', 'sub', 'fmin', 'fmax']
i = 0 # count number of consts and use it in the name
for node in traced_model.graph.nodes:
if node.name.split("_")[0] in merge_layers:
jmitrevs marked this conversation as resolved.
Show resolved Hide resolved
for arg in node.args:
if np.isscalar(arg):
# add an input node with the constant value
new_node = traced_model.graph.placeholder(
name='const_' + str(i), type_expr=torch.Tensor, default_value=arg
)
node.prepend(new_node)
node.update_arg(1, new_node)
i += 1

traced_model.graph.lint()

for node in traced_model.graph.nodes:
if node.op == 'call_module':
# modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x',
Expand Down Expand Up @@ -249,13 +267,26 @@ def parse_pytorch_model(config, verbose=True):

input_layer = {}
input_layer['name'] = node.name
input_layer['class_name'] = 'InputLayer'
input_layer['input_shape'] = list(input_shapes[n_inputs][1:])
layer_list.insert(n_inputs, input_layer)

output_shapes[input_layer['name']] = list(input_shapes[n_inputs])
input_layers.append(input_layer['name'])
n_inputs += 1
if 'const' in node.name:
pytorch_class = 'Constant'
layer, output_shape = layer_handlers[pytorch_class](pytorch_class, node.name, node)

layer_list.append(layer)

assert output_shape is not None
output_shapes[layer['name']] = output_shape

else:

input_layer['class_name'] = 'InputLayer'
input_layer['input_shape'] = list(input_shapes[n_inputs][1:])
layer_list.insert(n_inputs, input_layer)

output_shapes[input_layer['name']] = list(input_shapes[n_inputs])

input_layers.append(input_layer['name'])
n_inputs += 1

layer_counter += 1

Expand Down
Loading