Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into multi-head-attention
Browse files Browse the repository at this point in the history
  • Loading branch information
rianbrooksflynn committed Jan 13, 2025
2 parents a0b9390 + 5c85e9d commit a82a6aa
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
args: ["--profile", "black", --line-length=125]

- repo: https://github.com/asottile/pyupgrade
rev: v3.19.0
rev: v3.19.1
hooks:
- id: pyupgrade
args: ["--py36-plus"]
Expand Down
19 changes: 19 additions & 0 deletions hls4ml/converters/pytorch/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
import numpy as np

from hls4ml.converters.pytorch_to_hls import pytorch_handler


@pytorch_handler('Constant')
def parse_constant_layer(operation, layer_name, node):
assert 'Constant' in operation

layer = {}
layer['inputs'] = []

layer['class_name'] = 'Constant'
layer['name'] = layer_name

constant = np.array(node._args)
layer['value'] = constant
output_shape = constant.shape

return layer, output_shape


@pytorch_handler('Linear')
def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert 'Linear' in operation
Expand Down
43 changes: 37 additions & 6 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch

from hls4ml.model import ModelGraph
Expand Down Expand Up @@ -159,6 +160,23 @@ def parse_pytorch_model(config, verbose=True):

n_inputs = 0

# check for constant nodes
merge_layers = ['add', 'mul', 'sub', 'fmin', 'fmax']
i = 0 # count number of consts and use it in the name
for node in traced_model.graph.nodes:
if node.name.split('_')[0] in merge_layers:
for arg in node.args:
if np.isscalar(arg):
# add an input node with the constant value
new_node = traced_model.graph.placeholder(
name='const_' + str(i), type_expr=torch.Tensor, default_value=arg
)
node.prepend(new_node)
node.update_arg(1, new_node)
i += 1

traced_model.graph.lint()

for node in traced_model.graph.nodes:
if node.op == 'call_module':
# modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x',
Expand Down Expand Up @@ -249,13 +267,26 @@ def parse_pytorch_model(config, verbose=True):

input_layer = {}
input_layer['name'] = node.name
input_layer['class_name'] = 'InputLayer'
input_layer['input_shape'] = list(input_shapes[n_inputs][1:])
layer_list.insert(n_inputs, input_layer)

output_shapes[input_layer['name']] = list(input_shapes[n_inputs])
input_layers.append(input_layer['name'])
n_inputs += 1
if 'const' in node.name:
pytorch_class = 'Constant'
layer, output_shape = layer_handlers[pytorch_class](pytorch_class, node.name, node)

layer_list.append(layer)

assert output_shape is not None
output_shapes[layer['name']] = output_shape

else:

input_layer['class_name'] = 'InputLayer'
input_layer['input_shape'] = list(input_shapes[n_inputs][1:])
layer_list.insert(n_inputs, input_layer)

output_shapes[input_layer['name']] = list(input_shapes[n_inputs])

input_layers.append(input_layer['name'])
n_inputs += 1

layer_counter += 1

Expand Down
7 changes: 6 additions & 1 deletion hls4ml/model/optimizer/passes/convert_to_channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,17 @@ def transform(self, model, node):
if (
isinstance(node, Reshape)
and len(node.attributes['target_shape']) == 1
and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "internal"
and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "off"
):
previous_node = node.get_input_node(node.inputs[0])
input = previous_node.name
outshape = previous_node.get_output_variable().shape

if (model.config.config['IOType'] == 'io_stream') and len(outshape) == 3:
raise Exception(
'No 3D transpose available in io_stream, this model cannot be converted to channels-last'
)

if len(outshape) == 2:
attributes = {'perm': [1, 0]}
else:
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def config_from_pytorch_model(
default_precision='ap_fixed<16,6>',
default_reuse_factor=1,
channels_last_conversion='full',
transpose_outputs=True,
transpose_outputs=False,
max_precision=None,
):
"""Create an HLS conversion config given the PyTorch model.
Expand Down
2 changes: 2 additions & 0 deletions hls4ml/writer/vivado_accelerator_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ def write_board_script(self, model):
f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '12.5%')))
f.write('variable version\n')
f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0')))
f.write('variable maximum_size\n')
f.write('set maximum_size {}\n'.format(model.config.get_config_value('MaximumSize', '4096')))
if self.vivado_accelerator_config.get_interface() == 'axi_stream':
in_bit, out_bit = self.vivado_accelerator_config.get_io_bitwidth()
f.write(f'set bit_width_hls_output {in_bit}\n')
Expand Down
2 changes: 1 addition & 1 deletion test/pytest/test_pytorch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def test_pooling(pooling, padds, backend):
model.eval()
pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy()

config = config_from_pytorch_model(model, input_shape_forHLS)
config = config_from_pytorch_model(model, input_shape_forHLS, transpose_outputs=True)
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_pooling_{pooling.__name__}_padds_{padds}_backend_{backend}')
hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend)
hls_model.compile()
Expand Down

0 comments on commit a82a6aa

Please sign in to comment.