diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0601a84b2..d45ffbdd2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: args: ["--profile", "black", --line-length=125] - repo: https://github.com/asottile/pyupgrade - rev: v3.19.0 + rev: v3.19.1 hooks: - id: pyupgrade args: ["--py36-plus"] diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 2c05b7501..57c42f401 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -1,6 +1,25 @@ +import numpy as np + from hls4ml.converters.pytorch_to_hls import pytorch_handler +@pytorch_handler('Constant') +def parse_constant_layer(operation, layer_name, node): + assert 'Constant' in operation + + layer = {} + layer['inputs'] = [] + + layer['class_name'] = 'Constant' + layer['name'] = layer_name + + constant = np.array(node._args) + layer['value'] = constant + output_shape = constant.shape + + return layer, output_shape + + @pytorch_handler('Linear') def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): assert 'Linear' in operation diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 79ca1fa5c..871026bc4 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -1,3 +1,4 @@ +import numpy as np import torch from hls4ml.model import ModelGraph @@ -159,6 +160,23 @@ def parse_pytorch_model(config, verbose=True): n_inputs = 0 + # check for constant nodes + merge_layers = ['add', 'mul', 'sub', 'fmin', 'fmax'] + i = 0 # count number of consts and use it in the name + for node in traced_model.graph.nodes: + if node.name.split('_')[0] in merge_layers: + for arg in node.args: + if np.isscalar(arg): + # add an input node with the constant value + new_node = traced_model.graph.placeholder( + name='const_' + str(i), type_expr=torch.Tensor, default_value=arg + ) + node.prepend(new_node) + node.update_arg(1, new_node) + i += 1 + + traced_model.graph.lint() + for node in traced_model.graph.nodes: if node.op == 'call_module': # modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x', @@ -249,13 +267,26 @@ def parse_pytorch_model(config, verbose=True): input_layer = {} input_layer['name'] = node.name - input_layer['class_name'] = 'InputLayer' - input_layer['input_shape'] = list(input_shapes[n_inputs][1:]) - layer_list.insert(n_inputs, input_layer) - output_shapes[input_layer['name']] = list(input_shapes[n_inputs]) - input_layers.append(input_layer['name']) - n_inputs += 1 + if 'const' in node.name: + pytorch_class = 'Constant' + layer, output_shape = layer_handlers[pytorch_class](pytorch_class, node.name, node) + + layer_list.append(layer) + + assert output_shape is not None + output_shapes[layer['name']] = output_shape + + else: + + input_layer['class_name'] = 'InputLayer' + input_layer['input_shape'] = list(input_shapes[n_inputs][1:]) + layer_list.insert(n_inputs, input_layer) + + output_shapes[input_layer['name']] = list(input_shapes[n_inputs]) + + input_layers.append(input_layer['name']) + n_inputs += 1 layer_counter += 1 diff --git a/hls4ml/model/optimizer/passes/convert_to_channels_last.py b/hls4ml/model/optimizer/passes/convert_to_channels_last.py index 0b5f12c00..606f42e54 100644 --- a/hls4ml/model/optimizer/passes/convert_to_channels_last.py +++ b/hls4ml/model/optimizer/passes/convert_to_channels_last.py @@ -97,12 +97,17 @@ def transform(self, model, node): if ( isinstance(node, Reshape) and len(node.attributes['target_shape']) == 1 - and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "internal" + and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "off" ): previous_node = node.get_input_node(node.inputs[0]) input = previous_node.name outshape = previous_node.get_output_variable().shape + if (model.config.config['IOType'] == 'io_stream') and len(outshape) == 3: + raise Exception( + 'No 3D transpose available in io_stream, this model cannot be converted to channels-last' + ) + if len(outshape) == 2: attributes = {'perm': [1, 0]} else: diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index e45008409..1db8e3c73 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -283,7 +283,7 @@ def config_from_pytorch_model( default_precision='ap_fixed<16,6>', default_reuse_factor=1, channels_last_conversion='full', - transpose_outputs=True, + transpose_outputs=False, max_precision=None, ): """Create an HLS conversion config given the PyTorch model. diff --git a/hls4ml/writer/vivado_accelerator_writer.py b/hls4ml/writer/vivado_accelerator_writer.py index cefa158e1..817847887 100644 --- a/hls4ml/writer/vivado_accelerator_writer.py +++ b/hls4ml/writer/vivado_accelerator_writer.py @@ -394,6 +394,8 @@ def write_board_script(self, model): f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '12.5%'))) f.write('variable version\n') f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0'))) + f.write('variable maximum_size\n') + f.write('set maximum_size {}\n'.format(model.config.get_config_value('MaximumSize', '4096'))) if self.vivado_accelerator_config.get_interface() == 'axi_stream': in_bit, out_bit = self.vivado_accelerator_config.get_io_bitwidth() f.write(f'set bit_width_hls_output {in_bit}\n') diff --git a/test/pytest/test_pytorch_api.py b/test/pytest/test_pytorch_api.py index 3056bd13f..3de0b3f19 100644 --- a/test/pytest/test_pytorch_api.py +++ b/test/pytest/test_pytorch_api.py @@ -498,7 +498,7 @@ def test_pooling(pooling, padds, backend): model.eval() pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy() - config = config_from_pytorch_model(model, input_shape_forHLS) + config = config_from_pytorch_model(model, input_shape_forHLS, transpose_outputs=True) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_pooling_{pooling.__name__}_padds_{padds}_backend_{backend}') hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend) hls_model.compile()