Skip to content

Commit

Permalink
Update optimizer match and default values for linebuffer impl.
Browse files Browse the repository at this point in the history
  • Loading branch information
jicampos committed Jan 17, 2025
1 parent a3ec210 commit faa71f6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
57 changes: 26 additions & 31 deletions hls4ml/backends/catapult/passes/conv_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ class GenerateConvStreamingInstructions(OptimizerPass):
'''Generates the instructions for streaming implementation of CNNs'''

def match(self, node):
return isinstance(node, (Conv1D, SeparableConv1D, Conv2D, SeparableConv2D))
is_match = (
isinstance(node, (Conv1D, SeparableConv1D, Conv2D, SeparableConv2D))
and node.model.config.get_config_value('IOType').casefold() == 'io_stream'
and node.get_attr('implementation').casefold() == 'encoded'
)
return is_match

def transform(self, model, node):
node_class = node.__class__.__name__
Expand All @@ -18,35 +23,25 @@ def transform(self, model, node):
raise Exception(f'Cannot generate instructions for node {node.name} ({node_class})')

def _generate_1d_instructions(self, node):
if node.model.config.get_config_value('IOType') == 'io_stream':
min_w, instructions = node.model.config.backend.compute_conv1d_instructions(
node.get_input_variable().shape[0],
node.get_input_variable().shape[1],
node.get_attr('filt_width'),
node.get_attr('stride_width'),
)
instructions_str = ','.join(str(i) for i in instructions)
node.set_attr('min_width', min_w)
node.set_attr('instructions', instructions_str)
else:
# these are unused; just put dummy values
node.set_attr('min_width', node.get_attr('in_width'))
node.set_attr('instructions', '0')
min_w, instructions = node.model.config.backend.compute_conv1d_instructions(
node.get_input_variable().shape[0],
node.get_input_variable().shape[1],
node.get_attr('filt_width'),
node.get_attr('stride_width'),
)
instructions_str = ','.join(str(i) for i in instructions)
node.set_attr('min_width', min_w)
node.set_attr('instructions', instructions_str)

def _generate_2d_instructions(self, node):
if node.model.config.get_config_value('IOType') == 'io_stream':
min_h, min_w, instructions = node.model.config.backend.compute_conv2d_instructions(
node.get_input_variable().shape[0],
node.get_input_variable().shape[1],
node.get_input_variable().shape[2],
node.get_attr('filt_height'),
node.get_attr('stride_height'),
)
instructions_str = ','.join(str(i) for i in instructions)
node.set_attr('min_height', min_h)
node.set_attr('min_width', min_w)
node.set_attr('instructions', instructions_str)
else:
node.set_attr('min_height', node.get_attr('in_height'))
node.set_attr('min_width', node.get_attr('in_width'))
node.set_attr('instructions', '0')
min_h, min_w, instructions = node.model.config.backend.compute_conv2d_instructions(
node.get_input_variable().shape[0],
node.get_input_variable().shape[1],
node.get_input_variable().shape[2],
node.get_attr('filt_height'),
node.get_attr('stride_height'),
)
instructions_str = ','.join(str(i) for i in instructions)
node.set_attr('min_height', min_h)
node.set_attr('min_width', min_w)
node.set_attr('instructions', instructions_str)
17 changes: 17 additions & 0 deletions hls4ml/backends/catapult/passes/convolution_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def format(self, node):
else:
params['fill_fn'] = 'FillConv1DBuffer'

if (
node.get_attr('implementation').casefold() == 'linebuffer'
or node.model.config.get_config_value('IOType').casefold() == 'io_parallel'
):
# these are unused; just put dummy values
params['min_width'] = node.get_attr('in_width')
params['instructions'] = '0'

conv_config = self.template.format(**params)

mult_params = self._default_config_params(node)
Expand Down Expand Up @@ -210,6 +218,15 @@ def format(self, node):
else:
params['fill_fn'] = 'FillConv2DBuffer'

if (
node.get_attr('implementation').casefold() == 'linebuffer'
or node.model.config.get_config_value('IOType').casefold() == 'io_parallel'
):
# these are unused; just put dummy values
params['min_height'] = node.get_attr('in_height')
params['min_width'] = node.get_attr('in_width')
params['instructions'] = '0'

conv_config = self.template.format(**params)

mult_params = self._default_config_params(node)
Expand Down

0 comments on commit faa71f6

Please sign in to comment.