Skip to content

Commit

Permalink
Merge pull request #1161 from JanFSchulte/transposefix
Browse files Browse the repository at this point in the history
Bug fixes for channel-last conversions in pytorch
  • Loading branch information
jmitrevs authored Jan 7, 2025
2 parents 4b7e12d + 31219e3 commit 3fa2902
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
7 changes: 6 additions & 1 deletion hls4ml/model/optimizer/passes/convert_to_channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,17 @@ def transform(self, model, node):
if (
isinstance(node, Reshape)
and len(node.attributes['target_shape']) == 1
and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "internal"
and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "off"
):
previous_node = node.get_input_node(node.inputs[0])
input = previous_node.name
outshape = previous_node.get_output_variable().shape

if (model.config.config['IOType'] == 'io_stream') and len(outshape) == 3:
raise Exception(
'No 3D transpose available in io_stream, this model cannot be converted to channels-last'
)

if len(outshape) == 2:
attributes = {'perm': [1, 0]}
else:
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def config_from_pytorch_model(
default_precision='ap_fixed<16,6>',
default_reuse_factor=1,
channels_last_conversion='full',
transpose_outputs=True,
transpose_outputs=False,
max_precision=None,
):
"""Create an HLS conversion config given the PyTorch model.
Expand Down
2 changes: 1 addition & 1 deletion test/pytest/test_pytorch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def test_pooling(pooling, padds, backend):
model.eval()
pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy()

config = config_from_pytorch_model(model, input_shape_forHLS)
config = config_from_pytorch_model(model, input_shape_forHLS, transpose_outputs=True)
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_pooling_{pooling.__name__}_padds_{padds}_backend_{backend}')
hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend)
hls_model.compile()
Expand Down

0 comments on commit 3fa2902

Please sign in to comment.