diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0601a84b2d..d607959dab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,6 +9,11 @@ repos: args: ['--line-length=125', '--skip-string-normalization'] +- repo: https://github.com/tox-dev/pyproject-fmt + rev: v2.5.0 + hooks: + - id: pyproject-fmt + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: @@ -16,6 +21,7 @@ repos: - id: check-case-conflict - id: check-merge-conflict - id: check-symlinks + - id: check-toml - id: check-yaml - id: debug-statements - id: end-of-file-fixer @@ -27,7 +33,6 @@ repos: rev: 5.13.2 hooks: - id: isort - args: ["--profile", "black", --line-length=125] - repo: https://github.com/asottile/pyupgrade rev: v3.19.0 @@ -35,11 +40,6 @@ repos: - id: pyupgrade args: ["--py36-plus"] -- repo: https://github.com/asottile/setup-cfg-fmt - rev: v2.7.0 - hooks: - - id: setup-cfg-fmt - - repo: https://github.com/pycqa/flake8 rev: 7.1.1 hooks: @@ -47,7 +47,11 @@ repos: exclude: docs/conf.py additional_dependencies: [flake8-bugbear, flake8-print] args: ['--max-line-length=125', # github viewer width - '--extend-ignore=E203,T201'] # E203 is not PEP8 compliant + '--extend-ignore=E203,T201', # E203 is not PEP8 compliant + '--per-file-ignores=hls4ml/model/optimizer/passes/bit_exact.py:E741,hls4ml/converters/keras_v3/squark/_base.py:E741,__init__.py:F401', + # i for #int w/o sign, I for #int w/ sign when massively processing bw conversions ...... + # ignore unused imports in __init__.py ..... + ] - repo: https://github.com/mgedmin/check-manifest rev: "0.50" diff --git a/MANIFEST.in b/MANIFEST.in index 549cc6983c..5bec5fe2a6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,7 +1,8 @@ -include LICENSE README.md CONTRIBUTING.md CITATION.cff pyproject.toml setup.py setup.cfg .clang-format +include LICENSE README.md CONTRIBUTING.md CITATION.cff pyproject.toml .clang-format graft example-models graft test graft contrib recursive-include hls4ml/templates * -global-exclude .git .gitmodules .gitlab-ci.yml +recursive-include hls4ml *.py +global-exclude .git .gitmodules .gitlab-ci.yml *.pyc include hls4ml/backends/vivado_accelerator/supported_boards.json diff --git a/hls4ml/__init__.py b/hls4ml/__init__.py index e3a7247b0d..0ff5e52ac9 100644 --- a/hls4ml/__init__.py +++ b/hls4ml/__init__.py @@ -1,33 +1,3 @@ -# Temporary workaround for QKeras installation requirement, will be removed after 1.0.0 -def maybe_install_qkeras(): - import subprocess - import sys - - QKERAS_PKG_NAME = 'QKeras' - # QKERAS_PKG_SOURCE = QKERAS_PKG_NAME - QKERAS_PKG_SOURCE = 'qkeras@git+https://github.com/fastmachinelearning/qkeras.git' - - def pip_list(): - p = subprocess.run([sys.executable, '-m', 'pip', 'list'], check=True, capture_output=True) - return p.stdout.decode() - - def pip_install(package): - subprocess.check_call([sys.executable, '-m', 'pip', 'install', package]) - - all_pkgs = pip_list() - if QKERAS_PKG_NAME not in all_pkgs: - print('QKeras installation not found, installing one...') - pip_install(QKERAS_PKG_SOURCE) - print('QKeras installed.') - - -try: - maybe_install_qkeras() -except Exception: - print('Could not find QKeras installation, make sure you have QKeras installed.') - -# End of workaround - from hls4ml import converters, report, utils # noqa: F401, E402 try: diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index fbfed71c5b..54d7fd6cd8 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -7,7 +7,7 @@ import numpy as np from hls4ml.backends.backend import Backend -from hls4ml.model.attributes import ChoiceAttribute, ConfigurableAttribute, TypeAttribute +from hls4ml.model.attributes import ConfigurableAttribute, TypeAttribute from hls4ml.model.layers import ( GRU, LSTM, @@ -32,7 +32,6 @@ SeparableConv1D, SeparableConv2D, SimpleRNN, - Softmax, ) from hls4ml.model.optimizer import model_optimizer from hls4ml.model.types import ( @@ -40,8 +39,6 @@ FixedPrecisionType, IntegerPrecisionType, PrecisionType, - RoundingMode, - SaturationMode, UnspecifiedPrecisionType, XnorPrecisionType, ) @@ -109,34 +106,6 @@ def __init__(self, name): act_attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8), description=descriptions.table_type)) self.attribute_map[Activation] = act_attrs - softmax_attrs = self.attribute_map.get(Softmax, []) - softmax_attrs.append( - ChoiceAttribute( - 'implementation', - ['latency', 'stable', 'argmax', 'legacy'], - default='stable', - description=descriptions.softmax_implementation, - ) - ) - softmax_attrs.append( - ConfigurableAttribute('skip', value_type=bool, default=False, description=descriptions.softmax_skip) - ) - softmax_attrs.append( - TypeAttribute( - 'exp_table', - default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), - description=descriptions.table_type, - ) - ) - softmax_attrs.append( - TypeAttribute( - 'inv_table', - default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), - description=descriptions.table_type, - ) - ) - self.attribute_map[Softmax] = softmax_attrs - def create_layer_class(self, layer_class): new_attrubutes = [] for cls, attributes in self.attribute_map.items(): diff --git a/hls4ml/backends/fpga/passes/fix_softmax_table_size.py b/hls4ml/backends/fpga/passes/fix_softmax_table_size.py index 4e04626d2e..860aa89597 100644 --- a/hls4ml/backends/fpga/passes/fix_softmax_table_size.py +++ b/hls4ml/backends/fpga/passes/fix_softmax_table_size.py @@ -6,7 +6,11 @@ class FixSoftmaxTableSize(OptimizerPass): def match(self, node): - return isinstance(node, Softmax) + if not isinstance(node, Softmax): + return False + if 'inv_table_size' in node.attributes: + return False # handler generating inv_table_size sets it properly + return True def transform(self, model, node: Layer): inp_layer = node.get_input_node() # type: ignore diff --git a/hls4ml/backends/fpga/passes/hgq_proxy_model.py b/hls4ml/backends/fpga/passes/hgq_proxy_model.py index 5ec1200ac7..77773bf131 100644 --- a/hls4ml/backends/fpga/passes/hgq_proxy_model.py +++ b/hls4ml/backends/fpga/passes/hgq_proxy_model.py @@ -52,10 +52,6 @@ def match(self, node: Layer): return isinstance(node, FixedPointQuantizer) def transform(self, model, node: FixedPointQuantizer): - if node.fusible: - model.remove_node(node, rewire=True) - return True - if model.config.config['IOType'] != 'io_parallel': raise NotImplementedError('Heterogenous quantization for activations is only supported with IOType=io_parallel') @@ -94,7 +90,6 @@ def __init__(self): def format(self, node): params = self._default_function_params(node) - node.attributes['result_t'].precision = node.attributes['table_t'].precision params['config'] = f'unary_lut_config{node.index}' params['table'] = node.get_weights('table').name diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 836da6e68a..8249f88bb8 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -150,13 +150,21 @@ def format(self, node): softmax_config_template = """struct {type}_config{index} : nnet::activ_config {{ static const unsigned n_in = {n_in}; - static const unsigned table_size = {table_size}; + static const unsigned n_outer = {n_outer}; + static const unsigned n_inner = {n_inner}; + static const unsigned parallelization_factor = {parallelization_factor}; + static const unsigned exp_table_size = {exp_table_size}; + static const unsigned inv_table_size = {inv_table_size}; static const unsigned io_type = nnet::{iotype}; static const unsigned reuse_factor = {reuse}; static const unsigned axis = {axis}; static const nnet::softmax_implementation implementation = nnet::softmax_implementation::{implementation}; + static constexpr float exp_scale = {exp_scale}; typedef {exp_table_t.name} exp_table_t; typedef {inv_table_t.name} inv_table_t; + typedef {accum_t.name} accum_t; + typedef {inv_inp_t.name} inv_inp_t; + typedef {inp_norm_t_str} inp_norm_t; }};\n""" activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});' @@ -208,10 +216,44 @@ def __init__(self): super(ActivationConfigTemplate, self).__init__(Softmax) # Skip ActivationConfigTemplate's __init__ self.template = softmax_config_template + def format(self, node): + params = self._default_config_params(node) + params['type'] = node.get_attr('activation') + params.setdefault('exp_table_size', params['table_size']) + params.setdefault('inv_table_size', params['table_size']) + params.setdefault('n_inner', 1) + params.setdefault('n_outer', 1) + params.setdefault('exp_scale', 1.0) + params.setdefault('parallelization_factor', -1) + + if 'inp_norm_t' not in params: + input_t = node.get_input_variable().type.precision + width, iwidth = input_t.width, input_t.integer + params['inp_norm_t_str'] = f'ap_fixed<{width}, {iwidth}, AP_RND, AP_SAT>' + else: + params['inp_norm_t_str'] = params['inp_norm_t'].name # type: ignore + + return self.template.format(**params) + + +class SoftmaxFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Softmax, include_header=activ_include_list) + self.template = activ_function_template + + def format(self, node): + params = self._default_function_params(node) + use_multidim = node.get_attr('n_inner', 1) > 1 or node.get_attr('n_outer', 1) > 1 + use_multidim = use_multidim and node.model.config.get_config_value('IOType') == 'io_parallel' + params['activation'] = 'softmax' if not use_multidim else 'softmax_multidim' + params['config'] = f'softmax_config{node.index}' + + return self.template.format(**params) + class ActivationFunctionTemplate(FunctionCallTemplate): def __init__(self): - super().__init__((Activation, HardActivation, Softmax), include_header=activ_include_list) + super().__init__((Activation, HardActivation), include_header=activ_include_list) self.template = activ_function_template def format(self, node): diff --git a/hls4ml/backends/vivado/passes/einsum.py b/hls4ml/backends/vivado/passes/einsum.py new file mode 100644 index 0000000000..0d13a7078a --- /dev/null +++ b/hls4ml/backends/vivado/passes/einsum.py @@ -0,0 +1,105 @@ +from math import ceil + +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import Einsum + +from .reshaping_templates import transpose_config_gen + +# Shared Dense template +# Einsum template + +einsum_config_template = ''' +struct config{index} {{ + typedef config{index}_tpose_inp0 tpose_inp0_conf; + typedef config{index}_tpose_inp1 tpose_inp1_conf; + typedef config{index}_tpose_out tpose_out_conf; + + typedef {accum_t.name} accum_t; + + // Layer Sizes + static const unsigned n_free0 = {n_free0}; + static const unsigned n_free1 = {n_free1}; + static const unsigned n_contract = {n_contract}; + static const unsigned n_inplace = {n_inplace}; + + // Resource reuse info + static const unsigned io_type = nnet::{iotype}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned reuse_factor = {reuse_factor}; + static const unsigned multiplier_limit = {multiplier_limit}; + static const bool store_weights_in_bram = false; // NOT USED + + template + using product = nnet::product::{product_type}; +}}; +''' + +einsum_function_template = 'nnet::einsum<{input0_t}, {input1_t}, {output_t}, {config}>({input0}, {input1}, {output});' + +einsum_include_list = ['nnet_utils/nnet_einsum.h'] + + +class EinsumConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Einsum) + self.template = einsum_config_template + + def format(self, node: Einsum): + default_params = self._default_config_params(node) + + strategy = node.model.config.get_strategy(node) + io_type = node.model.config.get_config_value('IOType') + + assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel for now' + assert strategy.lower() == 'latency', 'EinsumDense layer only supports Latency strategy for now' + + # EinsumDense config + params = default_params.copy() + params['strategy'] = strategy + params['n_free0'] = node.attributes.attributes['n_free0'] + params['n_free1'] = node.attributes.attributes['n_free1'] + params['n_contract'] = node.attributes.attributes['n_contract'] + params['n_inplace'] = node.attributes.attributes['n_inplace'] + inp0_t = node.get_input_variable(node.inputs[0]).type.precision + inp1_t = node.get_input_variable(node.inputs[1]).type.precision + params['product_type'] = get_backend('vivado').product_type(inp0_t, inp1_t) + + total_mults = params['n_free0'] * params['n_free1'] * params['n_contract'] * params['n_inplace'] + params['multiplier_limit'] = ceil(total_mults / params['reuse_factor']) + + einsum_conf = self.template.format(**params) + + # inp/out transpose config + inp0_shape = node.attributes.attributes['inp0_shape'] + inp1_shape = node.attributes.attributes['inp1_shape'] + out_interpert_shape = node.attributes.attributes['out_interpert_shape'] + inp0_tpose_idxs = node.attributes.attributes['inp0_tpose_idxs'] + inp1_tpose_idxs = node.attributes.attributes['inp1_tpose_idxs'] + out_tpose_idxs = node.attributes.attributes['out_tpose_idxs'] + tpose_inp0_conf_name = f'config{node.index}_tpose_inp0' + tpose_inp1_conf_name = f'config{node.index}_tpose_inp1' + tpose_out_conf_name = f'config{node.index}_tpose_out' + + inp0_tpose_conf = transpose_config_gen(tpose_inp0_conf_name, inp0_shape, inp0_tpose_idxs) + inp1_tpose_conf = transpose_config_gen(tpose_inp1_conf_name, inp1_shape, inp1_tpose_idxs) + out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + + return '\n\n'.join((inp0_tpose_conf, inp1_tpose_conf, out_tpose_conf, einsum_conf)) + + +class EinsumFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Einsum, include_header=einsum_include_list) + self.template = einsum_function_template + + def format(self, node: Einsum): + params = {} + params['config'] = f'config{node.index}' + params['input0_t'] = node.get_input_variable(node.inputs[0]).type.name + params['input1_t'] = node.get_input_variable(node.inputs[1]).type.name + params['output_t'] = node.get_output_variable().type.name + params['input0'] = node.get_input_variable(node.inputs[0]).name + params['input1'] = node.get_input_variable(node.inputs[1]).name + params['output'] = node.get_output_variable().name + return self.template.format(**params) diff --git a/hls4ml/backends/vivado/passes/einsum_dense.py b/hls4ml/backends/vivado/passes/einsum_dense.py new file mode 100644 index 0000000000..4edafa7f42 --- /dev/null +++ b/hls4ml/backends/vivado/passes/einsum_dense.py @@ -0,0 +1,120 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import EinsumDense + +from .reshaping_templates import transpose_config_gen + +# Shared Dense template + +dense_config_template = """struct config{index}_dense : nnet::dense_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned reuse_factor = {reuse}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned n_zeros = {nzeros}; + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + template + using kernel = nnet::{dense_function}; + template + using product = nnet::product::{product_type}; +}};\n""" + +# EinsumDense template + +einsum_dense_config_template = ''' +struct config{index} {{ + typedef config{index}_tpose_inp tpose_inp_conf; + typedef config{index}_tpose_out tpose_out_conf; + typedef config{index}_dense dense_conf; + + // Layer Sizes + static const unsigned n_free_data = {n_free_data}; + static const unsigned n_free_kernel = {n_free_kernel}; + static const unsigned n_contract = {n_contract}; + static const unsigned n_inplace = {n_inplace}; + + // Resource reuse info + static const unsigned io_type = nnet::{iotype}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned reuse_factor = {reuse_factor}; + static const unsigned parallelization_factor = {parallelization_factor}; // Only useful when n_inplace > 1 + static const bool store_weights_in_bram = false; // NOT USED +}}; +''' + +einsum_dense_function_template = 'nnet::einsum_dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' + +einsum_dense_include_list = ['nnet_utils/nnet_einsum_dense.h', 'nnet_utils/nnet_dense.h'] + + +class EinsumDenseConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(EinsumDense) + self.template = einsum_dense_config_template + self.dense_template = dense_config_template + + def format(self, node: EinsumDense): + default_params = self._default_config_params(node) + + strategy = node.model.config.get_strategy(node) + io_type = node.model.config.get_config_value('IOType') + + assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel for now' + assert strategy.lower() == 'latency', 'EinsumDense layer only supports Latency strategy for now' + + # EinsumDense config + params = default_params.copy() + params['strategy'] = strategy + params['n_free_data'] = node.attributes.attributes['n_free_data'] + params['n_free_kernel'] = node.attributes.attributes['n_free_kernel'] + params['n_contract'] = node.attributes.attributes['n_contract'] + params['n_inplace'] = node.attributes.attributes['n_inplace'] + params['parallelization_factor'] = node.attributes.attributes['parallelization_factor'] + + einsum_conf = self.template.format(**params) + + # inp/out transpose config + inp_shape = node.attributes.attributes['inp_shape'] + out_interpert_shape = node.attributes.attributes['out_interpert_shape'] + inp_tpose_idxs = node.attributes.attributes['inp_tpose_idxs'] + out_tpose_idxs = node.attributes.attributes['out_tpose_idxs'] + tpose_inp_conf_name = f'config{node.index}_tpose_inp' + tpose_out_conf_name = f'config{node.index}_tpose_out' + + inp_tpose_conf = transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs) + out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + + # Dense config + dense_params = default_params.copy() + dense_params['strategy'] = strategy + dense_params['n_in'] = node.attributes.attributes['n_contract'] + dense_params['n_out'] = node.attributes.attributes['n_free_kernel'] + if node.attributes.attributes['n_inplace'] == 1: + dense_params['nzeros'] = node.get_weights('weight').nzeros # type: ignore + else: + dense_params['nzeros'] = '-1; // Not making sense when kernels are switching' + dense_params['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights('weight').type.precision # type: ignore + ) + + dense_params['dense_function'] = 'DenseLatency' # Latency only for now + + dense_config = self.dense_template.format(**dense_params) + + return '\n\n'.join((inp_tpose_conf, out_tpose_conf, dense_config, einsum_conf)) + + +class EinsumDenseFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(EinsumDense, include_header=einsum_dense_include_list) + self.template = einsum_dense_function_template + + def format(self, node): + params = self._default_function_params(node) + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + + return self.template.format(**params) diff --git a/hls4ml/backends/vivado/passes/reshaping_templates.py b/hls4ml/backends/vivado/passes/reshaping_templates.py index ec6705eb29..e59d81c8c5 100644 --- a/hls4ml/backends/vivado/passes/reshaping_templates.py +++ b/hls4ml/backends/vivado/passes/reshaping_templates.py @@ -1,3 +1,7 @@ +from math import prod + +import numpy as np + from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate from hls4ml.model.layers import Resize, Transpose, ZeroPadding1D, ZeroPadding2D @@ -97,16 +101,45 @@ def format(self, node): # Transpose templates -transpose_config_template = """struct config{index} : nnet::transpose_config {{ - static const unsigned depth = {depth}; - static const unsigned height = {height}; - static const unsigned width = {width}; - static constexpr unsigned perm[3] = {{{perm_str}}}; -}};\n""" -transpose_function_template = 'nnet::transpose_{dim}<{input_t}, {output_t}, {config}>({input}, {output});' +transpose_include_list = ['nnet_utils/nnet_transpose.h', 'nnet_utils/nnet_transpose_stream.h'] + +transpose_config_template = """struct {config_name} {{ + static const unsigned dims = {dims}; + static const unsigned N = {N}; + static const unsigned* const from_shape; + static const unsigned* const to_shape; + static const unsigned* const perm; + static const unsigned* const perm_strides; +}}; + +unsigned {config_name}_from_shape[{dims}] = {{{from_shape}}}; +unsigned {config_name}_to_shape[{dims}] = {{{to_shape}}}; +unsigned {config_name}_perm[{dims}] = {{{perm}}}; +unsigned {config_name}_perm_strides[{dims}] = {{{perm_strides}}}; + +const unsigned* const {config_name}::from_shape = {config_name}_from_shape; +const unsigned* const {config_name}::to_shape = {config_name}_to_shape; +const unsigned* const {config_name}::perm = {config_name}_perm; +const unsigned* const {config_name}::perm_strides = {config_name}_perm_strides; +""" + +transpose_function_template = 'nnet::transpose<{input_t}, {output_t}, {config_name}>({input}, {output});' -transpose_include_list = ['nnet_utils/nnet_array.h', 'nnet_utils/nnet_stream.h'] + +def transpose_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]): + new_shape = tuple(shape[i] for i in perm) + strides = np.cumprod((shape[1:] + (1,))[::-1])[::-1] + perm_strides = tuple(int(strides[i]) for i in perm) + return transpose_config_template.format( + dims=len(shape), + N=prod(shape), + from_shape=', '.join(str(x) for x in shape), + perm=', '.join(str(x) for x in perm), + perm_strides=', '.join(str(x) for x in perm_strides), + to_shape=', '.join(str(x) for x in new_shape), + config_name=name, + ) class TransposeConfigTemplate(LayerConfigTemplate): @@ -115,18 +148,18 @@ def __init__(self): self.template = transpose_config_template def format(self, node): - params = self._default_config_params(node) - - return self.template.format(**params) + shape = tuple(node.get_input_variable().shape) + perm = tuple(node.get_attr('perm')) + name = f'config{node.index}' + return transpose_config_gen(name, shape, perm) class TransposeFunctionTemplate(FunctionCallTemplate): def __init__(self): - super().__init__(Transpose, include_header=transpose_include_list) self.template = transpose_function_template + super().__init__(Transpose, include_header=transpose_include_list) def format(self, node): params = self._default_function_params(node) - params['dim'] = node.get_attr('dim') - + params['config_name'] = f'config{node.index}' return self.template.format(**params) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 117805dd86..d2ba498a73 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -26,7 +26,6 @@ SeparableConv1D, SeparableConv2D, SimpleRNN, - Softmax, ) from hls4ml.model.optimizer import get_backend_passes, layer_optimizer from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType, PackedType @@ -551,13 +550,6 @@ def init_pooling1d(self, layer): def init_pooling2d(self, layer): layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) - @layer_optimizer(Softmax) - def init_softmax(self, layer): - if layer.model.config.get_config_value('IOType') == 'io_parallel': - assert ( - len(layer.get_input_variable().shape) == 1 - ), 'Softmax with io_parallel strategy cannot be used on multidimensional tensors.' - @layer_optimizer(Embedding) def init_embed(self, layer): if layer.attributes['n_in'] is None: diff --git a/scripts/hls4ml b/hls4ml/cli/__init__.py similarity index 100% rename from scripts/hls4ml rename to hls4ml/cli/__init__.py diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index 3d7ce1fe56..47569b1ad9 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -1,6 +1,5 @@ import importlib import os -import warnings import yaml @@ -10,33 +9,22 @@ from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401 from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401 from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler +from hls4ml.converters.keras_v3_to_hls import parse_keras_v3_model # noqa: F401 +from hls4ml.converters.onnx_to_hls import get_supported_onnx_layers # noqa: F401 from hls4ml.converters.onnx_to_hls import parse_onnx_model # noqa: F401 +from hls4ml.converters.onnx_to_hls import onnx_to_hls, register_onnx_layer_handler +from hls4ml.converters.pytorch_to_hls import ( # noqa: F401 + get_supported_pytorch_layers, + pytorch_to_hls, + register_pytorch_layer_handler, +) + +# from hls4ml.converters.pytorch_to_hls import parse_pytorch_model # noqa: F401 from hls4ml.model import ModelGraph from hls4ml.utils.config import create_config +from hls4ml.utils.dependency import requires from hls4ml.utils.symbolic_utils import LUTFunction -# ----------Make converters available if the libraries can be imported----------# -try: - from hls4ml.converters.pytorch_to_hls import ( # noqa: F401 - get_supported_pytorch_layers, - pytorch_to_hls, - register_pytorch_layer_handler, - ) - - __pytorch_enabled__ = True -except ImportError: - warnings.warn("WARNING: Pytorch converter is not enabled!", stacklevel=1) - __pytorch_enabled__ = False - -try: - from hls4ml.converters.onnx_to_hls import get_supported_onnx_layers # noqa: F401 - from hls4ml.converters.onnx_to_hls import onnx_to_hls, register_onnx_layer_handler - - __onnx_enabled__ = True -except ImportError: - warnings.warn("WARNING: ONNX converter is not enabled!", stacklevel=1) - __onnx_enabled__ = False - # ----------Layer handling register----------# model_types = ['keras', 'pytorch', 'onnx'] @@ -51,7 +39,7 @@ # and has 'handles' attribute # and is defined in this module (i.e., not imported) if callable(func) and hasattr(func, 'handles') and func.__module__ == lib.__name__: - for layer in func.handles: + for layer in func.handles: # type: ignore if model_type == 'keras': register_keras_layer_handler(layer, func) elif model_type == 'pytorch': @@ -93,10 +81,10 @@ def parse_yaml_config(config_file): """ def construct_keras_model(loader, node): - from tensorflow.keras.models import load_model - model_str = loader.construct_scalar(node) - return load_model(model_str) + import keras + + return keras.models.load_model(model_str) yaml.add_constructor('!keras_model', construct_keras_model, Loader=yaml.SafeLoader) @@ -124,15 +112,9 @@ def convert_from_config(config): model = None if 'OnnxModel' in yamlConfig: - if __onnx_enabled__: - model = onnx_to_hls(yamlConfig) - else: - raise Exception("ONNX not found. Please install ONNX.") + model = onnx_to_hls(yamlConfig) elif 'PytorchModel' in yamlConfig: - if __pytorch_enabled__: - model = pytorch_to_hls(yamlConfig) - else: - raise Exception("PyTorch not found. Please install PyTorch.") + model = pytorch_to_hls(yamlConfig) else: model = keras_to_hls(yamlConfig) @@ -174,6 +156,7 @@ def _check_model_config(model_config): return model_config +@requires('_keras') def convert_from_keras_model( model, output_dir='my-hls-test', @@ -237,6 +220,7 @@ def convert_from_keras_model( return keras_to_hls(config) +@requires('_torch') def convert_from_pytorch_model( model, output_dir='my-hls-test', @@ -308,6 +292,7 @@ def convert_from_pytorch_model( return pytorch_to_hls(config) +@requires('onnx') def convert_from_onnx_model( model, output_dir='my-hls-test', @@ -371,6 +356,7 @@ def convert_from_onnx_model( return onnx_to_hls(config) +@requires('sr') def convert_from_symbolic_expression( expr, n_symbols=None, diff --git a/hls4ml/converters/keras/hgq_proxy_model.py b/hls4ml/converters/keras/hgq_proxy_model.py index 1598759253..68b884a4fd 100644 --- a/hls4ml/converters/keras/hgq_proxy_model.py +++ b/hls4ml/converters/keras/hgq_proxy_model.py @@ -1,4 +1,5 @@ from hls4ml.converters.keras_to_hls import KerasReader, keras_handler, parse_default_keras_layer +from hls4ml.model.types import FixedPrecisionType @keras_handler('FixedPointQuantizer', 'HGQ>FixedPointQuantizer') @@ -10,11 +11,14 @@ def fixedpoint_quantizer_handler(keras_layer, input_names, input_shapes, data_re config['RND'] = keras_layer['config']['RND'] config['SAT'] = keras_layer['config']['SAT'] config['fusible'] = fusible - if not fusible: - k = data_reader.get_weights_data(name, 'keep_negative') - b = data_reader.get_weights_data(name, 'bits') - i = data_reader.get_weights_data(name, 'integers') - config['mask_kbi'] = k, b, i + k = data_reader.get_weights_data(name, 'keep_negative') + b = data_reader.get_weights_data(name, 'bits') + i = data_reader.get_weights_data(name, 'integers') + + if fusible: + k, b, i = k.ravel()[:1], b.ravel()[:1], i.ravel()[:1] + + config['mask_kbi'] = k, b, i config['overrides'] = keras_layer['config']['overrides'] layer = config @@ -27,10 +31,9 @@ def unary_lut_keras_handler(keras_layer, input_names, input_shapes, data_reader: table = data_reader.get_weights_data(config['name'], 'table') k, i, f = keras_layer['config']['kif_out'] - k, b, i = k, k + i + f, k + i - config['table_t'] = f'{"" if k else "u"}fixed<{b},{i}>' - config['table'] = table - config['table_size'] = len(table) + k, b, I = k, k + i + f, k + i # noqa: E741 + config['table_t'] = FixedPrecisionType(b, I, k) # noqa: E741 + config['table_data'] = table config['activation'] = 'unary_lut' layer = config diff --git a/hls4ml/converters/keras/qkeras.py b/hls4ml/converters/keras/qkeras.py index 7357d95aed..d1910c070d 100644 --- a/hls4ml/converters/keras/qkeras.py +++ b/hls4ml/converters/keras/qkeras.py @@ -1,5 +1,3 @@ -from qkeras.quantizers import get_quantizer - from hls4ml.converters.keras.convolution import parse_conv1d_layer, parse_conv2d_layer from hls4ml.converters.keras.core import parse_batchnorm_layer, parse_dense_layer from hls4ml.converters.keras.recurrent import parse_rnn_layer @@ -88,6 +86,8 @@ def parse_qrnn_layer(keras_layer, input_names, input_shapes, data_reader): @keras_handler('QActivation') def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader): + from qkeras.quantizers import get_quantizer + assert keras_layer['class_name'] == 'QActivation' supported_activations = [ 'quantized_relu', diff --git a/hls4ml/converters/keras/reshape.py b/hls4ml/converters/keras/reshape.py index 1f6dc2a759..08803df828 100644 --- a/hls4ml/converters/keras/reshape.py +++ b/hls4ml/converters/keras/reshape.py @@ -24,7 +24,7 @@ def parse_reshape_layer(keras_layer, input_names, input_shapes, data_reader): layer = parse_default_keras_layer(keras_layer, input_names) layer['target_shape'] = keras_layer['config']['target_shape'] - output_shape = input_shapes[0][:1] + keras_layer['config']['target_shape'] + output_shape = input_shapes[0][:1] + list(keras_layer['config']['target_shape']) return layer, output_shape diff --git a/hls4ml/converters/keras_to_hls.py b/hls4ml/converters/keras_to_hls.py index e31e2b96a9..aa7bfe8862 100644 --- a/hls4ml/converters/keras_to_hls.py +++ b/hls4ml/converters/keras_to_hls.py @@ -4,6 +4,8 @@ from hls4ml.model import ModelGraph +from .keras_v3_to_hls import parse_keras_v3_model + MAXMULT = 4096 @@ -160,9 +162,9 @@ def get_model_arch(config): # Model instance passed in config from API keras_model = config['KerasModel'] if isinstance(keras_model, str): - from tensorflow.keras.models import load_model + import keras - keras_model = load_model(keras_model) + keras_model = keras.models.load_model(keras_model) model_arch = json.loads(keras_model.to_json()) reader = KerasModelReader(keras_model) elif 'KerasJson' in config: @@ -323,6 +325,13 @@ def parse_keras_model(model_arch, reader): def keras_to_hls(config): + if 'KerasModel' in config: + import keras + + if keras.__version__ >= '3.0': + layer_list, input_layers, output_layers, _ = parse_keras_v3_model(config['KerasModel']) + return ModelGraph(config, layer_list, input_layers, output_layers) + model_arch, reader = get_model_arch(config) layer_list, input_layers, output_layers, _ = parse_keras_model(model_arch, reader) print('Creating HLS model') diff --git a/hls4ml/converters/keras_v3/__init__.py b/hls4ml/converters/keras_v3/__init__.py new file mode 100644 index 0000000000..eb9442ba91 --- /dev/null +++ b/hls4ml/converters/keras_v3/__init__.py @@ -0,0 +1,7 @@ +from . import conv # noqa: F401 +from . import core # noqa: F401 +from . import einsum_dense # noqa: F401 +from . import squark # noqa: F401 +from ._base import registry as layer_handlers + +__all__ = ['layer_handlers'] diff --git a/hls4ml/converters/keras_v3/_base.py b/hls4ml/converters/keras_v3/_base.py new file mode 100644 index 0000000000..6f50ed6523 --- /dev/null +++ b/hls4ml/converters/keras_v3/_base.py @@ -0,0 +1,216 @@ +import typing +from types import FunctionType +from typing import Any, Callable, Sequence, TypedDict, overload + + +class DefaultConfig(TypedDict, total=False): + name: str + class_name: str + module: str + input_keras_tensor_names: list[str] + input_shape: list[list[int]] + output_keras_tensor_names: list[str] + epsilon: float + use_bias: bool + data_format: str + + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + +T_kv3_handler = Callable[ + ['keras.Layer', Sequence['keras.KerasTensor'], Sequence['keras.KerasTensor']], tuple[dict[str, Any], ...] +] + +registry: dict[str, T_kv3_handler] = {} + + +@overload +def register(cls: type) -> type: ... + + +@overload +def register(cls: str) -> Callable[[T_kv3_handler], T_kv3_handler]: ... + + +def register(cls: str | type): + """Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class. + + Parameters + ---------- + cls : str|type + If str, the key to register the handler under. If type, the class to register the handler for. + + Examples + -------- + ```python + @keras_dispatcher.register + class MyLayerHandler(KerasV3LayerHandler): + handles = ('my_package.src.submodule.MyLayer', 'MyLayer2') + + def handle(self, layer, inp_tensors, out_tensors): + # handler code + + + @keras_dispatcher.register('MyLayer3') + def my_layer_handler(layer, inp_tensors, out_tensors): + # handler code + ``` + """ + + def deco(func): + if isinstance(cls, str): + registry[cls] = func + for k in getattr(func, 'handles', ()): + registry[k] = func + if isinstance(cls, type): + return cls + return func + + if isinstance(cls, type): + return deco(cls()) + return deco + + +def maybe_add_attrs(config: dict[str, Any] | DefaultConfig, obj: Any, *attrs: str): + for attr in attrs: + if attr not in config and hasattr(obj, attr): + config[attr] = getattr(obj, attr) + + +class KerasV3LayerHandler: + """Base class for keras v3 layer handlers. Subclass this class to create a handler for a specific layer type.""" + + handles = () + default_config: DefaultConfig + + def __call__( + self, + layer: 'keras.Layer', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ) -> tuple[dict[str, Any], ...]: + """Handle a keras layer. Return a tuple of dictionaries, each + dictionary representing a layer (module) in the HLS model. One + layer may correspond one or more dictionaries (e.g., layers with + activation functions will be split into two layers). + + Some common attributes are automatically added to the dictionary + if the handler returns a single dictionary. If the handler + returns multiple dictionaries, the attributes must be added + manually. Anything returned by the handler will override the + automatic attributes. + + Automatic attributes: - name - class_name - module - + input_keras_tensor_names - input_shape - + output_keras_tensor_names + + If the layer has an activation function, an additional + dictionary will be added to the return value representing the + activation function. + + + Parameters + ---------- + layer : keras.Layer + The layer to be converted to HLS configuration(s). + in_tensors : Sequence[KerasTensor] + The list of input tensors to the layer. + out_tensors : Sequence[KerasTensor] + The list of output tensors from the layer. + + Returns + ------- + dict[str, Any] | tuple[dict[str, Any], ...] + layer configuration(s) for the HLS model to be consumed by + the ModelGraph constructor + """ + + name = layer.name + class_name = layer.__class__.__name__ + module = layer.__module__ + + default_config: DefaultConfig = { + 'name': name, + 'class_name': class_name, + 'module': module, + 'input_keras_tensor_names': [t.name for t in in_tensors], + 'input_shape': [list(t.shape[1:]) for t in in_tensors], # type: ignore + 'output_keras_tensor_names': [t.name for t in out_tensors], + } + + maybe_add_attrs(default_config, layer, 'epsilon', 'use_bias', 'data_format') + + mandatory_keys = ['name', 'class_name', 'output_keras_tensor_names', 'input_keras_tensor_names'] + + self.default_config = default_config + config0 = self.handle(layer, in_tensors, out_tensors) + del self.default_config + + if isinstance(config0, tuple): + for conf in config0: + for key in mandatory_keys: + assert key in conf, f"Key {key} missing from layer {name} handled by {self.__class__.__name__}" + return config0 + + config = {} + config.update(default_config) + config.update(config0) + ret = (config,) + + # If activation exists, append it + + act_config, intermediate_tensor_name = self.maybe_get_activation_config(layer, out_tensors) + if act_config is not None: + ret[0]['output_keras_tensor_names'] = [intermediate_tensor_name] + ret = *ret, act_config + + return ret + + def maybe_get_activation_config(self, layer, out_tensors): + import keras + + activation = getattr(layer, 'activation', None) + name = layer.name + if activation not in (keras.activations.linear, None): + assert len(out_tensors) == 1, f"Layer {name} has more than one output, but has an activation function" + assert isinstance(activation, FunctionType), f"Activation function for layer {name} is not a function" + intermediate_tensor_name = f'{out_tensors[0].name}_activation' + act_cls_name = activation.__name__ + act_config = { + 'class_name': 'Activation', + 'activation': act_cls_name, + 'name': f'{name}_{act_cls_name}', + 'input_keras_tensor_names': [intermediate_tensor_name], + 'output_keras_tensor_names': [out_tensors[0].name], + } + return act_config, intermediate_tensor_name + return None, None + + def handle( + self, + layer: 'keras.Layer', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ) -> dict[str, Any] | tuple[dict[str, Any], ...]: + return {} + + def load_weight(self, layer: 'keras.Layer', key: str): + """Load a weight from a layer. + + Parameters + ---------- + layer : keras.Layer + The layer to load the weight from. + key : str + The key of the weight to load. + + Returns + ------- + np.ndarray + The weight. + """ + import keras + + return keras.ops.convert_to_numpy(getattr(layer, key)) diff --git a/hls4ml/converters/keras_v3/conv.py b/hls4ml/converters/keras_v3/conv.py new file mode 100644 index 0000000000..adf6221822 --- /dev/null +++ b/hls4ml/converters/keras_v3/conv.py @@ -0,0 +1,119 @@ +import typing +from math import ceil +from typing import Sequence + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + + +@register +class KV3ConvHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.convolutional.conv1d.Conv1D', + 'keras.src.layers.convolutional.conv2d.Conv2D', + 'keras.src.layers.convolutional.depthwise_conv1d.DepthwiseConv1D', + 'keras.src.layers.convolutional.depthwise_conv2d.DepthwiseConv2D', + 'keras.src.layers.convolutional.separable_conv1d.SeparableConv1D', + 'keras.src.layers.convolutional.separable_conv2d.SeparableConv2D', + ) + + def handle( + self, + layer: 'keras.layers.Conv1D|keras.layers.Conv2D|keras.layers.DepthwiseConv1D|keras.layers.DepthwiseConv2D', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + from keras.src.layers.convolutional.base_conv import BaseConv + from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv + from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv + + assert len(in_tensors) == 1, f"Layer {layer.name} has more than one input" + assert len(out_tensors) == 1, f"Layer {layer.name} has more than one output" + + in_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore + out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore + assert all(isinstance(x, int) for x in in_shape), f"Layer {layer.name} has non-fixed size input: {in_shape}" + assert all(isinstance(x, int) for x in out_shape), f"Layer {layer.name} has non-fixed size output: {out_shape}" + + kernel = self.load_weight(layer, 'kernel') + if layer.use_bias: + bias = self.load_weight(layer, 'bias') + else: + bias = None + + ker_px_shape: tuple[int, ...] = layer.kernel_size + data_format = layer.data_format + + if data_format == 'channels_last': + *px_in_shape, ch_in = in_shape + *px_out_shape, ch_out = out_shape + else: + ch_in, *px_in_shape = in_shape + ch_out, *px_out_shape = out_shape + + if layer.padding == 'same': + n_padding = [ceil(N / n) * n - N for N, n in zip(px_in_shape, ker_px_shape)] + n_padding0 = [p // 2 for p in n_padding] + n_padding1 = [p - p0 for p, p0 in zip(n_padding, n_padding0)] + elif layer.padding == 'valid': + n_padding0 = [0] * len(px_in_shape) + n_padding1 = [0] * len(px_in_shape) + elif layer.padding == 'causal': + n_padding0 = [ker_px_shape[0] - 1] + [0] * (len(px_in_shape) - 1) + n_padding1 = [0] * len(px_in_shape) + else: + raise ValueError(f"Invalid padding mode {layer.padding} for layer {layer.name}") + + config = { + 'bias_data': bias, + 'data_format': data_format, + 'weight_data': kernel, + 'n_filt': ch_out, + 'n_chan': ch_in, + } + + if layer.rank == 1: + config.update( + { + 'filt_width': ker_px_shape[0], + 'stride_width': layer.strides[0], + 'pad_left': n_padding0[0], + 'pad_right': n_padding1[0], + 'in_width': px_in_shape[0], + 'out_width': px_out_shape[0], + } + ) + elif layer.rank == 2: + config.update( + { + 'filt_height': ker_px_shape[0], + 'filt_width': ker_px_shape[1], + 'stride_height': layer.strides[0], + 'stride_width': layer.strides[1], + 'pad_top': n_padding0[0], + 'pad_bottom': n_padding1[0], + 'pad_left': n_padding0[1], + 'pad_right': n_padding1[1], + 'in_height': px_in_shape[0], + 'in_width': px_in_shape[1], + 'out_height': px_out_shape[0], + 'out_width': px_out_shape[1], + } + ) + else: + _cls = f"{layer.__class__.__module__}.{layer.__class__.__name__}" + raise ValueError(f"Only 1D and 2D conv layers are supported, got {_cls} (rank={layer.rank})") + if isinstance(layer, BaseDepthwiseConv): + config['depthwise_data'] = kernel + config['depth_multiplier'] = layer.depth_multiplier + elif isinstance(layer, BaseSeparableConv): + config['depthwise_data'] = kernel + config['pointwise_data'] = self.load_weight(layer, 'pointwise_kernel') + config['depth_multiplier'] = layer.depth_multiplier + elif isinstance(layer, BaseConv): + config['weight_data'] = kernel + + return config diff --git a/hls4ml/converters/keras_v3/core.py b/hls4ml/converters/keras_v3/core.py new file mode 100644 index 0000000000..f3ac9a0d75 --- /dev/null +++ b/hls4ml/converters/keras_v3/core.py @@ -0,0 +1,222 @@ +import inspect +import typing +from math import prod +from typing import Any, Sequence + +import numpy as np + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + from keras.src.layers.merging.base_merge import Merge + + +@register +class KV3DenseHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.dense.Dense',) + + def handle( + self, + layer: 'keras.layers.Dense', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + + kernel = self.load_weight(layer, 'kernel') + bias = self.load_weight(layer, 'bias') if layer.use_bias else None + n_in, n_out = kernel.shape + + config = { + 'data_format': 'channels_last', + 'weight_data': kernel, + 'bias_data': bias, + 'n_out': n_out, + 'n_in': n_in, + } + return config + + +@register +class KV3InputHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.input_layer.InputLayer',) + + def handle( + self, + layer: 'keras.layers.InputLayer', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {'input_shape': list(layer._batch_shape[1:])} + return config + + +@register +class KV3MergeHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.merging.add.Add', + 'keras.src.layers.merging.multiply.Multiply', + 'keras.src.layers.merging.average.Average', + 'keras.src.layers.merging.maximum.Maximum', + 'keras.src.layers.merging.minimum.Minimum', + 'keras.src.layers.merging.concatenate.Concatenate', + 'keras.src.layers.merging.subtract.Subtract', + 'keras.src.layers.merging.dot.Dot', + ) + + def handle( + self, + layer: 'Merge', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + cls_name: str | None = None, + ): + assert len(out_tensors) == 1, f"Merge layer {layer.name} has more than one output" + output_shape = list(out_tensors[0].shape[1:]) + + cls_name = cls_name or layer.__class__.__name__ + config: dict[str, Any] = { + 'output_shape': output_shape, + 'op': cls_name.lower(), + } + + match cls_name.lower(): + case 'Concatenate': + rank = len(output_shape) + class_name = f'Concatenate{rank}d' + config['axis'] = layer.axis + case 'Dot': + class_name = f'Dot{len(output_shape)}d' + rank = len(output_shape) + assert rank == 1, f"Dot product only supported for 1D tensors, got {rank}D on layer {layer.name}" + case _: + class_name = 'Merge' + + config['class_name'] = class_name + return config + + +@register +class KV3ActivationHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.activation.Activation',) + + def handle( + self, + layer: 'keras.layers.Activation', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + import keras + + config = {} + config.update(self.default_config) + + activation = getattr(layer, 'activation', keras.activations.linear) + match activation: + case keras.activations.softmax: + class_name = 'Softmax' + config['axis'] = -1 + case keras.activations.hard_sigmoid: + class_name = 'HardActivation' + case keras.activations.leaky_relu: + class_name = 'LeakyReLU' + signature = inspect.signature(keras.activations.leaky_relu) + config['activ_param'] = signature.parameters['negative_slope'].default + case keras.activations.elu: + class_name = 'ELU' + signature = inspect.signature(keras.activations.elu) + config['activ_param'] = signature.parameters['alpha'].default + case _: + class_name = 'Activation' + + config['activation'] = activation.__name__ + config['class_name'] = class_name + return (config,) + + +@register +class KV3ReLUHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.activations.leaky_relu.LeakyReLU', + 'keras.src.layers.activations.prelu.PReLU', + 'keras.src.layers.activations.relu.ReLU', + ) + + def handle( + self, + layer: 'keras.layers.ReLU', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {} + config.update(self.default_config) + + if layer.__class__.__name__ == 'ReLU': + config['class_name'] = 'Activation' + config['activation'] = 'relu' + return config + + if layer.__class__.__name__ == 'PReLU': + config['class_name'] = 'PReLU' + config['param_data'] = np.array(layer.alpha) + config['activation'] = 'prelu' + else: + config['class_name'] = 'LeakyReLU' + config['activ_param'] = float(layer.negative_slope) + config['activation'] = 'leaky_relu' + + return (config,) + + +@register +class KV3SoftmaxHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.softmax.Softmax',) + + def handle( + self, + layer: 'keras.layers.Softmax', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + ax = layer.axis + ax = ax if ax >= 0 else len(in_tensors[0].shape) + ax + # io_stream asserts axis=-1, convert to -1 when it is + n_outer: int = prod(in_tensors[0].shape[1:ax]) # type: ignore + n_inner: int = prod(in_tensors[0].shape[ax + 1 :]) # type: ignore + ax = -1 if ax == len(in_tensors[0].shape) - 1 else ax + config = {} + config.update(self.default_config) + if len(in_tensors) == 2: + raise NotImplementedError("Masked softmax not supported yet") + config['class_name'] = 'MaskedSoftmax' + elif len(in_tensors) == 1: + config['class_name'] = 'Softmax' + else: + raise ValueError(f"Too many inputs for softmax layer {layer.name}: expected 1 or 2, got {len(in_tensors)}") + config['axis'] = layer.axis + config['activation'] = 'softmax' + config['n_outer'] = (n_outer,) + config['n_inner'] = n_inner + + return (config,) + + +@register +class KV3HardActivationHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.elu.ELU',) + + def handle( + self, + layer: 'keras.layers.ELU', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {} + config.update(self.default_config) + + config['class_name'] = 'ELU' + config['activ_param'] = float(layer.alpha) + config['activation'] = 'elu' + + return (config,) diff --git a/hls4ml/converters/keras_v3/einsum_dense.py b/hls4ml/converters/keras_v3/einsum_dense.py new file mode 100644 index 0000000000..8eb000fcf7 --- /dev/null +++ b/hls4ml/converters/keras_v3/einsum_dense.py @@ -0,0 +1,75 @@ +import typing +from typing import Sequence + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + + +def strip_batch_dim(equation: str, einsum_dense: bool = True): + """Remove the batch dimension from the equation. + + Args: + equation (str): The einsum equation. + einsum_dense (bool): Whether the equation is for EinsumDense layer. + + Returns: + str: The einsum equation without the batch dimension. + """ + + _inps, out = equation.split('->') + inp0, inp1 = _inps.split(',') + if einsum_dense: + if inp0.startswith('...'): + assert out.startswith('...'), f'Error in eq: {equation}: Batch dim mismatch for the input and output.' + else: + assert inp0[0] == out[0], f'Error in eq: {equation}: Batch dim mismatch for the input and output.' + assert inp0[0] not in inp1, f'Error in eq: {equation}: Batch dim is used in the kernel.' + inp0, out = inp0[1:], out[1:] + else: + assert inp0[0] == inp1[0] == out[0], f'Error in eq: {equation}: Batch dim mismatch for the inputs and output.' + inp0, inp1, out = inp0[1:], inp1[1:], out[1:] + return f'{inp0},{inp1}->{out}' + + +@register +class KV3EinsumDenseHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.einsum_dense.EinsumDense',) + + def handle( + self, + layer: 'keras.layers.EinsumDense', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + assert len(in_tensors) == 1, 'EinsumDense layer must have exactly one input tensor' + assert len(out_tensors) == 1, 'EinsumDense layer must have exactly one output tensor' + + inp_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore + out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore + + # fmt: off + assert all(d is not None for d in inp_shape), \ + f'Error when processing {layer.name}: EinsumDense layer requires fully inp shapes' + assert all(d is not None for d in out_shape), \ + f'Error when processing {layer.name}: EinsumDense layer requires fully out shapes' + # fmt: on + + equation = strip_batch_dim(layer.equation, True) + + kernel = self.load_weight(layer, 'kernel') + + bias = None + if layer.bias_axes: + bias = self.load_weight(layer, 'bias') + + return { + 'class_name': 'EinsumDense', + 'equation': equation, + 'weight_data': kernel, + 'bias_data': bias, + 'inp_shape': inp_shape, + 'out_shape': out_shape, + } diff --git a/hls4ml/converters/keras_v3/squark/__init__.py b/hls4ml/converters/keras_v3/squark/__init__.py new file mode 100644 index 0000000000..f0f8d1c89b --- /dev/null +++ b/hls4ml/converters/keras_v3/squark/__init__.py @@ -0,0 +1 @@ +from . import _base, einsum, multi_head_attention, softmax, unary_lut diff --git a/hls4ml/converters/keras_v3/squark/_base.py b/hls4ml/converters/keras_v3/squark/_base.py new file mode 100644 index 0000000000..383b617568 --- /dev/null +++ b/hls4ml/converters/keras_v3/squark/_base.py @@ -0,0 +1,195 @@ +from math import prod +from typing import TYPE_CHECKING, Any, Sequence + +import numpy as np + +from hls4ml.converters.keras_v3._base import KerasV3LayerHandler, register +from hls4ml.converters.keras_v3.conv import KV3ConvHandler +from hls4ml.converters.keras_v3.core import KV3ActivationHandler, KV3DenseHandler, KV3MergeHandler +from hls4ml.converters.keras_v3.einsum_dense import KV3EinsumDenseHandler + +if TYPE_CHECKING: + import squark + from keras.api import KerasTensor, Layer + + +def extract_fixed_quantizer_config(q, tensor: 'KerasTensor', is_input: bool) -> dict[str, Any]: + from keras.api.ops import convert_to_numpy + from squark.quantizer.internal.fixed_point_quantizer import FixedPointQuantizerKBI, FixedPointQuantizerKIF + + internal_q: FixedPointQuantizerKIF | FixedPointQuantizerKBI = q.quantizer + + shape: tuple[int, ...] = tensor.shape[1:] # type: ignore + if any([s is None for s in shape]): + raise ValueError(f"Tensor {tensor.name} has at least one dimension with no fixed size") + k, i, f = internal_q.kif + k, B, I = k, k + i + f, k + i # type: ignore + k, B, I = convert_to_numpy(k), convert_to_numpy(B), convert_to_numpy(I) + + k = np.broadcast_to(k.astype(np.int8), (1,) + shape) + B = np.broadcast_to(B.astype(np.int8), (1,) + shape) + I = np.broadcast_to(I.astype(np.int8), (1,) + shape) + + overflow_mode = internal_q.overflow_mode + round_mode = internal_q.round_mode + fusible = np.unique(k).size == 1 and np.unique(B).size == 1 and np.unique(I).size == 1 + + input_keras_tensor_names = tensor.name if is_input else f'{tensor.name}_q' + output_keras_tensor_names = f'{tensor.name}_q' if is_input else tensor.name + return { + 'name': q.name, + 'class_name': 'FixedPointQuantizer', + 'mask_kbi': (k, B, I), + 'SAT': overflow_mode, + 'RND': round_mode, + 'fusible': fusible, + 'input_keras_tensor_names': [input_keras_tensor_names], + 'output_keras_tensor_names': [output_keras_tensor_names], + 'overrides': {}, + } + + +def override_io_tensor_confs(confs: tuple[dict[str, Any], ...], overrides: dict[str, str]): + for conf in confs: + inp_tensor_names = conf['input_keras_tensor_names'] + out_tensor_names = conf['output_keras_tensor_names'] + conf['input_keras_tensor_names'] = [overrides.get(name, name) for name in inp_tensor_names] + conf['output_keras_tensor_names'] = [overrides.get(name, name) for name in out_tensor_names] + + +@register +class SQLayerHandler(KerasV3LayerHandler): + def __call__( + self, + layer: 'squark.layers.QLayerBase', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + ret = super().__call__(layer, in_tensors, out_tensors) + + if layer._enable_iq and hasattr(layer, '_iq'): + if len(in_tensors) > 1: + iq_confs = [extract_fixed_quantizer_config(q, tensor, True) for q, tensor in zip(layer._iq, in_tensors)] + else: + iq_confs = [extract_fixed_quantizer_config(layer._iq, in_tensors[0], True)] + else: + iq_confs = () + + if layer._enable_oq: + if len(out_tensors) > 1: + oq_confs = [extract_fixed_quantizer_config(q, tensor, False) for q, tensor in zip(layer._oq, out_tensors)] + else: + oq_confs = [extract_fixed_quantizer_config(layer._oq, out_tensors[0], False)] + else: + oq_confs = () + + if iq_confs: + _froms = [t.name for t in in_tensors] + _tos = [f'{t.name}_q' for t in in_tensors] + overrides = dict(zip(_froms, _tos)) + override_io_tensor_confs(ret, overrides) + + if oq_confs: + _froms = [t.name for t in out_tensors] + _tos = [f'{t.name}_q' for t in out_tensors] + overrides = dict(zip(_froms, _tos)) + override_io_tensor_confs(ret, overrides) + + return *iq_confs, *ret, *oq_confs + + def load_weight(self, layer: 'Layer', key: str): + from keras.api.ops import convert_to_numpy + + if hasattr(layer, f'q{key}'): + return convert_to_numpy(getattr(layer, f'q{key}')) + return super().load_weight(layer, key) + + +@register +class SQEinsumDenseHandler(SQLayerHandler, KV3EinsumDenseHandler): + handles = ( + 'squark.layers.core.einsum_dense.QEinsumDense', + 'squark.layers.einsum_dense_batchnorm.QEinsumDenseBatchnorm', + ) + + +@register +class SQStandaloneQuantizerHandler(KerasV3LayerHandler): + handles = ('squark.quantizer.quantizer.Quantizer',) + + def handle( + self, + layer: 'squark.quantizer.Quantizer', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + conf = extract_fixed_quantizer_config(layer, in_tensors[0], True) + del conf['output_keras_tensor_names'] + return conf + + +@register +class SQConvHandler(SQLayerHandler, KV3ConvHandler): + handles = ( + 'squark.layers.conv.QConv1D', + 'squark.layers.conv.QConv2D', + # 'squark.layers.conv.QConv3D', + ) + + +@register +class SQDenseHandler(SQLayerHandler, KV3DenseHandler): + handles = ('squark.layers.core.dense.QDense',) + + +@register +class SQActivationHandler(SQLayerHandler, KV3ActivationHandler): + handles = ('squark.layers.activation.QActivation',) + + +@register +class SQBatchNormalizationHandler(SQLayerHandler): + handles = ('squark.layers.batch_normalization.QBatchNormalization',) + + def handle( + self, + layer: 'squark.layers.QBatchNormalization', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + from keras import ops + + scale, offset = layer.qscaler_and_qoffset + scale = ops.convert_to_numpy(scale) + offset = ops.convert_to_numpy(offset) + + assert layer.axis in (len(in_tensors[0].shape) - 1, -1), 'Only batch_norm with axis=-1 is supported' + + return { + 'n_filt': scale.size, + 'n_in': prod(in_tensors[0].shape[1:]), # type: ignore + 'scale_data': scale, + 'bias_data': offset, + } + + +@register +class SQMergeHandler(SQLayerHandler, KV3MergeHandler): + handles = ( + 'squark.layers.ops.merge.QAdd', + 'squark.layers.ops.merge.QSubtract', + 'squark.layers.ops.merge.QMultiply', + 'squark.layers.ops.merge.QAverage', + 'squark.layers.ops.merge.QMaximum', + 'squark.layers.ops.merge.QMinimum', + 'squark.layers.ops.merge.QConcatenate', + ) + + def handle( + self, + layer: 'squark.layers.ops.merge.QMerge', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + cls_name = layer.__class__.__name__[1:] + return super().handle(layer, in_tensors, out_tensors, cls_name) diff --git a/hls4ml/converters/keras_v3/squark/einsum.py b/hls4ml/converters/keras_v3/squark/einsum.py new file mode 100644 index 0000000000..0d0e0ed4c2 --- /dev/null +++ b/hls4ml/converters/keras_v3/squark/einsum.py @@ -0,0 +1,46 @@ +import typing +from typing import Sequence + +from ..einsum_dense import strip_batch_dim +from ._base import SQLayerHandler, register + +if typing.TYPE_CHECKING: + import squark + from keras.api import KerasTensor + + +@register +class SQEinsumHandler(SQLayerHandler): + handles = ('squark.layers.ops.einsum.QEinsum',) + + def handle( + self, + layer: 'squark.layers.QEinsum', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + assert len(in_tensors) == 2, 'Einsum layer must have exactly two input tensors' + assert len(out_tensors) == 1, 'Einsum layer must have exactly one output tensor' + + inp0_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore + inp1_shape: tuple[int, ...] = in_tensors[1].shape[1:] # type: ignore + out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore + + # fmt: off + assert all(d is not None for d in inp0_shape), \ + f'Error when processing {layer.name}: Einsum layer requires full inp shapes, got {inp0_shape} for inp1' + assert all(d is not None for d in inp1_shape), \ + f'Error when processing {layer.name}: Einsum layer requires full inp shapes, got {inp1_shape} for inp2' + assert all(d is not None for d in out_shape), \ + f'Error when processing {layer.name}: EinsumDense layer requires full out shapes. got {out_shape} for output' + # fmt: on + + equation = strip_batch_dim(layer.equation, einsum_dense=False) + + return { + 'class_name': 'Einsum', + 'equation': equation, + 'inp0_shape': inp0_shape, + 'inp1_shape': inp1_shape, + 'out_shape': out_shape, + } diff --git a/hls4ml/converters/keras_v3/squark/multi_head_attention.py b/hls4ml/converters/keras_v3/squark/multi_head_attention.py new file mode 100644 index 0000000000..b580bf90f2 --- /dev/null +++ b/hls4ml/converters/keras_v3/squark/multi_head_attention.py @@ -0,0 +1,119 @@ +import typing +from inspect import Signature +from typing import Sequence + +import numpy as np + +from ._base import SQEinsumDenseHandler, SQLayerHandler, register +from .einsum import SQEinsumHandler +from .softmax import SQSoftmaxHandler + +if typing.TYPE_CHECKING: + import squark + from keras.api import KerasTensor + + +@register +class SQMultiHeadAttentionHandler(SQLayerHandler): + handles = ('squark.layers.multi_head_attention.QMultiHeadAttention',) + + def handle( + self, + layer: 'squark.layers.QMultiHeadAttention', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + from keras import KerasTensor + from squark.layers import QEinsum + + assert len(in_tensors) in (3, 4), 'MultiHead layer must have 3 (Q, K, V) or 4 (Q, K, V, M) input tensors' + assert len(out_tensors) == 1, 'Attention score output is not supported yet' + assert len(in_tensors) == 3, 'Mask tensor is not supported yet' + tensor_q, *_ = in_tensors + tensor_O, *tensor_attn = out_tensors + unique_name: str = layer.name + + node_index: int = tensor_q._keras_history.node_index # type: ignore + assert all( + [node_index == inp._keras_history.node_index for inp in layer.input[1:]] + ), f'Critical error handling layer {layer.name}' + node = layer._inbound_nodes[node_index] + + args = node.arguments.args + kwargs = node.arguments.kwargs + sig: Signature = layer._call_signature + + # map everything to kwargs + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + tensor_q = bound.arguments['query'] + tensor_k = bound.arguments['key'] + tensor_v = bound.arguments['value'] + tensor_q_mask = bound.arguments['query_mask'] + tensor_k_mask = bound.arguments['key_mask'] + tensor_v_mask = bound.arguments['value_mask'] + tensor_attn_mask = bound.arguments['attention_mask'] + return_scores = bound.arguments['return_attention_scores'] # noqa: F841 + + n_mask_def = np.sum( + [ + tensor_q_mask is not None, + tensor_k_mask is not None, + tensor_v_mask is not None, + tensor_attn_mask is not None, + ] + ) + assert n_mask_def <= 1, f'Layer {layer.name} has {n_mask_def} masks defined, expected at most 1' + + unique_name = f'{layer.name}_{node_index}' + to_Q = layer.query_dense + to_K = layer.key_dense + to_V = layer.value_dense + to_O = layer.output_dense + softmax = layer._softmax + + Q_batch_shape = to_Q.full_output_shape + K_batch_shape = to_K.full_output_shape + V_batch_shape = to_V.full_output_shape + # O_batch_shape = to_O.full_output_shape + n_head = layer.num_heads + score_batch_shape = (None, n_head, *Q_batch_shape[1:-2], *K_batch_shape[1:-2]) + + einsum_QK = QEinsum(layer._dot_product_equation, name=f'{layer.name}_QK', enable_iq=False, enable_oq=False) + einsum_sV = QEinsum(layer._combine_equation, name=f'{layer.name}_aV', enable_iq=False, enable_oq=False) + + tensor_Q = KerasTensor(name=f'{unique_name}_Q', shape=Q_batch_shape) + tensor_K = KerasTensor(name=f'{unique_name}_K', shape=K_batch_shape) + tensor_V = KerasTensor(name=f'{unique_name}_V', shape=V_batch_shape) + + pre_O_shape = (None, *tensor_q.shape[1:-1], layer.num_heads, layer.value_dim) + tensor_pre_O = KerasTensor(name=f'{unique_name}_pre_O', shape=pre_O_shape) + # tensor_O = KerasTensor(name=f'{name}_QK', shape=O_batch_shape) + tensor_pre_score = KerasTensor(name=f'{unique_name}_pre_score', shape=score_batch_shape) + tensor_score = KerasTensor(name=f'{unique_name}_score', shape=score_batch_shape) + + einsum_handler = SQEinsumHandler() + einsum_dense_handler = SQEinsumDenseHandler() + softmax_handler = SQSoftmaxHandler() + + config_to_Q = einsum_dense_handler(to_Q, [tensor_q], [tensor_Q]) + config_to_K = einsum_dense_handler(to_K, [tensor_k], [tensor_K]) + config_to_V = einsum_dense_handler(to_V, [tensor_v], [tensor_V]) + config_einsum_KQ = einsum_handler(einsum_QK, [tensor_K, tensor_Q], [tensor_pre_score]) + config_softmax = softmax_handler(softmax, [tensor_pre_score], [tensor_score]) + config_einsum_sV = einsum_handler(einsum_sV, [tensor_score, tensor_V], [tensor_pre_O]) + config_to_O = einsum_dense_handler(to_O, [tensor_pre_O], [tensor_O]) + + configs = ( + *config_to_Q, + *config_to_K, + *config_to_V, + *config_einsum_KQ, + *config_softmax, + *config_einsum_sV, + *config_to_O, + ) + for conf in configs: + conf['name'] = f'{layer.name}_{conf["name"]}' + return configs diff --git a/hls4ml/converters/keras_v3/squark/softmax.py b/hls4ml/converters/keras_v3/squark/softmax.py new file mode 100644 index 0000000000..d27e4ede2a --- /dev/null +++ b/hls4ml/converters/keras_v3/squark/softmax.py @@ -0,0 +1,131 @@ +import typing +from math import prod +from typing import Sequence + +from hls4ml.model.types import FixedPrecisionType, RoundingMode, SaturationMode + +from ._base import SQLayerHandler, register + +if typing.TYPE_CHECKING: + import squark + from keras.api import KerasTensor + from squark.quantizer.internal import FixedPointQuantizerBase + + +def fixed_quantizer_to_hls4ml_t(q: 'FixedPointQuantizerBase', take_max=False): + from keras import ops + + k, i, f = q.kif + k = ops.convert_to_numpy(k) + i = ops.convert_to_numpy(i) + f = ops.convert_to_numpy(f) + if not take_max: + assert k.size == 1 and i.size == 1 and f.size == 1, 'Only homogeneous quantizer is supported' + k = bool(k.ravel().item()) + i = int(i.ravel().item()) + f = int(f.ravel().item()) + else: + k = bool(k.max()) + i = int(i.max()) + f = int(f.max()) + + k, b, I = k, k + i + f, k + i # noqa: E741 + round_mode = q.round_mode + if round_mode.startswith('S_'): + round_mode = round_mode[2:] # stochastic rounding + round_mode = getattr(RoundingMode, round_mode) + sat_mode = getattr(SaturationMode, q.overflow_mode) + return FixedPrecisionType(b, I, k, rounding_mode=round_mode, saturation_mode=sat_mode) + + +@register +class SQSoftmaxHandler(SQLayerHandler): + handles = ('squark.layers.softmax.QSoftmax',) + + def handle( + self, + layer: 'squark.layers.QSoftmax', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + assert not layer._allow_heterogeneous_table, 'Heterogeneous table is not supported in QSoftmax layer' + if len(layer.axis) == 1: + ax = layer.axis[0] + ax = ax if ax >= 0 else len(in_tensors[0].shape) + ax + # io_stream asserts axis=-1, convert to -1 when it is + n_outer: int = prod(in_tensors[0].shape[1:ax]) # type: ignore + n_inner: int = prod(in_tensors[0].shape[ax + 1 :]) # type: ignore + n_in: int = in_tensors[0].shape[ax] # type: ignore + ax = -1 if ax == len(in_tensors[0].shape) - 1 else ax + else: # softmax along multiple axes + axs = [ax if ax >= 0 else len(in_tensors[0].shape) + ax for ax in layer.axis] + axs = sorted(axs) + assert all(ax1 - ax0 == 1 for ax0, ax1 in zip(axs[:-1], axs[1:])), 'Softmax must act on adjacent axes' + n_outer: int = prod(in_tensors[0].shape[1 : axs[0]]) # type: ignore + n_inner: int = prod(in_tensors[0].shape[axs[-1] + 1 :]) # type: ignore + n_in: int = prod(in_tensors[0].shape[axs[0] : axs[-1] + 1]) # type: ignore + ax = -1 # if n_inner == 1 else 999 # 999 as placeholder + + from keras import ops + from squark.quantizer.internal import FixedPointQuantizerBase + + impl = 'stable' if layer.stable else 'latency' + + if impl == 'stable': + exp_table_size = 2 ** int(ops.convert_to_numpy(ops.max(layer.exp_table.iq.quantizer.bits))) + else: + exp_table_size = None # Placeholder, will be overridden in bit-exact pass + + exp_oq = layer.exp_table.oq.quantizer + inv_oq = layer.inv_table.oq.quantizer + inv_iq = layer.inv_table.iq.quantizer + assert isinstance(exp_oq, FixedPointQuantizerBase), 'Only fixed-point quantizer is supported for exp_table' + exp_table_t = fixed_quantizer_to_hls4ml_t(exp_oq) + inv_table_t = fixed_quantizer_to_hls4ml_t(inv_oq) + inv_inp_t = fixed_quantizer_to_hls4ml_t(inv_iq) + exp_scale = layer.input_scaler + + inv_table_size = 2**inv_inp_t.width + + parallelization_factor = layer.parallelization_factor + + if parallelization_factor < 0: + parallelization_factor = n_outer * n_inner + + if len(in_tensors) == 2: + raise NotImplementedError("Masked softmax not supported yet") + class_name = 'MaskedSoftmax' + elif len(in_tensors) == 1: + class_name = 'Softmax' + else: + raise ValueError(f"Too many inputs for softmax layer {layer.name}: expected 1 or 2, got {len(in_tensors)}") + + config = {} + config.update(self.default_config) + config.update( + { + 'axis': ax, + 'n_in': n_in, + 'activation': 'softmax', + 'n_outer': n_outer, + 'n_inner': n_inner, + 'implementation': impl, + 'exp_table_t': exp_table_t, + 'exp_table_size': exp_table_size, + 'inv_table_t': inv_table_t, + 'inv_table_size': inv_table_size, + 'inv_inp_t': inv_inp_t, + 'exp_scale': exp_scale, + 'parallelization_factor': parallelization_factor, + 'class_name': class_name, + '_bit_exact': True, + } + ) + + if layer.stable: + inp_norm_t = fixed_quantizer_to_hls4ml_t(layer.exp_table.iq.quantizer) + inp_norm_t.saturation_mode = SaturationMode.WRAP + inp_norm_t.rounding_mode = RoundingMode.TRN + config['inp_norm_t'] = inp_norm_t + + return (config,) diff --git a/hls4ml/converters/keras_v3/squark/unary_lut.py b/hls4ml/converters/keras_v3/squark/unary_lut.py new file mode 100644 index 0000000000..8dee49540f --- /dev/null +++ b/hls4ml/converters/keras_v3/squark/unary_lut.py @@ -0,0 +1,99 @@ +import typing +from typing import Sequence + +import numpy as np +from quantizers import float_quantize, get_fixed_quantizer + +from hls4ml.model.types import FixedPrecisionType + +from ._base import KerasV3LayerHandler, SQLayerHandler, register + +if typing.TYPE_CHECKING: + import squark + from keras.api import KerasTensor + +from decimal import Decimal + +from hls4ml.utils.qinterval import minimal_kif + + +@register +class SQUnaryLUTHandler(SQLayerHandler, KerasV3LayerHandler): + handles = ('squark.layers.activation.QUnaryFunctionLUT',) + + def handle( + self, + layer: 'squark.layers.QUnaryFunctionLUT', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + from keras import ops + from squark.quantizer.internal import FixedPointQuantizerBase, FloatPointQuantizer + + if not layer.enable_iq and not layer.enable_oq: + raise ValueError('Currently only support input_quantizer enabled UnaryFunctionLUT layer') + assert not layer._allow_heterogeneous_table, 'Heterogeneous table is not supported in QUnaryFunctionLUT layer' + + iq = layer.iq.quantizer + _min = Decimal(float(ops.min(iq.min))) # type: ignore + _max = Decimal(float(ops.max(iq.max))) # type: ignore + _eps = Decimal(float(ops.min(iq.epsilon))) # type: ignore + N = (_max - _min) / _eps + 1 + assert float(N).is_integer(), 'Invalid quantizer range' + N = int(N) + assert N <= 1e6, 'Too large quantizer range' + assert np.log2(N).is_integer(), f'Invalid quantizer range: N must be power of 2, got {N}' + + all_inputs = ops.linspace(float(_min), float(_max), N) + + config = {} + config.update(self.default_config) + + if isinstance(iq, FixedPointQuantizerBase): + table = ops.convert_to_numpy(layer.activation(all_inputs)) + if _min < 0: + # idx by binary repr, move the positive part to the front + table_pos, table_neg = table[N // 2 :], table[: N // 2] + table = np.concatenate([table_pos, table_neg]) + else: + assert isinstance(iq, FloatPointQuantizer), f'{layer.name}: Unknown quantizer class {type(iq)}' + mee0 = (ops.convert_to_numpy(x) for x in (iq.m, iq.e, iq.e0)) + assert all( + x.size == 1 for x in mee0 + ), f'{layer.name}: Only homogeneous input quantizer is supported for minifloat' + m, e, e0 = (int(x.ravel().item()) for x in mee0) + all_inputs = float_quantize(all_inputs, m, e, e0) + table = ops.convert_to_numpy(layer.activation(all_inputs)) + + oq = layer.oq.quantizer + if isinstance(oq, FixedPointQuantizerBase): + round_mode = oq.round_mode + if round_mode.startswith('S_'): + round_mode = round_mode[2:] + overflow_mode = oq.overflow_mode + fixed_q = get_fixed_quantizer(round_mode, overflow_mode) + k, i, f = (ops.convert_to_numpy(x).ravel().item() for x in oq.kif) + table = fixed_q(table, k, i, f) # type: ignore + + k, b, I = bool(k), k + i + f, k + i # noqa: E741 + table_t = FixedPrecisionType(b, I, k) + else: + assert isinstance(oq, FloatPointQuantizer) + m, e, e0 = (ops.convert_to_numpy(x).ravel().item() for x in (oq.m, oq.e, oq.e0)) + table = float_quantize(table, m, e, e0) + k, i, f = (int(np.min(x)) for x in minimal_kif(table)) + + raise NotImplementedError('FloatPointQuantizer is not supported yet') + table_t = FixedPrecisionType(k + i + f, k + i, bool(k)) + table = ops.convert_to_numpy(table) + + config.update( + { + 'class_name': 'UnaryLUT', + 'table_data': table, + 'table_t': table_t, + 'activation': 'unary_lut', + } + ) + + return (config,) diff --git a/hls4ml/converters/keras_v3_to_hls.py b/hls4ml/converters/keras_v3_to_hls.py new file mode 100644 index 0000000000..5c0168cc1e --- /dev/null +++ b/hls4ml/converters/keras_v3_to_hls.py @@ -0,0 +1,284 @@ +import typing +from itertools import chain +from types import FunctionType +from typing import Any, Callable, Sequence + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + +import numpy as np + +from .keras_v3 import layer_handlers as v3_layer_handlers + +T_kv3_handler = Callable[ + ['keras.Layer', Sequence['keras.KerasTensor'], Sequence['keras.KerasTensor']], tuple[dict[str, Any], ...] +] + + +def get_io_tensors(layer: 'keras.Layer', node_whitelist: set[int] | None = None): + """Given a keras layer, return a list of tuples of input and output + tensors. If the layer is called only once (i.e., no shared layers), + the list will contain only one tuple. + + The layer must have been built before calling this function. + + Parameters + ---------- + layer : keras.Layer + The layer to get input and output tensors from. + node_whitelist : set[int]|None, optional + If not None, only return tensors from nodes with ids in this + set, used to filter out nodes that are not part of the model, by + default None + + + Returns + ------- + list[tuple[tuple['KerasTensor', ...], tuple['KerasTensor', ...]]] + A list of tuples of input and output tensors. + """ + in_nodes = layer._inbound_nodes + if node_whitelist is not None: + in_nodes = [node for node in in_nodes if id(node) in node_whitelist] + + ret: list[tuple[tuple['KerasTensor', ...], tuple['KerasTensor', ...]]] = [] + for node in in_nodes: + in_tensors = tuple(node.arguments.keras_tensors) + out_tensors = tuple(node.outputs) + ret.append((in_tensors, out_tensors)) + return ret + + +def resolve_dependency_relation(model: 'keras.Model'): + """Given a keras model, return the following information: + - A list of input tensor names + - A list of output tensor names + - A list of (layer_name, input_tensor_names, output_tensor_names) tuples + - A dictionary of tensor_name -> KerasTensor + + Parameters + ---------- + model : keras.Model + The keras model to analyze. + + Returns + ------- + tuple[tuple[str, ...], tuple[str, ...], list[tuple[str, tuple[str, ...], tuple[str, ...]]], dict[str, KerasTensor]] + inp_tensor_names, out_tensor_names, layer_io, tensors + """ + tensors: dict[str, 'KerasTensor'] = {} + "tensor_name -> KerasTensor" + depends_on: dict[str, tuple[str, ...]] = {} + "tensor_name -> {tensor_name}" + layer_io: list[tuple[str, tuple[str, ...], tuple[str, ...]]] = [] + "layer_name -> ((input_tensor_names), (output_tensor_names))" + + inputs = tuple(t.name for t in model.inputs) + outputs = tuple(t.name for t in model.outputs) + node_whitelist = {id(node) for v in model._nodes_by_depth.values() for node in v} + + for layer in model.layers: + for in_tensors, out_tensors in get_io_tensors(layer, node_whitelist): + in_tensor_names = tuple(t.name for t in in_tensors) + out_tensor_names = tuple(t.name for t in out_tensors) + for t in chain(in_tensors, out_tensors): + tensors[t.name] = t + for o_name in out_tensor_names: + depends_on[o_name] = in_tensor_names + layer_io.append((layer.name, in_tensor_names, out_tensor_names)) + + return inputs, outputs, layer_io, tensors + + +class UniqueName: + """Helper class to generate unique names for layers, if one being used multiple times.""" + + def __init__(self): + self.used_names: set[str] = set() + + def next_name(self, name: str): + i = 0 + if name in self.used_names: + while f'{name}_{i}' in self.used_names: + i += 1 + name = f'{name}_{i}' + self.used_names.add(name) + return name + + def __call__(self, name: str): + return self.next_name(name) + + def reset(self): + self.used_names.clear() + + +class KerasV3HandlerDispatcher: + """Dispatcher class to handle different types of keras v3 layers.""" + + def __init__(self, layer_handlers: dict[str, T_kv3_handler], v2_layer_handlers=None): + self.registry = layer_handlers + self.v2_layer_handlers = v2_layer_handlers or {} + + def __call__( + self, layer: 'keras.Layer', in_tensors: Sequence['keras.KerasTensor'], out_tensors: Sequence['keras.KerasTensor'] + ) -> tuple[dict[str, Any], ...]: + assert layer.built, f"Layer {layer.name} is not built" + + ret = self.v3_call(layer, in_tensors, out_tensors) + if ret is not None: + return ret + ret = self.v2_call(layer, in_tensors, out_tensors) + if ret is not None: + return ret + + raise ValueError( + f"Layer {layer.__class__.__module__}.{layer.__class__.__name__} not found in either v3 or v2 handlers" + ) + + def v3_call( + self, layer: 'keras.layers.Layer', inp_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor'] + ): + cls_name = layer.__class__.__name__ + module = layer.__module__ + key = f"{module}.{cls_name}" + + # keras v3 handlers + handler = self.registry.get(key, None) + handler = handler or self.registry.get(cls_name, None) + + if handler is None: + return None + return handler(layer, inp_tensors, out_tensors) + + def v2_call( + self, layer: 'keras.layers.Layer', inp_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor'] + ): + # keras v2 handlers fallback + print(f"v2 handler used for layer {layer.name}") + + import keras + + config = layer.get_config() + layer_dict = {'config': config, 'class_name': layer.__class__.__name__} + + class DummyReader: + def get_weights_data(self, layer_name, var_name): + assert layer_name == layer.name, f"Processing {layer.name}, but handler tried to read {layer_name}" + for w in layer.weights: + if var_name in w.name: + return np.array(w) + return None + + reader = DummyReader() + input_shapes = [list(t.shape) for t in inp_tensors] + input_names = [t.name for t in inp_tensors] + output_names = [t.name for t in out_tensors] + key = layer.__class__.__name__ + handler = self.v2_layer_handlers.get(key, None) + if handler is None: + return None + + ret, _ = handler(layer_dict, input_names, input_shapes, reader) + ret['output_keras_tensor_names'] = output_names + ret['input_keras_tensor_names'] = input_names + ret = (ret,) + + activation = getattr(layer, 'activation', None) + if activation not in (keras.activations.linear, None): + assert isinstance(activation, FunctionType), f"Activation function for layer {layer.name} is not a function" + intermediate_tensor_name = f'{output_names[0]}_activation' + ret[0]['output_keras_tensor_names'] = (intermediate_tensor_name,) + act_cls_name = activation.__name__ + act_config = { + 'class_name': 'Activation', + 'activation': act_cls_name, + 'name': f'{layer.name}_{act_cls_name}', + 'input_keras_tensor_names': (intermediate_tensor_name,), + 'output_keras_tensor_names': output_names, + } + ret = *ret, act_config + return ret + + +def parse_keras_v3_model(model: 'keras.Model'): + """Parse a keras model into a list of dictionaries, each + representing a layer in the HLS model, and a list of input and + output layer names. + + Parameters + ---------- + model : keras.Model + + Returns + ------- + tuple[list[dict[str, Any]], list[str], list[str], list[list[int]]] + layer_list, input_layer_names, output_layer_names, + batch_output_shapes + + Raises + ------ + ValueError + If a circular dependency is detected. + """ + + assert model.built, "Model must be built before parsing" + + import keras + + if isinstance(model, keras.Sequential): + model = model._functional # everything is functional under the hood lol + + from .keras_to_hls import layer_handlers as v2_layer_handlers # Delayed import to avoid circular import + + keras_v3_dispatcher = KerasV3HandlerDispatcher(v3_layer_handlers, v2_layer_handlers) + + model_inputs, model_outputs, dependency, tensors = resolve_dependency_relation(model) + + satisfied = set() + + unique_name = UniqueName() + + layer_list: list[dict[str, Any]] = [] + + while any(t not in satisfied for t in model_outputs): + # Until all tensors in the model are satisfied + for i, (layer_name, in_tensor_names, out_tensor_names) in enumerate(dependency): + if not all(t in satisfied for t in in_tensor_names): + continue # Skip layer if some inputs are not ready + if all(t in satisfied for t in out_tensor_names): + continue # Skip layer if the outputs are already satisfied + + layer: 'keras.Layer' = model.get_layer(layer_name) + inp_tensors = [tensors[t] for t in in_tensor_names] + out_tensors = [tensors[t] for t in out_tensor_names] + + _configs = keras_v3_dispatcher(layer, inp_tensors, out_tensors) + # Dispatch to v3 handler if available, else fallback to v2 handler + + # Prevent name conflicts. If a layer is used multiple times, add a suffix to the name. + # At this stage connections between modules are recorded by i/o keras tensor names + for _conf in _configs: + _conf['name'] = unique_name(_conf['name']) + + layer_list.extend(_configs) # Add the layer to the list + satisfied.update(out_tensor_names) # Mark the outputs as satisfied + dependency.pop(i) + break # Restart the loop to add another layer + else: + # If no layer was added in the loop, then there is a circular dependency + raise ValueError("Circular dependency detected") + + # Mark inputs[inp layer name] for ModelGraph to parse from i/o keras tensor names + provides: dict[str, str] = {} # tensor_name -> src_layer_name + for conf in layer_list: + for out_name in conf['output_keras_tensor_names']: + provides[out_name] = conf['name'] + inputs = [provides[tname] for tname in conf['input_keras_tensor_names']] + conf['inputs'] = inputs + + input_layer_names = [provides[tname] for tname in model_inputs] + output_layer_names = [provides[tname] for tname in model_outputs] + batch_output_shapes = [list(tensors[tname].shape) for tname in model_outputs] + + return layer_list, input_layer_names, output_layer_names, batch_output_shapes diff --git a/hls4ml/converters/onnx_to_hls.py b/hls4ml/converters/onnx_to_hls.py index 75850fa93e..d51701e726 100644 --- a/hls4ml/converters/onnx_to_hls.py +++ b/hls4ml/converters/onnx_to_hls.py @@ -1,7 +1,5 @@ -import onnx -from onnx import helper, numpy_helper - from hls4ml.model import ModelGraph +from hls4ml.utils.dependency import requires # ----------------------Helpers--------------------- @@ -20,7 +18,10 @@ def replace_char_inconsitency(name): return name.replace('.', '_') +@requires('onnx') def get_onnx_attribute(operation, name, default=None): + from onnx import helper + attr = next((x for x in operation.attribute if x.name == name), None) if attr is None: value = default @@ -74,8 +75,11 @@ def get_input_shape(graph, node): return rv +@requires('onnx') def get_constant_value(graph, constant_name): tensor = next((x for x in graph.initializer if x.name == constant_name), None) + from onnx import numpy_helper + return numpy_helper.to_array(tensor) @@ -257,6 +261,7 @@ def parse_onnx_model(onnx_model): return layer_list, input_layers, output_layers +@requires('onnx') def onnx_to_hls(config): """Convert onnx model to hls model from configuration. @@ -273,6 +278,8 @@ def onnx_to_hls(config): # Extract model architecture print('Interpreting Model ...') + import onnx + onnx_model = onnx.load(config['OnnxModel']) if isinstance(config['OnnxModel'], str) else config['OnnxModel'] layer_list, input_layers, output_layers = parse_onnx_model(onnx_model) diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 79ca1fa5c6..f279a1970a 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -1,6 +1,5 @@ -import torch - from hls4ml.model import ModelGraph +from hls4ml.utils.dependency import requires class PyTorchModelReader: @@ -24,8 +23,11 @@ def get_weights_data(self, layer_name, var_name): return data +@requires('_torch') class PyTorchFileReader(PyTorchModelReader): # Inherit get_weights_data method def __init__(self, config): + import torch + self.config = config if not torch.cuda.is_available(): @@ -103,6 +105,7 @@ def decorator(function): # ---------------------------------------------------------------- +@requires('_torch') def parse_pytorch_model(config, verbose=True): """Convert PyTorch model to hls4ml ModelGraph. @@ -368,6 +371,7 @@ def parse_pytorch_model(config, verbose=True): return layer_list, input_layers +@requires('_torch') def pytorch_to_hls(config): layer_list, input_layers = parse_pytorch_model(config) print('Creating HLS model') diff --git a/hls4ml/model/__init__.py b/hls4ml/model/__init__.py index fc504392b6..4ca72e3cd6 100644 --- a/hls4ml/model/__init__.py +++ b/hls4ml/model/__init__.py @@ -1,8 +1 @@ from hls4ml.model.graph import HLSConfig, ModelGraph # noqa: F401 - -try: - from hls4ml.model import profiling # noqa: F401 - - __profiling_enabled__ = True -except ImportError: - __profiling_enabled__ = False diff --git a/hls4ml/model/attributes.py b/hls4ml/model/attributes.py index d03d2bd108..9d7b78c9db 100644 --- a/hls4ml/model/attributes.py +++ b/hls4ml/model/attributes.py @@ -36,7 +36,7 @@ class Attribute: """ - def __init__(self, name, value_type=Integral, default=None, configurable=False, description=None): + def __init__(self, name, value_type: type = Integral, default=None, configurable=False, description=None): self.name = name self.value_type = value_type self.default = default diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 520f96ba5f..07339c9709 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -417,6 +417,8 @@ def _apply_sub_flow(self, flow_name, applied_flows): for sub_flow in flow.requires: if sub_flow not in applied_flows.keys(): + # if sub_flow != 'convert': + # continue self._apply_sub_flow(sub_flow, applied_flows) if len(flow.optimizers) > 0: @@ -732,7 +734,7 @@ def _get_top_function(self, x): if x0.dtype in [np.single, np.float32]: top_function = getattr(self._top_function_lib, self.config.get_project_name() + '_float') ctype = ctypes.c_float - elif x0.dtype in [np.double, np.float64, np.float_]: + elif x0.dtype in [np.double, np.float64]: top_function = getattr(self._top_function_lib, self.config.get_project_name() + '_double') ctype = ctypes.c_double else: diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 3847cda9cf..f0d20b824a 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -21,16 +21,24 @@ FixedPrecisionType, IntegerPrecisionType, NamedType, + PrecisionType, + RoundingMode, + SaturationMode, TensorVariable, UnspecifiedPrecisionType, WeightVariable, find_minimum_width, ) from hls4ml.utils import attribute_descriptions as descriptions +from hls4ml.utils.einsum_utils import parse_einsum from hls4ml.utils.string_utils import convert_to_snake_case +if typing.TYPE_CHECKING: + from hls4ml.model import ModelGraph # TODO move this to some utility module + + class classproperty: def __init__(self, func): self.func = func @@ -80,7 +88,7 @@ def __init__(self, model, name, attributes, inputs, outputs=None): "No model layer should be named 'input' because that is a reserved;" + "layer name in ModelGraph; Please rename the layer in your model" ) - self.model = model + self.model: 'ModelGraph' = model self.name = name self.index = model.next_layer() self.inputs = inputs @@ -145,6 +153,9 @@ def _validate_attributes(self): # Validate existing attributes for attr_name, attr_value in self.attributes.items(): + if isinstance(attr_value, PrecisionType): + attr_value = self._wrap_precision_to_type(f'{self.name}_{attr_name}', attr_value) + self.set_attr(attr_name, attr_value) exp_attr = all_attributes.pop(attr_name, None) if exp_attr is not None: if not exp_attr.validate_value(attr_value): @@ -910,7 +921,8 @@ def initialize(self): shape = inp.shape dims = inp.dim_names self.add_output_variable(shape, dims) - self.set_attr('n_in', self.get_input_variable().size()) + if 'n_in' not in self.attributes: + self.set_attr('n_in', self.get_input_variable().size()) class ParametrizedActivation(Activation): @@ -975,6 +987,31 @@ def initialize(self): class Softmax(Activation): + _expected_attributes = [ + Attribute('n_in'), + Attribute('activation', value_type=str), + Attribute('n_outer', value_type=int, default=1), + Attribute('n_inner', value_type=int, default=1), + ChoiceAttribute('implementation', ['latency', 'stable', 'argmax', 'legacy'], default='stable'), + ConfigurableAttribute('skip', value_type=bool, default=False), + TypeAttribute( + 'exp_table', + default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), + ), + TypeAttribute( + 'inv_table', + default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), + ), + TypeAttribute( + 'inv_inp', + default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), + ), + TypeAttribute( + 'accum', + default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), + ), + ] + def initialize(self): super().initialize() @@ -1016,16 +1053,21 @@ def initialize(self): dims = inp.dim_names self.add_output_variable(shape, dims) - gamma = self.get_attr('gamma_data') - beta = self.get_attr('beta_data') - mean = self.get_attr('mean_data') - var = self.get_attr('variance_data') - - scale = gamma / np.sqrt(var + self.get_attr('epsilon')) - bias = beta - scale * mean + if self.get_attr('scale_data') is None: + gamma = self.get_attr('gamma_data') + var = self.get_attr('variance_data') + scale = gamma / np.sqrt(var + self.get_attr('epsilon')) + self.add_weights_variable(name='scale', var_name='s{index}', data=scale) + else: + self.add_weights_variable(name='scale', var_name='s{index}') - self.add_weights_variable(name='scale', var_name='s{index}', data=scale) - self.add_weights_variable(name='bias', var_name='b{index}', data=bias) + if self.get_attr('bias_data') is None: + beta = self.get_attr('beta_data') + mean = self.get_attr('mean_data') + bias = beta - scale * mean + self.add_weights_variable(name='bias', var_name='b{index}', data=bias) + else: + self.add_weights_variable(name='bias', var_name='b{index}') # TODO: discuss whether this should be renamed to soemthing more descriptive, and whether the class hierarchy makes sense @@ -1221,8 +1263,7 @@ def initialize(self): perm = self.get_attr('perm') self.set_attr('dim', f'{len(inp.shape)}d') - if len(perm) > 3: - raise Exception('ERROR: Transpose of tensors with rank > 3 is not yet supported.') + # TODO: dim>3 is only supported for vivado/vitis backend # ONNX double transpose specific, sometimes ONNX injects # useless double transpose layers when converting @@ -1242,11 +1283,14 @@ def initialize(self): self.set_attr('depth', 1) self.set_attr('height', inp.shape[0]) self.set_attr('width', inp.shape[1]) - elif len(shape) > 2: + elif len(shape) == 3: dims = [f'OUT_DEPTH_{self.index}', f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}'] self.set_attr('depth', inp.shape[0]) self.set_attr('height', inp.shape[1]) self.set_attr('width', inp.shape[2]) + elif len(shape) > 3: + # Differentiate between 2/3/3+ dim does not really appear to be needed. To be removed? + dims = [f'OUT_DIM_{i}_{self.index}' for i in range(1, len(shape) + 1)] self.add_output_variable(shape, dims, precision=inp.type.precision) @@ -1616,6 +1660,131 @@ def initialize(self): self.add_output_variable([len(self.get_attr('expression'))], [f'N_OUTPUTS_{self.index}'], var_name='y') +class EinsumDense(Layer): + _expected_attributes = [ + WeightAttribute('weight'), + WeightAttribute('bias'), + TypeAttribute('weight'), + TypeAttribute('bias'), + TypeAttribute('accum'), + Attribute('equation', value_type=str), + Attribute('inp_shape', value_type=tuple), + Attribute('out_shape', value_type=tuple), + ] + + def initialize(self): + out_shape = self.attributes['out_shape'] + if len(out_shape) > 1: + dims = [f'N_LAYER_{self.index}_D{i}' for i in range(1, len(out_shape) + 1)] + else: + dims = [f'N_LAYER_{self.index}'] + self.add_output_variable(list(out_shape), dims) + + kernel: np.ndarray = self.attributes.attributes['weight_data'] + bias: np.ndarray | None = self.attributes.attributes['bias_data'] + equation = self.attributes['equation'] + inp_shape = self.attributes['inp_shape'] + out_shape = self.attributes['out_shape'] + + kernel_shape = kernel.shape + recipe = parse_einsum(equation, inp_shape, kernel_shape) + assert not any(recipe['direct_sum_axis']), ( + 'Do not put direct sum indices (e.g., only appears in one of the operands) in the equation.' + 'Use explicit addition operator before instead.' + ) + inp_tpose_idxs, ker_tpose_idxs = recipe['in_transpose_idxs'] + out_tpose_idxs = recipe['out_transpose_idxs'] + + # Pre-transpose kernel (and bias) to save a transpose in cpp. Shouldn't matter for latency strategy though. + # hls4ml dense acts like i,ij->j + # parser assumes ij,j->i, so we need to transpose the kernel to match + kernel = kernel.transpose(ker_tpose_idxs) + kernel = kernel.reshape(recipe['I'], recipe['L1'], recipe['C']).transpose(0, 2, 1) + + def to_original_kernel(tkernel: np.ndarray) -> np.ndarray: + _kernel = tkernel.transpose(0, 2, 1) + _kernel = _kernel.reshape(tuple(kernel_shape[i] for i in ker_tpose_idxs)) + return _kernel.transpose(np.argsort(ker_tpose_idxs)) + + # TODO: for weight in bram mode (resource), broadcasting bias here shall be avoided. + if bias is not None: + bias = np.broadcast_to(bias, out_shape).transpose(np.argsort(out_tpose_idxs)) + else: + # The automatically created bias is just the last dimension of the output shape + # Which is too small in general for einsum dense. + # The transpose is just to match the shape in case of have real bias, no real effect. + bias = np.zeros(out_shape).transpose(np.argsort(out_tpose_idxs)) + + self.attributes.attributes['weight_data'] = kernel + self.attributes.attributes['to_original_kernel'] = to_original_kernel + self.attributes.attributes['bias_data'] = bias + self.attributes['inp_tpose_idxs'] = inp_tpose_idxs + self.attributes['out_tpose_idxs'] = out_tpose_idxs + self.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] + self.attributes['n_free_data'] = recipe['L0'] + self.attributes['n_free_kernel'] = recipe['L1'] + self.attributes['n_inplace'] = recipe['I'] + self.attributes['n_contract'] = recipe['C'] + pf = self.attributes.attributes.get('parallelization_factor', recipe['L0']) + self.attributes['parallelization_factor'] = pf + + self.add_weights(compression=self.model.config.get_compression(self)) + self.add_bias() + + +class Matmul(Layer): + _expected_attributes = [ + TypeAttribute('accum'), + Attribute('inup1_shape', value_type=tuple), + Attribute('inp2_shape', value_type=tuple), + ] + + +class Einsum(Layer): + _expected_attributes = [ + TypeAttribute('accum'), + Attribute('equation', value_type=str), + Attribute('inp0_shape', value_type=tuple), + Attribute('inp1_shape', value_type=tuple), + Attribute('out_shape', value_type=tuple), + ] + + def initialize(self): + out_shape = self.attributes['out_shape'] + if len(out_shape) > 1: + dims = [f'N_LAYER_{self.index}_D{i}' for i in range(1, len(out_shape) + 1)] + else: + dims = [f'N_LAYER_{self.index}'] + self.add_output_variable(list(out_shape), dims) + + equation = self.attributes['equation'] + inp0_shape = self.attributes['inp0_shape'] + inp1_shape = self.attributes['inp1_shape'] + out_shape = self.attributes['out_shape'] + + recipe = parse_einsum(equation, inp0_shape, inp1_shape) + assert not any(recipe['direct_sum_axis']), ( + 'Do not put direct sum indices (e.g., only appears in one of the operands) in the equation.' + 'Use explicit addition operator before instead.' + ) + inp0_tpose_idxs, inp1_tpose_idxs = recipe['in_transpose_idxs'] + out_tpose_idxs = recipe['out_transpose_idxs'] + + self.attributes.attributes.update(recipe) + self.attributes['n_free0'] = recipe['L0'] + self.attributes['n_free1'] = recipe['L1'] + self.attributes['n_inplace'] = recipe['I'] + self.attributes['n_contract'] = recipe['C'] + self.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] + + self.attributes['inp0_tpose_idxs'] = inp0_tpose_idxs + self.attributes['inp1_tpose_idxs'] = inp1_tpose_idxs + self.attributes['out_tpose_idxs'] = out_tpose_idxs + + pf = self.attributes.attributes.get('parallelization_factor', recipe['L0']) + self.attributes['parallelization_factor'] = pf + + layer_map = { 'Input': Input, 'InputLayer': Input, @@ -1684,6 +1853,8 @@ def initialize(self): 'SymbolicExpression': SymbolicExpression, # TensorFlow-specific layers: 'BiasAdd': BiasAdd, + 'EinsumDense': EinsumDense, + 'Einsum': Einsum, } diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index a745eceba1..391a3934b5 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -71,7 +71,10 @@ 'fuse_consecutive_batch_normalization', 'fuse_batch_normalization', 'replace_multidimensional_dense_with_conv', - 'enforce_proxy_model_embedded_config', + # 'enforce_proxy_model_embedded_config', + 'bit_exact', + 'fuse_fixed_point_quantizer', + 'fix_input_precision', 'eliminate_linear_activation', 'merge_linear_activation', # many of the above optimzers need to be done before this diff --git a/hls4ml/model/optimizer/passes/bit_exact.py b/hls4ml/model/optimizer/passes/bit_exact.py new file mode 100644 index 0000000000..85c212119c --- /dev/null +++ b/hls4ml/model/optimizer/passes/bit_exact.py @@ -0,0 +1,659 @@ +import typing +from copy import copy +from functools import reduce, singledispatch +from math import ceil, log2, prod +from typing import Sequence +from warnings import warn + +import numpy as np +from numpy.typing import NDArray + +from hls4ml.model.layers import ( + Activation, + BatchNormalization, + Concatenate, + Conv1D, + Conv2D, + Dense, + Einsum, + EinsumDense, + GlobalPooling1D, + GlobalPooling2D, + Input, + Layer, + Merge, + Pooling1D, + Pooling2D, + Reshape, + Softmax, +) +from hls4ml.model.optimizer import ModelOptimizerPass, OptimizerPass +from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer, UnaryLUT +from hls4ml.model.types import FixedPrecisionType, NamedType, RoundingMode, SaturationMode, WeightVariable +from hls4ml.utils.qinterval import QIntervalArray, einsum, minimal_kif + +if typing.TYPE_CHECKING: + from hls4ml.model import ModelGraph + + +KIF_t = tuple[NDArray[np.int8], NDArray[np.int8], NDArray[np.int8]] + + +def to_hls4ml_fixed(k, i, f, name, *args): + signed, b, I = k != 0, int(k + i + f), int(k + i) + if b <= 0: + b = 1 + I = 0 + args = [arg.upper() for arg in args] + ptype = FixedPrecisionType(b, I, signed, *args) + return NamedType(name, ptype) + + +def get_input_layers(layer: Layer): + model: 'ModelGraph' = layer.model + inp_names = layer.inputs + return [model.graph[name] for name in inp_names] + + +def get_output_layers(layer: Layer): + model: 'ModelGraph' = layer.model + return [l for l in model.graph.values() if layer.name in l.attributes.get('inputs', ())] + + +def get_output_shape(layer: Layer) -> tuple[int, ...]: + return tuple(layer.get_output_variable().shape) + + +def get_input_shapes(layer: Layer) -> list[tuple[int, ...]]: + return [get_output_shape(inp) for inp in get_input_layers(layer)] + + +def _maximum_kif_at_shape(shape: tuple[int, ...]): + k = np.ones(shape, dtype=np.int8) + i = np.full(shape, 126, dtype=np.int8) + f = np.full(shape, 126, dtype=np.int8) + return k, i, f + + +@singledispatch +def request_kif(layer: Layer) -> tuple[KIF_t, ...]: + input_shapes = get_input_shapes(layer) + return tuple(_maximum_kif_at_shape(shape) for shape in input_shapes) + + +@request_kif.register +def _(layer: FixedPointQuantizer): + assert layer.mask_kbi is not None + k, b, I = layer.mask_kbi + k, i, f = k, I - k, b - I + + out_shape = get_output_shape(layer) + k = np.broadcast_to(k[0], out_shape).astype(np.int8) + i = np.broadcast_to(i[0], out_shape).astype(np.int8) + f = np.broadcast_to(f[0], out_shape).astype(np.int8) + + if layer.SAT != 'WRAP': + k[:] = 1 + i[:] = 126 + if layer.RND == 'TRN': + pass + elif layer.RND == 'RND': + f += 1 + else: + f += 3 + return ((k, i, f),) + + +@request_kif.register +def _(layer: Reshape): + inp_shape = get_input_shapes(layer)[0] + k, i, f = requested_kif(layer) + k = k.reshape(inp_shape) + i = i.reshape(inp_shape) + f = f.reshape(inp_shape) + return ((k, i, f),) + + +@request_kif.register +def _(layer: Activation): + fn_name = layer.attributes.attributes.get('activation') + if fn_name == 'linear': + return (requested_kif(layer),) + if fn_name == 'relu': + k, i, f = requested_kif(layer) + k = np.ones_like(k) + return ((k, i, f),) + inp_shape = get_input_shapes(layer)[0] + return (_maximum_kif_at_shape(inp_shape),) + + +@request_kif.register +def _(layer: Concatenate): + inp_shape0, inp_shape1 = get_input_shapes(layer) + k, i, f = requested_kif(layer) + ax = layer.attributes['axis'] + n_split = inp_shape0[ax] + + k0, k1 = np.split(k, [n_split], axis=ax) + i0, i1 = np.split(i, [n_split], axis=ax) + f0, f1 = np.split(f, [n_split], axis=ax) + + return ((k0, i0, f0), (k1, i1, f1)) + + +def requested_kif(layer: Layer) -> KIF_t: + out_layers = get_output_layers(layer) + out_shape = get_output_shape(layer) + if not out_layers: + return _maximum_kif_at_shape(out_shape) + + k = np.zeros(out_shape, dtype=np.int8) + i = np.full(out_shape, -127, dtype=np.int8) + f = i.copy() + for out_layer in out_layers: + _kif_s = request_kif(out_layer) + out_layer_inp_layers = get_input_layers(out_layer) + idx = out_layer_inp_layers.index(layer) + k = np.maximum(k, _kif_s[idx][0]) + i = np.maximum(i, _kif_s[idx][1]) + f = np.maximum(f, _kif_s[idx][2]) + + return k, i, f + + +@singledispatch +def produce_kif(layer: Layer) -> KIF_t: + raise NotImplementedError(f'No implementation of produce_kif for {layer.__class__}') + + +@produce_kif.register +def _(layer: Input): + k = np.ones(get_output_shape(layer), dtype=np.int8) + i = f = np.full(get_output_shape(layer), 126, dtype=np.int8) + return k, i, f + + +def get_input_kifs(layer: Layer): + return [produce_kif(l) for l in get_input_layers(layer)] + + +@produce_kif.register +def _(layer: FixedPointQuantizer): + assert layer.mask_kbi is not None + k, b, I = layer.mask_kbi + k, i, f = k, I - k, b - I + + out_shape = get_output_shape(layer) + k = np.broadcast_to(k[0], out_shape) + i = np.broadcast_to(i[0], out_shape) + f = np.broadcast_to(f[0], out_shape) + + return k, i, f + + +@produce_kif.register +def _(layer: Reshape): + out_shape = get_output_shape(layer) + k, i, f = produce_kif(get_input_layers(layer)[0]) + return k.reshape(out_shape), i.reshape(out_shape), f.reshape(out_shape) + + +@produce_kif.register +def _(layer: Merge): + op = layer.attributes.attributes['op'].lower() + kif_ins = get_input_kifs(layer) + match op: + case 'add': + qint_ins = [QIntervalArray.from_kif(*kif) for kif in kif_ins] + k, i, f = reduce(lambda a, b: a + b, qint_ins).to_kif() # type: ignore + return k.astype(np.int8), i, f + case 'concatename': + axis = layer.attributes.attributes['axis'] + _ks, _is, _fs = zip(*[kif for kif in kif_ins]) + k = np.concatenate(_ks, axis=axis) + i = np.concatenate(_is, axis=axis) + f = np.concatenate(_fs, axis=axis) + return k, i, f + case _: + raise NotImplementedError(f'No implementation of Merge for {op}') + + +@produce_kif.register +def _(layer: EinsumDense): + t_kernel = layer.attributes.attributes['weight'].data + to_original_kernel = layer.attributes.attributes['to_original_kernel'] + kernel = to_original_kernel(t_kernel) + _bias = layer.attributes.attributes['bias'] + eq = layer.attributes.attributes['equation'] + k_in, i_in, f_in = get_input_kifs(layer)[0] + qint_in = QIntervalArray.from_kif(k_in, i_in, f_in) + qint_out = einsum(eq, qint_in, kernel) + if _bias is not None: + t_bias = _bias.data + bias = t_bias.transpose(layer.attributes.attributes['out_tpose_idxs']) + qint_out = qint_out + bias + k, i, f = qint_out.to_kif() + return k.astype(np.int8), i, f + + +@produce_kif.register +def _(layer: Einsum): + kif_in1, kif_in2 = get_input_kifs(layer) + qint_in1 = QIntervalArray.from_kif(*kif_in1) + qint_in2 = QIntervalArray.from_kif(*kif_in2) + eq = layer.attributes.attributes['equation'] + qint_out = einsum(eq, qint_in1, qint_in2) + k, i, f = qint_out.to_kif() + return k.astype(np.int8), i, f + + +@produce_kif.register +def _(layer: Dense): + kernel = layer.attributes.attributes['weight'].data + _bias = layer.attributes.attributes['bias'] + k_in, i_in, f_in = get_input_kifs(layer)[0] + qint_in = QIntervalArray.from_kif(k_in, i_in, f_in) + qint_out = qint_in @ kernel + if _bias is not None: + qint_out = qint_out + _bias.data + k, i, f = qint_out.to_kif() + return k.astype(np.int8), i, f + + +def r_im2col(kernel_size: Sequence[int], arr: np.ndarray, buffer: np.ndarray, axis: int): + w = kernel_size[0] + if len(kernel_size) == 3: # 1D + for i in range(arr.shape[axis] - w + 1): + patch = np.take(arr, range(i, i + w), axis=axis) + buffer[i] = patch.flatten() + else: # 2D+ + for i in range(arr.shape[axis] - w + 1): + patch = arr[i : i + w] + r_im2col(kernel_size[1:], patch, buffer[i], axis + 1) + + +def _im2col(kernel_size: Sequence[int], arr: np.ndarray): + if len(kernel_size) < 3: + return arr + shape = [inp_d - ker_d + 1 for inp_d, ker_d in zip(arr.shape, kernel_size[:-2])] + shape.append(np.prod(kernel_size[:-1])) # type: ignore + buf = np.empty(shape, dtype=arr.dtype) + r_im2col(kernel_size, arr, buf, 0) + return buf + + +def im2col(kernel_size: Sequence[int], *arrs: np.ndarray): + """im2col for multidimensional arrays. Assumes Channel Last format. + + Parameters + ---------- + kernel_size : Sequence[int] + The size of the kernel, in the form (*kernel_shape, ch_in, ch_out) + + *arrs : np.ndarray + The input arrays to be transformed + + Returns + ------- + list[np.ndarray] + The transformed arrays + """ + return [_im2col(kernel_size, arr) for arr in arrs] + + +def pad_arrs(node: Layer, pad_val: float = 0, *arrs: np.ndarray): + out_arrs = [] + if node.class_name.endswith('2D'): + pad_top = node.attributes.attributes['pad_top'] + pad_bottom = node.attributes.attributes['pad_bottom'] + pad_left = node.attributes.attributes['pad_left'] + pad_right = node.attributes.attributes['pad_right'] + for arr in arrs: + r = np.pad(arr, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), constant_values=pad_val) + out_arrs.append(r) + elif node.class_name.endswith('1D'): + pad_left = node.attributes.attributes['pad_left'] + pad_right = node.attributes.attributes['pad_right'] + for arr in arrs: + r = np.pad(arr, ((pad_left, pad_right), (0, 0)), constant_values=pad_val) + out_arrs.append(r) + else: + raise ValueError(f'Layer {node.class_name} is not supported for pad_arrs') + return tuple(out_arrs) + + +def stride_arrs(node: Layer, *arrs: np.ndarray): + if node.class_name.endswith('2D'): + st_h = node.attributes.attributes['stride_height'] + st_w = node.attributes.attributes['stride_width'] + return tuple(arr[::st_h, ::st_w] for arr in arrs) + if node.class_name.endswith('1D'): + st_w = node.attributes.attributes['stride_width'] + return tuple(arr[::st_w] for arr in arrs) + raise ValueError(f'Layer {node.class_name} is not supported for stride_arrs') + + +@produce_kif.register(Conv1D) +@produce_kif.register(Conv2D) +def _(layer: Conv1D | Conv2D): + assert layer.attributes.attributes['data_format'] == 'channels_last', 'Only channels_last format is supported' + kernel = layer.attributes.attributes['weight'].data + _bias = layer.attributes.attributes['bias'] + bias = _bias.data if _bias is not None else 0 + k_in, i_in, f_in = get_input_kifs(layer)[0] + k_in, i_in, f_in = pad_arrs(layer, 0, k_in, i_in, f_in) + k_in, i_in, f_in = im2col(kernel.shape, k_in, i_in, f_in) + k_in, i_in, f_in = stride_arrs(layer, k_in, i_in, f_in) + kernel = kernel.reshape(-1, kernel.shape[-1]) + qint_in = QIntervalArray.from_kif(k_in, i_in, f_in) + qint_out = qint_in @ kernel + qint_out = qint_out + bias + k, i, f = qint_out.to_kif() + return k.astype(np.int8), i, f + + +@produce_kif.register(Pooling1D) +@produce_kif.register(Pooling2D) +@produce_kif.register(GlobalPooling1D) +@produce_kif.register(GlobalPooling2D) +def _(layer: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D): + if isinstance(layer, (Pooling1D, GlobalPooling1D)): + px_shape = (layer.attributes['pool_width'],) + else: + px_shape = (layer.attributes['pool_height'], layer.attributes['pool_width']) + ch_out = ch_in = layer.attributes['n_filt'] + + im2col_shape = *px_shape, ch_in, ch_out # conv kernel shape + k_in, i_in, f_in = get_input_kifs(layer)[0] + if isinstance(layer, (Pooling1D, Pooling2D)): + k_in, i_in, f_in = pad_arrs(layer, 0, k_in, i_in, f_in) + k_in, i_in, f_in = im2col(im2col_shape, k_in, i_in, f_in) + if isinstance(layer, (Pooling1D, Pooling2D)): + k_in, i_in, f_in = stride_arrs(layer, k_in, i_in, f_in) + + k_out = k_in.reshape(*k_in.shape[:-1], -1, ch_in).max(axis=-2).astype(np.int8) + i_out = i_in.reshape(*i_in.shape[:-1], -1, ch_in).max(axis=-2).astype(np.int8) + f_out = f_in.reshape(*f_in.shape[:-1], -1, ch_in).max(axis=-2).astype(np.int8) + + pool_op = layer.attributes['pool_op'] + if pool_op == 'Average': + f_add = log2(prod(px_shape)) + if not f_add.is_integer(): + raise ValueError('Average pooling with non-power-of-2 pool size cannot be bit-exact') + f_out += int(f_add) + + return k_out, i_out, f_out + + +@produce_kif.register +def _(layer: BatchNormalization): + k_in, i_in, f_in = get_input_kifs(layer)[0] + qint_in = QIntervalArray.from_kif(k_in, i_in, f_in) + scale = layer.attributes.attributes['scale'].data + + _bias = layer.attributes.attributes['bias'] + bias = _bias.data if _bias is not None else 0 + + qint_out = qint_in * scale + bias + k, i, f = qint_out.to_kif() + return k.astype(np.int8), i, f + + +@produce_kif.register +def _(layer: Softmax): + out_shape = get_output_shape(layer) + + inv_table_t: FixedPrecisionType = layer.attributes['inv_table_t'].precision + exp_table_t: FixedPrecisionType = layer.attributes['exp_table_t'].precision + + b_exp, I_exp = exp_table_t.width, exp_table_t.integer + b_inv, I_inv = inv_table_t.width, inv_table_t.integer + + i_exp, f_exp = I_exp, b_exp - I_exp + i_inv, f_inv = I_inv, b_inv - I_inv + k = np.zeros(out_shape, dtype=np.int8) + i = np.full(out_shape, min(i_exp + i_inv, 1), dtype=np.int8) + f = np.full(out_shape, f_exp + f_inv, dtype=np.int8) + + return k, i, f + + +@produce_kif.register +def _(layer: Concatenate): + kifs_in = get_input_kifs(layer) + ks, is_, fs = zip(*kifs_in) + ax = layer.attributes.attributes['axis'] + k = np.concatenate(ks, axis=ax) + i = np.concatenate(is_, axis=ax) + f = np.concatenate(fs, axis=ax) + return k, i, f + + +@produce_kif.register +def _(layer: Activation): + fn_name = layer.attributes.attributes['activation'] + k, i, f = get_input_kifs(layer)[0] + + if fn_name == 'linear': + return k, i, f + if fn_name == 'relu': + print(k.__class__) + k = np.zeros_like(k) + return k, i, f + if fn_name == 'tanh': + i = np.minimum(i, 1) + f = np.full_like(f, 126) + return k, i, f + if fn_name == 'sigmoid': + k = np.zeros_like(k) + i = np.minimum(i, 1) + f = np.full_like(f, 126) + return k, i, f + + k = np.zeros_like(k) + i = np.full_like(i, 1) + f = np.full_like(f, 126) + return k, i, f + + +@produce_kif.register +def _(layer: UnaryLUT): + table_t = layer.attributes['table_t'].precision + k, I, f = table_t.signed, table_t.integer, table_t.fractional + i = I - k + shape = get_output_shape(layer) + k = np.full(shape, np.max(k), dtype=np.int8) + i = np.full(shape, np.max(i), dtype=np.int8) + f = np.full(shape, np.max(f), dtype=np.int8) + return k, i, f + + +def kif_arrs_to_ints(arr: tuple[np.ndarray, np.ndarray, np.ndarray]): + return tuple(int(np.max(a)) for a in arr) + + +def default_register_precision(layer: Layer): + _pk, _pi, _pf = produce_kif(layer) # Maximum possible k,i,f output from this layer + _rk, _ri, _rf = requested_kif(layer) # Maximum possible k,i,f may be utilized by the next layer + _ok, _oi, _of = np.minimum(_pk, _rk), np.minimum(_pi, _ri), np.minimum(_pf, _rf) + _oi += ((_pf > _rf) & (_pi <= _ri)).astype(np.int8) # Corner cases overflow prevention + + result_kif = kif_arrs_to_ints((_ok, _oi, _of)) + result_t = to_hls4ml_fixed(*result_kif, f'{layer.name}_t') + layer.attributes.attributes['result_t'] = result_t + layer.get_output_variable().type = result_t + + overrides = {} + + # Set accum_t, if exists ONLY for layers with accum_t directly at output (in general, linear DSP operations) + if 'accum_t' in layer.attributes.attributes: + accum_kif = kif_arrs_to_ints((_pk, _pi, _pf)) + accum_t = to_hls4ml_fixed(*accum_kif, f'{layer.name}_accum_t') + overrides['accum_t'] = accum_t + + # Set precision for fixed array (weight_t, bias_t, table_t, etc.) + for w_name_t, v in layer.attributes.attributes.items(): + if not isinstance(v, NamedType) and w_name_t.endswith('_t'): + continue # Not a precision, skip + + w_name = w_name_t[:-2] + if w_name not in layer.attributes.attributes: + continue # No matching data found, skip + + weight_var: WeightVariable = layer.attributes.attributes[w_name] + if weight_var is None: # Corresponding weight not exist, precision to be used nowhere. Put dummy. + precision = to_hls4ml_fixed(0, 0, 1, f'{layer.name}_{w_name_t}') + else: + data = weight_var.data + if not isinstance(data, np.ndarray): + raise ValueError(f'Expected data to be np.ndarray, got {type(data)} on layer {layer.name}') + k, i, f = kif_arrs_to_ints(minimal_kif(data)) + precision = to_hls4ml_fixed(k, i, f, f'{layer.name}_{w_name_t}') + overrides[w_name_t] = precision + + # Apply overrides + for w_name_t, v in overrides.items(): + layer.attributes.attributes[w_name_t] = v + if w_name_t[:-2] in layer.attributes.attributes: + # weight variables need extra steps to update precision + weight_var: WeightVariable = layer.attributes.attributes[w_name_t[:-2]] + weight_var.type = v + weight_var.update_precision(v.precision) + layer.model.config.layer_name_precision[f'{layer.name}_{w_name_t[:-2]}'] = str(v.precision) + + return (_pk, _pi, _pf), (_rk, _ri, _rf), (_ok, _oi, _of) + + +@singledispatch +def register_precision(node: Layer): + default_register_precision(node) + + +@register_precision.register +def _(node: Softmax): + if not node.attributes.get('_bit_exact', False): + # Softmax is not bit-exact by default + warn(f'Softmax layer {node.name} is converted from a frontend not supporting bit-exact softmax.') + return + + inv_inp_t: FixedPrecisionType = node.attributes['inv_inp_t'].precision + accum_t = copy(inv_inp_t) + if inv_inp_t.saturation_mode != SaturationMode.WRAP: + accum_t.saturation_mode = SaturationMode.WRAP + n_in = node.attributes['n_in'] + scale = ceil(log2(n_in)) + accum_t.width += scale + accum_t.integer += scale + if inv_inp_t.rounding_mode == RoundingMode.TRN: + pass + elif inv_inp_t.rounding_mode == RoundingMode.RND: + accum_t.width += 1 + else: + accum_t.width += 3 + accum_t.rounding_mode = RoundingMode.TRN + default_register_precision(node) + impl = node.attributes['implementation'] + match impl: + case 'latency': + k, i, f = get_input_kifs(node)[0] + b = np.max(k) + np.max(i) + np.max(f) + case 'stable': + inp_norm_t: FixedPrecisionType = node.attributes['inp_norm_t'].precision + b = inp_norm_t.width + case 'lagency': + raise ValueError('lagency softmax is not supported') + case 'argmax': + b = 0 + case _: + raise ValueError(f'Unknown softmax implementation {impl}') + + exp_table_size = 2 ** int(b) + node.attributes['exp_table_size'] = exp_table_size + node.attributes['accum_t'] = NamedType(f'{node.name}_accum_t', accum_t) + + +@register_precision.register +def _(node: UnaryLUT): + k, i, f = minimal_kif(node.attributes['table'].data) # type: ignore + k, i, f = bool(np.max(k)), int(np.max(i)), int(np.max(f)) + table_t = to_hls4ml_fixed(k, i, f, f'{node.name}_table_t') + node.attributes['table_t'] = table_t + default_register_precision(node) + + +@register_precision.register(Pooling1D) +@register_precision.register(Pooling2D) +@register_precision.register(GlobalPooling1D) +@register_precision.register(GlobalPooling2D) +def _(node: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D): + default_register_precision(node) + pool_op = node.attributes['pool_op'] + if pool_op != 'Average': + return + if isinstance(node, (Pooling1D, GlobalPooling1D)): + px_shape = (node.attributes['pool_width'],) + else: + px_shape = (node.attributes['pool_height'], node.attributes['pool_width']) + i_add = int(log2(prod(px_shape))) + node.attributes['accum_t'].precision.width += i_add + node.attributes['accum_t'].precision.integer += i_add + + +class BitExact(ModelOptimizerPass): + def __init__(self): + pass + + def _match(self, model: 'ModelGraph'): + if not any(isinstance(node, FixedPointQuantizer) for node in model.graph.values()): + return False + return True + + def transform(self, model): + if not self._match(model): + return False + + for node in model.graph.values(): + if node.attributes.get('bit_exact_transformed'): + return False + register_precision(node) + node.attributes['bit_exact_transformed'] = True + + return False + + +class FixInputPrecision(OptimizerPass): + def match(self, node: Layer): + if not isinstance(node, Input): + return False + + # Unhandled input precision, usually by a heterogeneous quantizer with non-WRAP saturation + return node.get_output_variable().type.precision.width > 120 + + def transform(self, model, node: Layer): + out_layers: list[FixedPointQuantizer] = get_output_layers(node) + + if len(out_layers) == 0: # Input connected to nothing + new_type = to_hls4ml_fixed(0, 0, 1, f'{node.name}_t') + node.get_output_variable().type = new_type + node.model.config.layer_name_precision[node.name] = str(new_type) + return False + + if not all(isinstance(l, FixedPointQuantizer) for l in out_layers): + warn(f'Input {node.name} has unhandled high precision. Consider setting it manually before synthesising.') + return False + + sat_modes = [l.SAT for l in out_layers] + sat_modes_set = set(sat_modes) + illegal_sat_modes = sat_modes_set - {'WRAP', 'SAT', 'SAT_SYM'} + if illegal_sat_modes: + raise ValueError(f'Input {node.name} has quantizer with illegal saturation mode {illegal_sat_modes} after.') + + kifs = [produce_kif(l) for l in out_layers] + i = np.max([np.max(i) for _, i, _ in kifs]) + k = np.max([np.max(k) for k, _, _ in kifs]) + f = node.get_output_variable().type.precision.fractional + new_type = to_hls4ml_fixed(k, i, f, f'{node.name}_t') + new_type.precision.saturation_mode = 'SAT' + node.get_output_variable().type = new_type + node.model.config.layer_name_precision[node.name] = str(new_type) + return False diff --git a/hls4ml/model/optimizer/passes/hgq_proxy_model.py b/hls4ml/model/optimizer/passes/hgq_proxy_model.py index 13e48aac43..10ff48a680 100644 --- a/hls4ml/model/optimizer/passes/hgq_proxy_model.py +++ b/hls4ml/model/optimizer/passes/hgq_proxy_model.py @@ -1,11 +1,19 @@ import re +import typing +from copy import copy from warnings import warn +import numpy as np + from hls4ml.backends.fpga.fpga_types import NamedType -from hls4ml.model.layers import Layer, register_layer +from hls4ml.model.attributes import Attribute, TypeAttribute, WeightAttribute +from hls4ml.model.layers import Layer, Reshape, register_layer from hls4ml.model.optimizer import OptimizerPass, register_pass from hls4ml.model.types import FixedPrecisionType, UnspecifiedPrecisionType, WeightVariable +if typing.TYPE_CHECKING: + from hls4ml.model import ModelGraph + re_purge_prefix = re.compile(r'(?]+)>\s*', re.IGNORECASE) @@ -20,33 +28,27 @@ def initialize(self): self.overrides = self.attributes['overrides'] self.fusible = self.attributes['fusible'] self.SAT, self.RND = self.attributes['SAT'], self.attributes['RND'] - self.mask_kbi = self.attributes.get('mask_kbi', None) + self.mask_kbi = self.attributes['mask_kbi'] class UnaryLUT(Layer): + _expected_attributes = [ + Attribute('n_in'), + TypeAttribute('table_t', default=FixedPrecisionType(18, 8, True)), + WeightAttribute('table'), + ] + def initialize(self): inp = self.get_input_variable() shape = inp.shape dims = inp.dim_names self.add_output_variable(shape, dims) self.set_attr('n_in', inp.size()) - self.table = self.attributes['table'] - self.table_size = self.attributes['table_size'] - - table_t = to_hls4ml_fixed(self.attributes['table_t']) - self.add_weights_variable(name='table', var_name='table{index}', precision=table_t, data=self.table) + self.table = self.attributes['table_data'] + self.attributes['table_size'] = len(self.table) + self.table_size = len(self.table) - -def to_hls4ml_fixed(fixed: str): - matched = re_parse_fixed.match(re_purge_prefix.sub('', fixed)) - assert matched is not None, f'Cannot parse {fixed}' - signed = matched.group(1) != 'u' - b, i, *args = matched.group(2).split(',') - b, i = int(b), int(i) - args = [arg.upper() for arg in args] - new_type = FixedPrecisionType(b, i, signed, *args) - # For some reason, __class__ is overwritten in hls4ml - return new_type + self.add_weights_variable(name='table') def userconf_ifdef(key: str, layer_name: str, model): @@ -74,6 +76,58 @@ def userconf_ifdef(key: str, layer_name: str, model): return key in layer_conf +q_kifRS_t = tuple[np.ndarray, np.ndarray, np.ndarray, str, str] + + +class FuseFixedPointQuantizer(OptimizerPass): + def match(self, node: Layer): + if not isinstance(node, FixedPointQuantizer): + return False + if any(np.unique(x).size > 1 for x in node.mask_kbi): + return False + return True + + def propagate(self, node: Layer, precision: FixedPrecisionType): + from hls4ml.model.optimizer.passes.bit_exact import get_input_layers, get_output_layers + + node.get_output_variable().type.precision = precision + node.attributes.attributes['result_t'].precision = precision + + if not isinstance(node, Reshape): + return node + + inp_layer = get_input_layers(node)[0] + can_propagate = len(get_output_layers(inp_layer)) == 1 + + if not can_propagate: + return node + + new_precision = copy(precision) + precision.saturation_bits = 0 + precision.rounding_mode = 'TRN' + precision.saturation_mode = 'WRAP' + self.propagate(inp_layer, new_precision) + + def transform(self, model: 'ModelGraph', node: FixedPointQuantizer): + from hls4ml.model.optimizer.passes.bit_exact import get_input_layers, get_output_layers + + # Rounding and saturation for FixedPointQuantizer are applied in generated code, thus not reflected in result_t. + if node.RND == 'TRN' and node.SAT == 'WRAP': + precision: FixedPrecisionType = copy(node.get_output_variable().type.precision) + else: + k, b, i = node.mask_kbi + k, b, i = bool(k.ravel()[0]), int(b.ravel()[0]), int(i.ravel()[0]) + precision = FixedPrecisionType(b, i, k, node.RND, node.SAT) + + inp_layer = get_input_layers(node)[0] + can_fuse = len(get_output_layers(inp_layer)) == 1 + if not can_fuse: + return False + self.propagate(inp_layer, precision) + model.remove_node(node) + return True + + class EnforceProxyModelEmbeddedConfig(OptimizerPass): def match(self, node: Layer): if not isinstance(node, FixedPointQuantizer): @@ -86,6 +140,17 @@ def transform(self, model, node: FixedPointQuantizer): if 'layers' not in node.overrides: return False + def to_hls4ml_fixed(fixed: str): + matched = re_parse_fixed.match(re_purge_prefix.sub('', fixed)) + assert matched is not None, f'Cannot parse {fixed}' + signed = matched.group(1) != 'u' + b, i, *args = matched.group(2).split(',') + b, i = int(b), int(i) + args = [arg.upper() for arg in args] + new_type = FixedPrecisionType(b, i, signed, *args) + # For some reason, __class__ is overwritten in hls4ml + return new_type + graph_changed = False layers = node.overrides['layers'] for name, conf in layers.items(): @@ -148,4 +213,5 @@ def register_hgq_proxy_model(): register_layer('HGQ>FixedPointQuantizer', FixedPointQuantizer) register_layer('UnaryLUT', UnaryLUT) register_layer('HGQ>UnaryLUT', UnaryLUT) - register_pass('enforce_proxy_model_embedded_config', EnforceProxyModelEmbeddedConfig) + # register_pass('enforce_proxy_model_embedded_config', EnforceProxyModelEmbeddedConfig) + register_pass('fuse_fixed_point_quantizer', FuseFixedPointQuantizer) diff --git a/hls4ml/model/optimizer/passes/qkeras.py b/hls4ml/model/optimizer/passes/qkeras.py index 03690bed0d..fb02d4eccf 100644 --- a/hls4ml/model/optimizer/passes/qkeras.py +++ b/hls4ml/model/optimizer/passes/qkeras.py @@ -1,5 +1,4 @@ import numpy as np -import tensorflow as tf from hls4ml.model.layers import ApplyAlpha from hls4ml.model.optimizer import ConfigurableOptimizerPass, OptimizerPass, register_pass @@ -113,6 +112,8 @@ def match(self, node): def transform(self, model, node): # The quantizer has to be applied to set the scale attribute # This must be applied to the _unquantized_ weights to obtain the correct scale + import tensorflow as tf + quantizer = node.weights['weight'].quantizer.quantizer_fn # get QKeras quantizer weights = node.weights['weight'].data_unquantized # get weights qweights = quantizer(tf.convert_to_tensor(weights)) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 84a83de23e..6def53f7d1 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -13,12 +13,11 @@ from hls4ml.model.layers import GRU, LSTM, SeparableConv1D, SeparableConv2D try: - import qkeras - from tensorflow import keras + import keras - __tf_profiling_enabled__ = True + __keras_profiling_enabled__ = True except ImportError: - __tf_profiling_enabled__ = False + __keras_profiling_enabled__ = False try: import torch @@ -27,6 +26,19 @@ except ImportError: __torch_profiling_enabled__ = False +try: + import qkeras + + __qkeras_profiling_enabled__ = True +except ImportError: + __qkeras_profiling_enabled__ = False + +__keras_activations = list() +if __keras_profiling_enabled__: + __keras_activations.append(keras.layers.Activation) +if __qkeras_profiling_enabled__: + __keras_activations.append(qkeras.QActivation) + def get_unoptimized_hlsmodel(model): from hls4ml.converters import convert_from_config @@ -482,7 +494,7 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): if hls_model_present: data = weights_hlsmodel(hls_model_unoptimized, fmt='summary', plot=plot) elif model_present: - if __tf_profiling_enabled__ and isinstance(model, keras.Model): + if __keras_profiling_enabled__ and isinstance(model, keras.Model): data = weights_keras(model, fmt='summary', plot=plot) elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential): data = weights_torch(model, fmt='summary', plot=plot) @@ -520,7 +532,7 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): if X is not None: print("Profiling activations" + before) data = None - if __tf_profiling_enabled__ and isinstance(model, keras.Model): + if __keras_profiling_enabled__ and isinstance(model, keras.Model): data = activations_keras(model, X, fmt='summary', plot=plot) elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential): data = activations_torch(model, X, fmt='summary', plot=plot) @@ -590,7 +602,7 @@ def get_ymodel_keras(keras_model, X): if ( hasattr(layer, 'activation') and layer.activation is not None - and not isinstance(layer, (keras.layers.Activation, qkeras.qlayers.QActivation)) + and not isinstance(layer, tuple(__keras_activations)) and layer.activation.__name__ != 'linear' ): tmp_activation = layer.activation diff --git a/hls4ml/model/quantizers.py b/hls4ml/model/quantizers.py index a5b9ceb8c4..eb313fc4ea 100644 --- a/hls4ml/model/quantizers.py +++ b/hls4ml/model/quantizers.py @@ -5,8 +5,6 @@ """ import numpy as np -import tensorflow as tf -from qkeras.quantizers import get_quantizer from hls4ml.model.types import ( ExponentPrecisionType, @@ -16,6 +14,7 @@ SaturationMode, XnorPrecisionType, ) +from hls4ml.utils.dependency import requires class Quantizer: @@ -86,7 +85,10 @@ class QKerasQuantizer(Quantizer): config (dict): Config of the QKeras quantizer to wrap. """ + @requires('qkeras') def __init__(self, config): + from qkeras.quantizers import get_quantizer + self.quantizer_fn = get_quantizer(config) self.alpha = config['config'].get('alpha', None) if config['class_name'] == 'quantized_bits': @@ -106,8 +108,8 @@ def __init__(self, config): self.hls_type = FixedPrecisionType(width=16, integer=6, signed=True) def __call__(self, data): - tf_data = tf.convert_to_tensor(data) - return self.quantizer_fn(tf_data).numpy() + data = np.array(data, dtype='float32') + return self.quantizer_fn(data).numpy() # return self.quantizer_fn(data) def _get_type(self, quantizer_config): @@ -131,7 +133,10 @@ class QKerasBinaryQuantizer(Quantizer): config (dict): Config of the QKeras quantizer to wrap. """ + @requires('qkeras') def __init__(self, config, xnor=False): + from qkeras.quantizers import get_quantizer + self.bits = 1 if xnor else 2 self.hls_type = XnorPrecisionType() if xnor else IntegerPrecisionType(width=2, signed=True) self.alpha = config['config']['alpha'] @@ -141,8 +146,8 @@ def __init__(self, config, xnor=False): self.binary_quantizer = BinaryQuantizer(1) if xnor else BinaryQuantizer(2) def __call__(self, data): - x = tf.convert_to_tensor(data) - y = self.quantizer_fn(x).numpy() + data = np.array(data, dtype='float32') + y = self.quantizer_fn(data).numpy() return self.binary_quantizer(y) @@ -153,15 +158,18 @@ class QKerasPO2Quantizer(Quantizer): config (dict): Config of the QKeras quantizer to wrap. """ + @requires('qkeras') def __init__(self, config): + from qkeras.quantizers import get_quantizer + self.bits = config['config']['bits'] self.quantizer_fn = get_quantizer(config) self.hls_type = ExponentPrecisionType(width=self.bits, signed=True) def __call__(self, data): # Weights are quantized to nearest power of two - x = tf.convert_to_tensor(data) - y = self.quantizer_fn(x) + data = np.array(data, dtype='float32') + y = self.quantizer_fn(data) if hasattr(y, 'numpy'): y = y.numpy() return y diff --git a/hls4ml/model/types.py b/hls4ml/model/types.py index 9d0a97440f..b3b0dea383 100644 --- a/hls4ml/model/types.py +++ b/hls4ml/model/types.py @@ -206,6 +206,18 @@ def __eq__(self, other: object) -> bool: def __hash__(self) -> int: return super().__hash__() ^ hash((self.integer, self.rounding_mode, self.saturation_mode, self.saturation_bits)) + @property + def min(self): + if not self.signed: + return 0.0 + if self.saturation_mode == SaturationMode.SAT_SYM: + return -(2.0 ** (self.integer - 1)) + 2.0**-self.fractional + return -(2.0 ** (self.integer - 1)) + + @property + def max(self): + return 2.0 ** (self.integer - 1) - 2.0**-self.fractional + class XnorPrecisionType(PrecisionType): """ diff --git a/hls4ml/optimization/__init__.py b/hls4ml/optimization/__init__.py index c626b70c2b..2b49886e39 100644 --- a/hls4ml/optimization/__init__.py +++ b/hls4ml/optimization/__init__.py @@ -1,3 +1 @@ -from .dsp_aware_pruning import optimize_keras_model_for_hls4ml # noqa: F401 -from .dsp_aware_pruning.attributes import get_attributes_from_keras_model_and_hls4ml_config # noqa: F401 -from .dsp_aware_pruning.keras import optimize_model # noqa: F401 +# No imports as each of the optimization modules may contain different dependencies. diff --git a/hls4ml/optimization/dsp_aware_pruning/keras/__init__.py b/hls4ml/optimization/dsp_aware_pruning/keras/__init__.py index 29012bd39e..b525f58a33 100644 --- a/hls4ml/optimization/dsp_aware_pruning/keras/__init__.py +++ b/hls4ml/optimization/dsp_aware_pruning/keras/__init__.py @@ -4,9 +4,6 @@ import numpy as np import tensorflow as tf -# Enables printing of loss tensors during custom training loop -from tensorflow.python.ops.numpy_ops import np_config - import hls4ml.optimization.dsp_aware_pruning.keras.utils as utils from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES from hls4ml.optimization.dsp_aware_pruning.keras.builder import build_optimizable_model, remove_custom_regularizers @@ -15,7 +12,6 @@ from hls4ml.optimization.dsp_aware_pruning.keras.reduction import reduce_model from hls4ml.optimization.dsp_aware_pruning.scheduler import OptimizationScheduler -np_config.enable_numpy_behavior() default_regularization_range = np.logspace(-6, -2, num=16).tolist() diff --git a/hls4ml/report/quartus_report.py b/hls4ml/report/quartus_report.py index c337e5de10..677a931402 100644 --- a/hls4ml/report/quartus_report.py +++ b/hls4ml/report/quartus_report.py @@ -2,8 +2,7 @@ import webbrowser from ast import literal_eval -from calmjs.parse import asttypes, es5 -from tabulate import tabulate +from hls4ml.utils.dependency import requires def parse_quartus_report(hls_dir, write_to_file=True): @@ -42,6 +41,7 @@ def parse_quartus_report(hls_dir, write_to_file=True): return results +@requires('quantus-report') def read_quartus_report(hls_dir, open_browser=False): ''' Parse and print the Quartus report to print the report. Optionally open a browser. @@ -53,6 +53,8 @@ def read_quartus_report(hls_dir, open_browser=False): Returns: None ''' + from tabulate import tabulate + report = parse_quartus_report(hls_dir) print('HLS Resource Summary\n') @@ -90,6 +92,7 @@ def _find_project_dir(hls_dir): return top_func_name + '-fpga.prj' +@requires('quantus-report') def read_js_object(js_script): ''' Reads the JavaScript file and return a dictionary of variables definded in the script. @@ -100,6 +103,7 @@ def read_js_object(js_script): Returns: Dictionary of variables defines in script ''' + from calmjs.parse import asttypes, es5 def visit(node): if isinstance(node, asttypes.Program): diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h index 4683239d85..7df968bd94 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h @@ -130,37 +130,40 @@ enum class softmax_implementation { latency = 0, legacy = 1, stable = 2, argmax inline float exp_fcn_float(float input) { return std::exp(input); } -template inline float softmax_real_val_from_idx(unsigned i) { +template inline float softmax_real_val_from_idx(unsigned i) { // Treat the index as the top N bits - static constexpr int N = ceillog2(CONFIG_T::table_size); // number of address bits for table + static constexpr int N = ceillog2(table_size); // number of address bits for table data_T x(0); x(x.width - 1, x.width - N) = i; return (float)x; } -template inline unsigned softmax_idx_from_real_val(data_T x) { +template inline unsigned softmax_idx_from_real_val(data_T x) { // Slice the top N bits to get an index into the table - static constexpr int N = ceillog2(CONFIG_T::table_size); // number of address bits for table - ap_uint y = x(x.width - 1, x.width - N); // slice the top N bits of input + static constexpr int N = ceillog2(table_size); // number of address bits for table + ap_uint y = x(x.width - 1, x.width - N); // slice the top N bits of input return (unsigned)y(N - 1, 0); } template -void init_exp_table(typename CONFIG_T::exp_table_t table_out[CONFIG_T::table_size]) { +void init_exp_table(typename CONFIG_T::exp_table_t table_out[CONFIG_T::exp_table_size], bool negative = false) { // The template data_T is the data type used to address the table - for (unsigned i = 0; i < CONFIG_T::table_size; i++) { + for (unsigned i = 0; i < CONFIG_T::exp_table_size; i++) { // Slicing bits for address is going to round towards 0, so take the central value - float x = softmax_real_val_from_idx(i); + float x = softmax_real_val_from_idx(i) * CONFIG_T::exp_scale; + if (negative) { + x = -x; + } typename CONFIG_T::exp_table_t exp_x = exp_fcn_float(x); table_out[i] = exp_x; } } template -void init_invert_table(typename CONFIG_T::inv_table_t table_out[CONFIG_T::table_size]) { +void init_invert_table(typename CONFIG_T::inv_table_t table_out[CONFIG_T::inv_table_size]) { // The template data_T is the data type used to address the table - for (unsigned i = 0; i < CONFIG_T::table_size; i++) { - float x = softmax_real_val_from_idx(i); + for (unsigned i = 0; i < CONFIG_T::inv_table_size; i++) { + float x = softmax_real_val_from_idx(i); typename CONFIG_T::inv_table_t inv_x = 1 / x; table_out[i] = inv_x; } @@ -172,40 +175,39 @@ void softmax_latency(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // Initialize the lookup tables #ifdef __HLS_SYN__ bool initialized = false; - typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #else static bool initialized = false; - static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #endif if (!initialized) { // Note we are exponentiating the inputs, which have type data_T init_exp_table(exp_table); // Note we are inverting the exponentials, which have type exp_table_t - init_invert_table(invert_table); + init_invert_table(invert_table); initialized = true; } // Calculate all the e^x's - typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; + typename CONFIG_T::accum_t exp_res[CONFIG_T::n_in]; #pragma HLS array_partition variable=exp_res complete - typename CONFIG_T::exp_table_t exp_sum(0); + typename CONFIG_T::inv_inp_t exp_sum(0); for (unsigned i = 0; i < CONFIG_T::n_in; i++) { #pragma HLS unroll - unsigned x = softmax_idx_from_real_val(data[i]); + unsigned x = softmax_idx_from_real_val(data[i]); exp_res[i] = exp_table[x]; } // Explicitly sum the results with an adder tree. // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing - Op_add op_add; - exp_sum = - reduce>(exp_res, op_add); + Op_add op_add; + exp_sum = reduce>(exp_res, op_add); typename CONFIG_T::inv_table_t inv_exp_sum = - invert_table[softmax_idx_from_real_val(exp_sum)]; + invert_table[softmax_idx_from_real_val(exp_sum)]; for (unsigned i = 0; i < CONFIG_T::n_in; i++) { #pragma HLS unroll res[i] = exp_res[i] * inv_exp_sum; @@ -218,19 +220,19 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // Initialize the lookup tables #ifdef __HLS_SYN__ bool initialized = false; - typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #else static bool initialized = false; - static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #endif if (!initialized) { // Note we are exponentiating the inputs, which have type data_T - init_exp_table(exp_table); + init_exp_table(exp_table, true); // Note we are inverting the exponentials, which have type exp_table_t - init_invert_table(invert_table); + init_invert_table(invert_table); initialized = true; } @@ -239,30 +241,29 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { data_T x_max = reduce>(data, op_max); // For the diffs, use the same type as the input but force rounding and saturation - ap_fixed d_xi_xmax[CONFIG_T::n_in]; + typename CONFIG_T::inp_norm_t d_xi_xmax[CONFIG_T::n_in]; for (unsigned i = 0; i < CONFIG_T::n_in; i++) { #pragma HLS unroll - d_xi_xmax[i] = data[i] - x_max; + d_xi_xmax[i] = x_max - data[i]; } // Calculate all the e^x's - typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; + typename CONFIG_T::accum_t exp_res[CONFIG_T::n_in]; #pragma HLS array_partition variable=exp_res complete - typename CONFIG_T::exp_table_t exp_sum(0); + typename CONFIG_T::inv_inp_t exp_sum(0); for (unsigned i = 0; i < CONFIG_T::n_in; i++) { #pragma HLS unroll - unsigned x = softmax_idx_from_real_val(d_xi_xmax[i]); + unsigned x = softmax_idx_from_real_val(d_xi_xmax[i]); exp_res[i] = exp_table[x]; } // Explicitly sum the results with an adder tree. // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing - Op_add op_add; - exp_sum = - reduce>(exp_res, op_add); + Op_add op_add; + exp_sum = reduce>(exp_res, op_add); typename CONFIG_T::inv_table_t inv_exp_sum = - invert_table[softmax_idx_from_real_val(exp_sum)]; + invert_table[softmax_idx_from_real_val(exp_sum)]; for (unsigned i = 0; i < CONFIG_T::n_in; i++) { #pragma HLS unroll res[i] = exp_res[i] * inv_exp_sum; @@ -299,16 +300,16 @@ void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // Initialize the lookup table #ifdef __HLS_SYN__ bool initialized = false; - typename CONFIG_T::table_t exp_table[CONFIG_T::table_size]; - typename CONFIG_T::table_t invert_table[CONFIG_T::table_size]; + typename CONFIG_T::table_t exp_table[CONFIG_T::exp_table_size]; + typename CONFIG_T::table_t invert_table[CONFIG_T::inv_table_size]; #else static bool initialized = false; - static typename CONFIG_T::table_t exp_table[CONFIG_T::table_size]; - static typename CONFIG_T::table_t invert_table[CONFIG_T::table_size]; + static typename CONFIG_T::table_t exp_table[CONFIG_T::exp_table_size]; + static typename CONFIG_T::table_t invert_table[CONFIG_T::inv_table_size]; #endif if (!initialized) { - init_exp_table_legacy(exp_table); - init_invert_table_legacy(invert_table); + init_exp_table_legacy(exp_table); + init_invert_table_legacy(invert_table); initialized = true; } @@ -330,12 +331,12 @@ void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { if (ii == jj) exp_diff_res = 1; else { - data_round = (data_cache[jj] - data_cache[ii]) * CONFIG_T::table_size / 16; - index = data_round + 8 * CONFIG_T::table_size / 16; + data_round = (data_cache[jj] - data_cache[ii]) * CONFIG_T::exp_table_size / 16; + index = data_round + 8 * CONFIG_T::exp_table_size / 16; if (index < 0) index = 0; - if (index > CONFIG_T::table_size - 1) - index = CONFIG_T::table_size - 1; + if (index > CONFIG_T::exp_table_size - 1) + index = CONFIG_T::exp_table_size - 1; exp_diff_res = exp_table[index]; } exp_res[ii] += exp_diff_res; @@ -344,11 +345,11 @@ void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // Second loop to invert for (int ii = 0; ii < CONFIG_T::n_in; ii++) { - int exp_res_index = exp_res[ii] * CONFIG_T::table_size / 64; + int exp_res_index = exp_res[ii] * CONFIG_T::inv_table_size / 64; if (exp_res_index < 0) exp_res_index = 0; - if (exp_res_index > CONFIG_T::table_size - 1) - exp_res_index = CONFIG_T::table_size - 1; + if (exp_res_index > CONFIG_T::inv_table_size - 1) + exp_res_index = CONFIG_T::inv_table_size - 1; // typename CONFIG_T::table_t exp_res_invert = invert_table[exp_res_index]; res[ii] = (res_T)invert_table[exp_res_index]; } @@ -394,6 +395,30 @@ void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { } } +template +void softmax_multidim(data_T data[CONFIG_T::n_outer * CONFIG_T::n_in * CONFIG_T::n_inner], + res_T res[CONFIG_T::n_outer * CONFIG_T::n_in * CONFIG_T::n_inner]) { + #pragma HLS inline + #pragma HLS allocation instances = softmax limit = CONFIG_T::parallelization_factor function + data_T buffer_in[CONFIG_T::n_in]; + res_T buffer_out[CONFIG_T::n_in]; + for (signed i = 0; i < CONFIG_T::n_outer; i++) { + #pragma HLS UNROLL + for (signed k = 0; k < CONFIG_T::n_inner; k++) { + #pragma HLS UNROLL + for (signed j = 0; j < CONFIG_T::n_in; j++) { + #pragma HLS UNROLL + buffer_in[j] = data[i * CONFIG_T::n_in * CONFIG_T::n_inner + j * CONFIG_T::n_inner + k]; + } + softmax(buffer_in, buffer_out); + for (signed j = 0; j < CONFIG_T::n_in; j++) { + #pragma HLS UNROLL + res[i * CONFIG_T::n_in * CONFIG_T::n_inner + j * CONFIG_T::n_inner + k] = buffer_out[j]; + } + } + } +} + // ************************************************* // TanH Activation // ************************************************* diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h index ef687243bf..d117a565aa 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h @@ -109,19 +109,19 @@ void softmax_latency(hls::stream &data, hls::stream &res) { // Initialize the lookup tables #ifdef __HLS_SYN__ bool initialized = false; - typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #else static bool initialized = false; - static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #endif if (!initialized) { // Note we are exponentiating the inputs, which have type data_T init_exp_table(exp_table); // Note we are inverting the exponentials, which have type exp_table_t - init_invert_table(invert_table); + init_invert_table(invert_table); initialized = true; } @@ -129,9 +129,9 @@ void softmax_latency(hls::stream &data, hls::stream &res) { constexpr unsigned ii = data_T::size / multiplier_limit; // Calculate all the e^x's - typename CONFIG_T::exp_table_t exp_res[data_T::size]; + typename CONFIG_T::accum_t exp_res[data_T::size]; #pragma HLS array_partition variable=exp_res complete - typename CONFIG_T::exp_table_t exp_sum(0); + typename CONFIG_T::inv_inp_t exp_sum(0); SoftmaxExpLoop: for (unsigned i = 0; i < CONFIG_T::n_in / data_T::size; i++) { #pragma HLS PIPELINE II=ii @@ -140,18 +140,17 @@ void softmax_latency(hls::stream &data, hls::stream &res) { SoftmaxExpPackLoop: for (unsigned j = 0; j < data_T::size; j++) { #pragma HLS UNROLL - unsigned x = softmax_idx_from_real_val(in_pack[j]); + unsigned x = softmax_idx_from_real_val(in_pack[j]); exp_res[j] = exp_table[x]; } // Explicitly sum the results with an adder tree. // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing - Op_add op_add; - exp_sum = - reduce>(exp_res, op_add); + Op_add op_add; + exp_sum = reduce>(exp_res, op_add); typename CONFIG_T::inv_table_t inv_exp_sum = - invert_table[softmax_idx_from_real_val(exp_sum)]; + invert_table[softmax_idx_from_real_val(exp_sum)]; res_T out_pack; PRAGMA_DATA_PACK(out_pack) @@ -171,19 +170,19 @@ void softmax_stable(hls::stream &data, hls::stream &res) { // Initialize the lookup tables #ifdef __HLS_SYN__ bool initialized = false; - typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #else static bool initialized = false; - static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #endif if (!initialized) { // Note we are exponentiating the inputs, which have type data_T - init_exp_table(exp_table); + init_exp_table(exp_table, true); // Note we are inverting the exponentials, which have type exp_table_t - init_invert_table(invert_table); + init_invert_table(invert_table); initialized = true; } @@ -209,30 +208,29 @@ void softmax_stable(hls::stream &data, hls::stream &res) { reduce>(data_array, op_max); // For the diffs, use the same type as the input but force rounding and saturation - ap_fixed d_xi_xmax[data_T::size]; + typename CONFIG_T::inp_norm_t d_xi_xmax[data_T::size]; for (unsigned j = 0; j < data_T::size; j++) { #pragma HLS UNROLL - d_xi_xmax[j] = data_array[j] - x_max; + d_xi_xmax[j] = x_max - data_array[j]; } // Calculate all the e^x's - typename CONFIG_T::exp_table_t exp_res[data_T::size]; + typename CONFIG_T::accum_t exp_res[data_T::size]; #pragma HLS ARRAY_PARTITION variable=exp_res complete - typename CONFIG_T::exp_table_t exp_sum(0); + typename CONFIG_T::inv_inp_t exp_sum(0); for (unsigned j = 0; j < data_T::size; j++) { #pragma HLS UNROLL - unsigned x = softmax_idx_from_real_val(d_xi_xmax[j]); + unsigned x = softmax_idx_from_real_val(d_xi_xmax[j]); exp_res[j] = exp_table[x]; } // Explicitly sum the results with an adder tree. // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing - Op_add op_add; - exp_sum = - reduce>(exp_res, op_add); + Op_add op_add; + exp_sum = reduce>(exp_res, op_add); typename CONFIG_T::inv_table_t inv_exp_sum = - invert_table[softmax_idx_from_real_val(exp_sum)]; + invert_table[softmax_idx_from_real_val(exp_sum)]; res_T out_pack; PRAGMA_DATA_PACK(out_pack) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_array.h b/hls4ml/templates/vivado/nnet_utils/nnet_array.h deleted file mode 100644 index d179102a99..0000000000 --- a/hls4ml/templates/vivado/nnet_utils/nnet_array.h +++ /dev/null @@ -1,52 +0,0 @@ -#ifndef NNET_ARRAY_H_ -#define NNET_ARRAY_H_ - -#include - -namespace nnet { - -struct transpose_config { - static const unsigned height = 10; - static const unsigned width = 10; - static const unsigned depth = 10; - static constexpr unsigned perm[3] = {2, 0, 1}; -}; - -template -void transpose_2d(data_T data[CONFIG_T::height * CONFIG_T::width], res_T data_t[CONFIG_T::height * CONFIG_T::width]) { - #pragma HLS PIPELINE - - for (int i = 0; i < CONFIG_T::height; i++) { - for (int j = 0; j < CONFIG_T::width; j++) { - data_t[j * CONFIG_T::height + i] = data[i * CONFIG_T::width + j]; - } - } -} - -template -void transpose_3d(data_T data[CONFIG_T::depth * CONFIG_T::height * CONFIG_T::width], - res_T data_t[CONFIG_T::depth * CONFIG_T::height * CONFIG_T::width]) { - unsigned dims[3] = {CONFIG_T::depth, CONFIG_T::height, CONFIG_T::width}; - unsigned dims_t[3]; - dims_t[0] = dims[CONFIG_T::perm[0]]; - dims_t[1] = dims[CONFIG_T::perm[1]]; - dims_t[2] = dims[CONFIG_T::perm[2]]; - - int idx[3] = {0}, idx_t[3] = {0}; - for (idx[0] = 0; idx[0] < dims[0]; idx[0]++) { - for (idx[1] = 0; idx[1] < dims[1]; idx[1]++) { - for (idx[2] = 0; idx[2] < dims[2]; idx[2]++) { - idx_t[0] = idx[CONFIG_T::perm[0]]; - idx_t[1] = idx[CONFIG_T::perm[1]]; - idx_t[2] = idx[CONFIG_T::perm[2]]; - - data_t[idx_t[0] * dims_t[1] * dims_t[2] + idx_t[1] * dims_t[2] + idx_t[2]] = - data[idx[0] * dims[1] * dims[2] + idx[1] * dims[2] + idx[2]]; - } - } - } -} - -} // namespace nnet - -#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h b/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h new file mode 100644 index 0000000000..cc2917783c --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h @@ -0,0 +1,83 @@ +#ifndef NNET_EINSUM_H_ +#define NNET_EINSUM_H_ + +#include "nnet_common.h" +#include "nnet_mult.h" +#include "nnet_transpose.h" + +namespace nnet { + +struct config_einsum { + typedef void tpose_inp0_conf; + typedef void tpose_inp1_conf; + typedef void tpose_out_conf; + + // Layer Sizes + static const unsigned n_free0; + static const unsigned n_free1; + static const unsigned n_contract; + static const unsigned n_inplace; + + // Resource reuse info + static const unsigned io_type; + static const unsigned strategy; + static const unsigned reuse_factor; + static const unsigned multiplier_limit; + static const bool store_weights_in_bram = false; // NOT USED + + template using product = nnet::product::mult; +}; + +template +void einsum(const data0_T data0[CONFIG_T::tpose_inp0_conf::N], const data1_T data1[CONFIG_T::tpose_inp1_conf::N], + res_T res[CONFIG_T::tpose_out_conf::N]) { + + #pragma HLS PIPELINE II = CONFIG_T::reuse_factor + #pragma HLS ALLOCATION operation instances = mul limit = CONFIG_T::multiplier_limit + + data0_T tpose_i0[CONFIG_T::tpose_inp0_conf::N]; + data1_T tpose_i1[CONFIG_T::tpose_inp1_conf::N]; + res_T tpose_o[CONFIG_T::tpose_out_conf::N]; + + #pragma HLS ARRAY_PARTITION variable = tpose_i0 complete + #pragma HLS ARRAY_PARTITION variable = tpose_i1 complete + #pragma HLS ARRAY_PARTITION variable = tpose_o complete + + nnet::transpose(data0, tpose_i0); + nnet::transpose(data1, tpose_i1); + + // for l0 in range(L0): + // for i in range(I): + // output[(i*L0+l0)*L1:(i*L0+l0+1)*L1] = input1[i*L1*C:(i+1)*L1*C].reshape((L1,C)) @ + // input0[(i*L0+l0)*C:(i*L0+l0+1)*C] + + constexpr unsigned L0 = CONFIG_T::n_free0; + constexpr unsigned L1 = CONFIG_T::n_free1; + constexpr unsigned C = CONFIG_T::n_contract; + constexpr unsigned I = CONFIG_T::n_inplace; + + typename CONFIG_T::accum_t accum_buf; + for (unsigned i = 0; i < I; i++) { + #pragma HLS UNROLL + for (unsigned l0 = 0; l0 < L0; l0++) { + #pragma HLS UNROLL + for (unsigned l1 = 0; l1 < L1; l1++) { + #pragma HLS UNROLL + accum_buf = 0; + for (unsigned c = 0; c < C; c++) { + #pragma HLS UNROLL + data0_T a = tpose_i0[(i * L0 + l0) * C + c]; + data1_T b = tpose_i1[i * L1 * C + l1 * C + c]; + accum_buf += CONFIG_T::template product::product(a, b); + } + tpose_o[(i * L0 + l0) * L1 + l1] = accum_buf; + } + } + } + + nnet::transpose(tpose_o, res); +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h b/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h new file mode 100644 index 0000000000..1abb7c5d08 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h @@ -0,0 +1,78 @@ +#ifndef NNET_EINSUM_DENSE_H_ +#define NNET_EINSUM_DENSE_H_ + +#include "hls_stream.h" +#include "nnet_common.h" +#include "nnet_dense_latency.h" +#include "nnet_dense_resource.h" +#include "nnet_function_stubs.h" +#include "nnet_helpers.h" +#include "nnet_mult.h" +#include "nnet_transpose.h" + +namespace nnet { + +struct einsum_dense_config { + // Internal data type definitions + + typedef void tpose_inp_conf; + typedef void tpose_out_conf; + typedef void dense_conf; + + // Layer Sizes + static const unsigned n_free_data = 1; + static const unsigned n_free_kernel = 1; + static const unsigned n_contract = 1; + static const unsigned n_inplace = 1; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned strategy = latency; + static const unsigned reuse_factor = 1; + static const unsigned parallelization_factor = 1000; // Only useful when n_inplace > 1 + static const bool store_weights_in_bram = false; // NOT USED + + // Product function to use + template using product = nnet::product::mult; +}; + +template +void einsum_dense( + data_T data[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace], + res_T res[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace], + typename CONFIG_T::dense_conf::weight_t weights[CONFIG_T::n_free_kernel * CONFIG_T::n_contract * CONFIG_T::n_inplace], + typename CONFIG_T::dense_conf::bias_t biases[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]) { + data_T inp_tpose[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace]; + res_T out_tpose[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]; + res_T out_buffer[CONFIG_T::n_free_kernel]; + #pragma HLS ARRAY_PARTITION variable = inp_tpose complete + #pragma HLS ARRAY_PARTITION variable = out_tpose complete + + nnet::transpose(data, inp_tpose); + + constexpr unsigned L0 = CONFIG_T::n_free_data; + constexpr unsigned L1 = CONFIG_T::n_free_kernel; + constexpr unsigned C = CONFIG_T::n_contract; + constexpr unsigned I = CONFIG_T::n_inplace; + + for (unsigned l0 = 0; l0 < L0; l0++) { + #pragma HLS UNROLL factor = CONFIG_T::parallelization_factor + for (unsigned i = 0; i < I; i++) { + #pragma HLS UNROLL + // even w/o explicit distributed arithmetic optimization, latency kernels are partially implemented as such + // so reusing the same multiplier for different weights doesn't really help... only full unrolling for now + dense(&inp_tpose[(i * L0 + l0) * C], out_buffer, + &weights[(i * L1 * C)], &biases[((i * L0 + l0) * L1)]); + for (unsigned j = 0; j < L1; j++) { + #pragma HLS UNROLL + out_tpose[(i * L0 + l0) * L1 + j] = out_buffer[j]; + } + } + } + + nnet::transpose(out_tpose, res); +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_stream.h index 900db16c36..33538ede9f 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_stream.h @@ -179,29 +179,6 @@ void broadcast_stream(hls::stream &data, hls::stream &res) { } } -template -void transpose_2d(hls::stream &data, hls::stream &res) { - typename data_T::value_type data_array[CONFIG_T::height * CONFIG_T::width]; - #pragma HLS ARRAY_PARTITION variable=data_array complete - - for (int i = 0; i < CONFIG_T::height * CONFIG_T::width / data_T::size; i++) { - #pragma HLS PIPELINE - data_T in_data = data.read(); - for (int j = 0; j < data_T::size; j++) { - data_array[i * data_T::size + j] = typename data_T::value_type(in_data[j]); - } - } - - for (int i = 0; i < CONFIG_T::height * CONFIG_T::width / res_T::size; i++) { - #pragma HLS PIPELINE - res_T out_data; - PRAGMA_DATA_PACK(out_data) - for (int j = 0; j < res_T::size; j++) { - out_data[j] = typename res_T::value_type(data_array[j * data_T::size + i]); - } - res.write(out_data); - } -} } // namespace nnet #endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_transpose.h b/hls4ml/templates/vivado/nnet_utils/nnet_transpose.h new file mode 100644 index 0000000000..85238c25dd --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_transpose.h @@ -0,0 +1,39 @@ +#ifndef NNET_PERMUTE_H_ +#define NNET_PERMUTE_H_ + +namespace nnet { + +struct transpose_config { + static const unsigned dims; + static const unsigned N; + // vivado/vitis hls can't indexing constexpr array for some reason + // and vivado hls don't like template recursion either (vitis is fine) + // thus this appears to be the only workaround (or overkill it with codegen) + static const unsigned *const from_shape; + static const unsigned *const to_shape; + static const unsigned *const perm; + static const unsigned *const perm_strides; +}; + +template unsigned transfer_idx(int index) { + // Given output idx in c-order flat array, return input idx + int idx = 0; + for (int i = CONFIG_T::dims - 1; i >= 0; i--) { + idx += (index % CONFIG_T::to_shape[i]) * CONFIG_T::perm_strides[i]; + index /= CONFIG_T::to_shape[i]; + } + return idx; +} + +template +void transpose(const data_T data[CONFIG_T::N], res_T res[CONFIG_T::N]) { + for (int i = 0; i < CONFIG_T::N; i++) { + #pragma HLS UNROLL + int idx = transfer_idx(i); + res[i] = data[idx]; + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_transpose_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_transpose_stream.h new file mode 100644 index 0000000000..7f46a68bd2 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_transpose_stream.h @@ -0,0 +1,67 @@ +#ifndef NNET_TRANSPOSE_STREAM_H +#define NNET_TRANSPOSE_STREAM_H + +#include "hls_stream.h" +#include "nnet_transpose.h" +#include + +namespace nnet { + +template +typename std::enable_if::type transpose(hls::stream &data, hls::stream &res) { + // #pragma HLS INLINE RECURSIVE + typename data_T::value_type data_array[CONFIG_T::N]; + #pragma HLS ARRAY_PARTITION variable=data_array complete + + for (int i = 0; i < CONFIG_T::N / data_T::size; i++) { + #pragma HLS PIPELINE + data_T in_data = data.read(); + for (int j = 0; j < data_T::size; j++) { + #pragma HLS UNROLL + data_array[i * data_T::size + j] = typename data_T::value_type(in_data[j]); + } + } + + for (int i = 0; i < CONFIG_T::N / res_T::size; i++) { + #pragma HLS PIPELINE + res_T out_data; + PRAGMA_DATA_PACK(out_data) + for (int j = 0; j < res_T::size; j++) { + #pragma HLS UNROLL + out_data[j] = typename res_T::value_type(data_array[j * CONFIG_T::from_shape[1] + i]); + } + res.write(out_data); + } +} + +// This sfinae is for vivado_hls, which has some overhead using the transfer_idx in io_stream. +// In vitis both performs exactly the same, thus this is not removed out of convenience. +template +typename std::enable_if::type transpose(hls::stream &data, hls::stream &res) { + // #pragma HLS INLINE RECURSIVE + typename data_T::value_type data_array[CONFIG_T::N]; + #pragma HLS ARRAY_PARTITION variable=data_array complete + + for (int i = 0; i < CONFIG_T::N / data_T::size; i++) { + #pragma HLS PIPELINE + data_T in_data = data.read(); + for (int j = 0; j < data_T::size; j++) { + #pragma HLS UNROLL + data_array[i * data_T::size + j] = typename data_T::value_type(in_data[j]); + } + } + + for (int i = 0; i < CONFIG_T::N / res_T::size; i++) { + #pragma HLS PIPELINE + res_T out_data; + PRAGMA_DATA_PACK(out_data) + for (int j = 0; j < res_T::size; j++) { + #pragma HLS UNROLL + out_data[j] = typename res_T::value_type(data_array[transfer_idx(i * res_T::size + j)]); + } + res.write(out_data); + } +} + +} // namespace nnet +#endif diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index e450084095..f20aa49835 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -1,8 +1,8 @@ import json -import qkeras - import hls4ml +import hls4ml.converters.keras_v3_to_hls +from hls4ml.utils.dependency import requires def create_config(output_dir='my-hls-test', project_name='myproject', backend='Vivado', version='1.0.0', **kwargs): @@ -46,8 +46,11 @@ def create_config(output_dir='my-hls-test', project_name='myproject', backend='V return config +@requires('qkeras') def _get_precision_from_quantizer(quantizer): if isinstance(quantizer, str): + import qkeras + quantizer_obj = qkeras.get_quantizer(quantizer) quantizer = {} # Some activations are classes with get_config method @@ -157,12 +160,17 @@ def config_from_keras_model( if isinstance(model, dict): model_arch = model + reader = hls4ml.converters.KerasModelReader(model) + layer_list, _, _, _ = hls4ml.converters.parse_keras_model(model_arch, reader) else: - model_arch = json.loads(model.to_json()) + import keras - reader = hls4ml.converters.KerasModelReader(model) - - layer_list, _, _, _ = hls4ml.converters.parse_keras_model(model_arch, reader) + if keras.__version__ > '3.0': + layer_list, *_ = hls4ml.converters.parse_keras_v3_model(model) + else: + model_arch = json.loads(model.to_json()) + reader = hls4ml.converters.KerasModelReader(model) + layer_list, _, _, _ = hls4ml.converters.parse_keras_model(model_arch, reader) def make_layer_config(layer): cls_name = layer['class_name'] diff --git a/hls4ml/utils/dependency.py b/hls4ml/utils/dependency.py new file mode 100644 index 0000000000..e546dcb8c9 --- /dev/null +++ b/hls4ml/utils/dependency.py @@ -0,0 +1,55 @@ +import sys +from functools import wraps +from importlib.metadata import metadata +from inspect import ismethod + +extra_requires: dict[str, list[str]] = {} +subpackage = None +for k, v in metadata('hls4ml')._headers: # type: ignore + if k != 'Requires-Dist': + continue + if '; extra == ' not in v: + continue + + req, pkg = v.split('; extra == ') + pkg = pkg.strip('"') + + extra_requires.setdefault(pkg, []).append(req) + + +def requires(pkg: str): + """Mark a function or method as requiring a package to be installed. + 'name': requires hls4ml[name] to be installed. + '_name': requires name to be installed. + + Parameters + ---------- + pkg : str + The package to require. + """ + + def deco(f): + if ismethod(f): + qualifier = f"Method {f.__self__.__class__.__name__}.{f.__name__}" + else: + qualifier = f"Function {f.__name__}" + + if not pkg.startswith("_"): + reqs = ", ".join(extra_requires[pkg]) + msg = f"{qualifier} requires {reqs}, but package {{ename}} is missing" + "Please consider install it with `pip install hls4ml[{pkg}]` for full functionality with {pkg}." + else: + msg = f"{qualifier} requires {pkg[1:]}, but package {{ename}} is missing." + "Consider install it with `pip install {pkg}`." + + @wraps(f) + def inner(*args, **kwargs): + try: + return f(*args, **kwargs) + except ImportError as e: + print(msg.format(ename=e.name), file=sys.stderr) + raise e + + return inner + + return deco diff --git a/hls4ml/utils/einsum_utils.py b/hls4ml/utils/einsum_utils.py new file mode 100644 index 0000000000..43ceb2ba96 --- /dev/null +++ b/hls4ml/utils/einsum_utils.py @@ -0,0 +1,256 @@ +from math import prod +from typing import TypedDict + +import numpy as np + + +class EinsumRecipe(TypedDict): + direct_sum_axis: tuple[tuple[int, ...], tuple[int, ...]] + in_transpose_idxs: tuple[tuple[int, ...], tuple[int, ...]] + L0: int + L1: int + I: int + C: int + out_interpert_shape: tuple[int, ...] + out_transpose_idxs: tuple[int, ...] + + +def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, ...]): + """Validate, resolve broadcasting, and compute output shape for einsum string + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + shape0 : tuple[int,...] + shape of input0 + shape1 : tuple[int,...] + shape of input1 + + Returns + ------- + tuple[str, tuple[int,...]] + einsum string w/o broadcasting, and output shape + + Raises + ------ + ValueError + If the einsum string is invalid, or if it is incompatible with the input shapes + """ + inp, out = map(str.strip, fn.split('->')) + in0, in1 = map(str.strip, inp.split(',')) + alphabets = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' + s_alphabets = set(alphabets) + + # Invalid characters + if not (s_alphabets >= set(in0.replace('...', '') + in1.replace('...', '') + out.replace('...', ''))): + raise ValueError(f"einsum string {fn} is invalid: subscripts should be in [a-zA-Z] and '...' only") + + in0 = in0.replace('...', '0') + in1 = in1.replace('...', '0') + out = out.replace('...', '0') + ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out) + sax_in0, sax_in1, sax_out = set(ax_in0), set(ax_in1), set(ax_out) + free_indices = ''.join(sorted(s_alphabets - sax_in0 - sax_in1 - sax_out)) + + # Repeated indices + if len(sax_in0) != len(ax_in0): + for a in in0: + if in0.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: input0 subscripts includes '{a}' multiple times") + if len(sax_in1) != len(ax_in1): + for a in in1: + if in1.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: input1 subscripts includes '{a}' multiple times") + if len(sax_out) != len(ax_out): + for a in out: + if out.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: output subscripts includes '{a}' multiple times") + + # Invalid broadcasting + if '0' in sax_in0 or '0' in sax_in1 or '0' in sax_out: + if '0' in sax_in0 and '0' in sax_in1: + raise ValueError(f"einsum string {fn} is invalid: both input0 and input1 allows broadcasting") + if '0' not in sax_out: + raise ValueError(f"einsum string {fn} is invalid: output does not allow broadcasting, but inputs do") + if '0' not in sax_in0 and '0' not in sax_in1: + raise ValueError(f"einsum string {fn} is invalid: output allows broadcasting, but inputs do not") + + # Output index out of nowhere + if remaining := sax_out - sax_in0 - sax_in1: + raise ValueError(f"einsum string {fn} is invalid: output subscripts {remaining} not found in inputs") + + _common_in = sax_in0 & sax_in1 + + # Invalid input dimensions + if '0' in sax_in0: + if len(sax_in0) - 1 > len(shape0): + raise ValueError(f"Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given") + # Replace broadcasting indices with free indices + n_broadcast = len(shape0) - len(sax_in0) + 1 + in0 = in0.replace('0', free_indices[:n_broadcast]) + out = out.replace('0', free_indices[:n_broadcast]) + ax_in0 = list(in0) + ax_out = list(out) + else: + if len(sax_in0) != len(shape0): + raise ValueError(f"Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given") + if '0' in sax_in1: + if len(sax_in1) - 1 > len(shape1): + raise ValueError(f"Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given") + # Replace broadcasting indices with free indices + n_broadcast = len(shape1) - len(sax_in1) + 1 + in1 = in1.replace('0', free_indices[:n_broadcast]) + out = out.replace('0', free_indices[:n_broadcast]) + ax_in1 = list(in1) + ax_out = list(out) + else: + if len(sax_in1) != len(shape1): + raise ValueError(f"Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given") + + # Input dimension mismatch + for a in _common_in: + ax_0 = ax_in0.index(a) + ax_1 = ax_in1.index(a) + if shape0[ax_0] != shape1[ax_1]: + raise ValueError( + f"Input dimension size mismatches for common subscript '{a}': {shape0[ax_0]} and {shape1[ax_1]}" + ) + + out_shape = tuple(shape0[ax_in0.index(a)] if a in ax_in0 else shape1[ax_in1.index(a)] for a in ax_out) + return f'{in0},{in1}->{out}', out_shape + + +def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int, ...]) -> EinsumRecipe: + """Parse einsum operation on two input arrays, return a recipe for execution + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + input : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + EinsumRecipe + einsum recipe; executed by _exec_einsum + """ + + fn, _ = _validate_einsum_expr(fn, input_shape0, input_shape1) + + _in, _out = fn.split('->') + _in0, _in1 = _in.split(',') + + in0, in1, out = list(_in0), list(_in1), list(_out) + s_in0, s_in1, s_out = set(in0), set(in1), set(out) + _common = s_in0 & s_in1 + _contract = _common - s_out + _inplace = _common & s_out + contract = sorted(_contract, key=lambda x: in1.index(x)) + inplace = sorted(_inplace, key=lambda x: in1.index(x)) + invariant0 = sorted((s_out - _common) & s_in0, key=lambda x: in0.index(x)) + invariant1 = sorted((s_out - _common) & s_in1, key=lambda x: in1.index(x)) + direct_sum0 = s_in0 - s_out - _common + direct_sum1 = s_in1 - s_out - _common + direct_sum_axis = ( + tuple(sorted(in0.index(x) for x in direct_sum0)), + tuple(sorted(in1.index(x) for x in direct_sum1)), + ) + + contract_idxs = tuple(map(in0.index, contract)), tuple(map(in1.index, contract)) + inplace_idxs = tuple(map(in0.index, inplace)), tuple(map(in1.index, inplace)) + invariant_idxs = tuple(map(in0.index, invariant0)), tuple(map(in1.index, invariant1)) + + inplace_shape = tuple(input_shape0[i] for i in inplace_idxs[0]) + inplace_size = prod(inplace_shape) + contract_size = prod(input_shape0[i] for i in contract_idxs[0]) + invariant_shape0 = tuple(input_shape0[i] for i in invariant_idxs[0]) + invariant_shape1 = tuple(input_shape1[i] for i in invariant_idxs[1]) + invariant_size0, invariant_size1 = prod(invariant_shape0), prod(invariant_shape1) + + transpose_idx0 = inplace_idxs[0] + invariant_idxs[0] + contract_idxs[0] + transpose_idx1 = inplace_idxs[1] + invariant_idxs[1] + contract_idxs[1] + + out_shape_pretranspose = inplace_shape + invariant_shape0 + invariant_shape1 + _out_transpose_idx = np.argsort(tuple(map(out.index, inplace + invariant0 + invariant1))) + out_transpose_idx = tuple(int(i) for i in _out_transpose_idx) + + return EinsumRecipe( + direct_sum_axis=direct_sum_axis, + in_transpose_idxs=(transpose_idx0, transpose_idx1), + out_interpert_shape=out_shape_pretranspose, + out_transpose_idxs=out_transpose_idx, + L0=invariant_size0, + L1=invariant_size1, + I=inplace_size, + C=contract_size, + ) + + +def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: + """Execute einsum operation on two input arrays + + Parameters + ---------- + recipe : EinsumRecipe + einsum recipe + input0 : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + np.ndarray + output array + """ + sum_axis0, sum_axis1 = recipe['direct_sum_axis'] + if sum_axis0: + input0 = np.sum(input0, axis=sum_axis0) + if sum_axis1: + input1 = np.sum(input1, axis=sum_axis1) + input0 = input0.transpose(recipe['in_transpose_idxs'][0]).ravel() + input1 = input1.transpose(recipe['in_transpose_idxs'][1]).ravel() + output = np.zeros(recipe['L0'] * recipe['L1'] * recipe['I'], dtype=input0.dtype) + + L0, L1, I, C = recipe['L0'], recipe['L1'], recipe['I'], recipe['C'] + + for l0 in range(L0): + for i in range(I): + A = input1[i * L1 * C : (i + 1) * L1 * C].reshape((L1, C)) + B = input0[(i * L0 + l0) * C : (i * L0 + l0 + 1) * C] + output[(i * L0 + l0) * L1 : (i * L0 + l0 + 1) * L1] = A @ B + + return output.reshape(recipe['out_interpert_shape']).transpose(recipe['out_transpose_idxs']) + + +def einsum(fn: str, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: + """Execute einsum operation on two input arrays. + + WARNING: Order of multiplication is reversed -- watchout if you are using non-commutative operators + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + input : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + np.ndarray + output array + """ + recipe = parse_einsum(fn, input0.shape, input1.shape) + return _exec_einsum(recipe, input0, input1) diff --git a/hls4ml/utils/qinterval.py b/hls4ml/utils/qinterval.py new file mode 100644 index 0000000000..54d47e7f23 --- /dev/null +++ b/hls4ml/utils/qinterval.py @@ -0,0 +1,335 @@ +from functools import singledispatchmethod +from typing import Any, Sequence, overload + +import numpy as np + +from hls4ml.utils.einsum_utils import EinsumRecipe, parse_einsum + + +def _minimal_f(array: np.ndarray): + _low, _high = np.full(array.shape, -32, dtype=np.int8), np.full(array.shape, 32, dtype=np.int8) + while np.any(_low < _high - 1): + _mid = (_low + _high) // 2 + scaled = array * 2.0**_mid + mask = scaled != scaled.astype(np.int64) + _low = np.where(mask, _mid, _low) + _high = np.where(mask, _high, _mid) + return _high + + +def minimal_kif(array: np.ndarray): + """Given a constant array, determine the minimal k, i, f values that can contain it with no loss of precision. + + Parameters + ---------- + array : np.ndarray + The constant array to be represented. + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + The minimal k, i, f values that can contain the array with no loss of precision. + """ + f = _minimal_f(array) + with np.errstate(divide='ignore', invalid='ignore'): + i = np.ceil(np.log2(np.maximum(array + 2.0**-f, -array))).astype(np.int8) + k = array < 0 + null_mask = array == 0 + i, f = np.where(null_mask, 0, i), np.where(null_mask, 0, f) + return k, i, f + + +class _QIntervalArray: + def __init__(self, min: np.ndarray, max: np.ndarray, delta: np.ndarray): + self.min = min.astype(np.float64) + self.max = max.astype(np.float64) + self.delta = delta.astype(np.float64) + self._validate() + + def _validate(self): + with np.errstate(divide='ignore', invalid='ignore'): + assert np.all(self.min <= self.max), "min must be less than or equal to max" + assert np.all( + (self.max % self.delta == 0) | ((self.max == 0) & (self.delta == 0)) + ), "max must be a multiple of delta" + assert np.all( + (self.min % self.delta == 0) | ((self.min == 0) & (self.delta == 0)) + ), "min must be a multiple of delta" + + +class QIntervalArray(_QIntervalArray): + """Symbolic array for quantized interval arithmetic. + + Available operations are: + - Addition + - Subtraction + - Multiplication + - Division (not recommended) + - Matrix multiplication + + Parameters + ---------- + min : np.ndarray + The minimum value of the interval. + max : np.ndarray + The maximum value of the interval. + delta : np.ndarray + The quantization step of the interval. + """ + + @singledispatchmethod + def __add__(self, other): + _min = self.min + other + _max = self.max + other + _delta = np.minimum(self.delta, 2.0 ** -_minimal_f(other)) + return QIntervalArray(_min, _max, _delta) + + @__add__.register + def _(self, other: _QIntervalArray): + _min = self.min + other.min + _max = self.max + other.max + _delta = np.minimum(self.delta, other.delta) + return QIntervalArray(_min, _max, _delta) + + def __sub__(self, other): + return self + (-other) + + @singledispatchmethod + def __mul__(self, other): + other = np.float64(other) + v1 = self.min * other + v2 = self.max * other + _min = np.minimum(v1, v2) + _max = np.maximum(v1, v2) + _delta = self.delta * other + return QIntervalArray(_min, _max, _delta) + + @__mul__.register + def _(self, other: _QIntervalArray): + v1 = self.min * other.min + v2 = self.min * other.max + v3 = self.max * other.min + v4 = self.max * other.max + _min = np.minimum(np.minimum(v1, v2), np.minimum(v3, v4)) + _max = np.maximum(np.maximum(v1, v2), np.maximum(v3, v4)) + _delta = self.delta * other.delta + return QIntervalArray(_min, _max, _delta) + + def __truediv__(self, other): + return self * (1 / other) + + def __neg__(self): + return QIntervalArray(-self.max, -self.min, self.delta) + + @singledispatchmethod + def __matmul__(self, other: np.ndarray): + seq = ''.join(chr(ord('a') + i) for i in range(self.min.ndim)) + eq = f'{seq},{seq[-1]}...->{seq}...' + ax = self.min.ndim - 1 + v1 = np.einsum(eq, self.min, other, optimize=True) + v2 = np.einsum(eq, self.max, other, optimize=True) + other_delta = 2.0 ** -_minimal_f(other) + _delta = np.einsum(eq, self.delta, other_delta, optimize=True) + delta = np.min(np.where(_delta == 0, np.inf, _delta), axis=ax) + _min = np.sum(np.minimum(v1, v2), axis=ax) + _max = np.sum(np.maximum(v1, v2), axis=ax) + return QIntervalArray(_min, _max, delta) + + @__matmul__.register + def _(self, other: _QIntervalArray): + seq = ''.join(chr(ord('a') + i) for i in range(self.min.ndim)) + eq = f'{seq},{seq[-1]}...->{seq}...' + ax = self.min.ndim - 1 + v1 = np.einsum(eq, self.min, other.min, optimize=True) + v2 = np.einsum(eq, self.max, other.max, optimize=True) + v3 = np.einsum(eq, self.min, other.max, optimize=True) + v4 = np.einsum(eq, self.max, other.min, optimize=True) + + _max = np.sum(np.maximum(np.maximum(v1, v2), np.maximum(v3, v4)), axis=ax) + _min = np.sum(np.minimum(np.minimum(v1, v2), np.minimum(v3, v4)), axis=ax) + + _delta = np.einsum(eq, self.delta, other.delta, optimize=True) + delta = np.min(_delta, axis=ax) + + return QIntervalArray(_min, _max, delta) + + def __rmatmul__(self, other: np.ndarray): + seq = ''.join(chr(ord('a') + i) for i in range(other.ndim)) + eq = f'{seq},{seq[-1]}...->{seq}...' + ax = other.ndim - 1 + v1 = np.einsum(eq, other, self.min, optimize=True) + v2 = np.einsum(eq, other, self.max, optimize=True) + other_delta = 2.0 ** -_minimal_f(other) + _delta = np.einsum(eq, other_delta, self.delta, optimize=True) + delta = np.min(np.where(_delta == 0, np.inf, _delta), axis=ax) + _min = np.sum(np.minimum(v1, v2), axis=ax) + _max = np.sum(np.maximum(v1, v2), axis=ax) + return QIntervalArray(_min, _max, delta) + + def transpose(self, axes: Sequence[int]): + return QIntervalArray(self.min.transpose(axes), self.max.transpose(axes), self.delta.transpose(axes)) + + @property + def shape(self): + return self.min.shape + + def reshape(self, shape: Sequence[int]): + return QIntervalArray(self.min.reshape(shape), self.max.reshape(shape), self.delta.reshape(shape)) + + def ravel(self): + return QIntervalArray(self.min.ravel(), self.max.ravel(), self.delta.ravel()) + + @property + def dtype(self): + return self.min.dtype + + def __getitem__(self, key): + return QIntervalArray(self.min[key], self.max[key], self.delta[key]) + + def __array_function__(self, func, types, args, kwargs): + if func == np.concatenate: + return QIntervalArray( + np.concatenate([a.min for a in args[0]]), + np.concatenate([a.max for a in args[0]]), + np.concatenate([a.delta for a in args[0]]), + ) + return NotImplemented + + def rmatmul(self, other: np.ndarray): + """Right matrix multiplication (other @ self), with __rmatmul__ implemented in QIntervalArray. + This is to avoid using the @ operator defined in np.ndarray. + + Parameters + ---------- + other : np.ndarray + The operand matrix multiplied from the left. + + Returns + ------- + QIntervalArray + The result + """ + return self.__rmatmul__(other) + + @classmethod + def from_kif(cls, k: np.ndarray | int | bool, i: np.ndarray | int, f: np.ndarray | int): + """Create a QIntervalArray from k, i, f values. + + Parameters + ---------- + k : np.ndarray | int | bool + keep_negative + i : np.ndarray | int + integer_bits, excluding sign bit + f : np.ndarray | int + fractional_bits + + Returns + ------- + QIntervalArray + The created QIntervalArray. + """ + + _min = np.asarray(-(2.0**i) * k) + _max = np.asarray(2.0**i - 2.0**-f) + _delta = np.asarray(2.0**-f) + return cls(_min, _max, _delta) + + def sample(self, n: int | None = None): + if n is not None: + rand = np.random.rand(n, *self.min.shape) + else: + rand = np.random.rand(*self.min.shape) + v = rand * (self.max - self.min) + self.min + v = np.round(v / self.delta) * self.delta + return v + + def to_kif(self): + f = -np.log2(self.delta).astype(np.int8) + + with np.errstate(divide='ignore', invalid='ignore'): + i = np.ceil(np.log2(np.maximum(self.max + 2.0**-f, -self.min))).astype(np.int8) + k = self.min < 0 + null_mask = (self.max == 0) & (self.min == 0) + i, f = np.where(null_mask, 0, i), np.where(null_mask, 0, f) + return k, i, f + + +def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray | QIntervalArray, input1: np.ndarray | QIntervalArray, operator): + """Execute einsum operation on two input arrays + + Parameters + ---------- + recipe : EinsumRecipe + einsum recipe + input0 : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + np.ndarray + output array + """ + input0 = input0.transpose(recipe['in_transpose_idxs'][0]).ravel() + input1 = input1.transpose(recipe['in_transpose_idxs'][1]).ravel() + # output = np.zeros(recipe['L0'] * recipe['L1'] * recipe['I'], dtype=input0.dtype) + output = [] + + L0, L1, I, C = recipe['L0'], recipe['L1'], recipe['I'], recipe['C'] + + for i in range(I): + for l0 in range(L0): + A = input1[i * L1 * C : (i + 1) * L1 * C].reshape((L1, C)) + B = input0[(i * L0 + l0) * C : (i * L0 + l0 + 1) * C] + output.append(operator(A, B)) + output = np.concatenate(output, axis=0) + + return output.reshape(recipe['out_interpert_shape']).transpose(recipe['out_transpose_idxs']) + + +@overload +def einsum(fn: str, input0: QIntervalArray, input1: QIntervalArray, operator=None) -> QIntervalArray: ... + + +@overload +def einsum(fn: str, input0: np.ndarray, input1: QIntervalArray, operator=None) -> QIntervalArray: ... + + +@overload +def einsum(fn: str, input0: QIntervalArray, input1: np.ndarray, operator=None) -> QIntervalArray: ... + + +@overload +def einsum(fn: str, input0: np.ndarray, input1: np.ndarray, operator=None) -> np.ndarray: ... + + +def einsum(fn: str, input0: np.ndarray | QIntervalArray, input1: np.ndarray | QIntervalArray) -> Any: # type: ignore + """Execute einsum operation on two input arrays + + WARNING: Order of multiplication is reversed -- watchout if you are using non-commutative operators + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + input : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + np.ndarray + output array + """ + + def operator(A, B): + if isinstance(A, np.ndarray): + return B.__rmatmul__(A) + else: + return A @ B + + recipe = parse_einsum(fn, input0.shape, input1.shape) + return _exec_einsum(recipe, input0, input1, operator) diff --git a/hls4ml/writer/catapult_writer.py b/hls4ml/writer/catapult_writer.py index 7db1063206..9a48460995 100755 --- a/hls4ml/writer/catapult_writer.py +++ b/hls4ml/writer/catapult_writer.py @@ -889,7 +889,9 @@ def keras_model_representer(dumper, keras_model): return dumper.represent_scalar('!keras_model', model_path) try: - from tensorflow.keras import Model as KerasModel + import keras + + KerasModel = keras.models.Model yaml.add_multi_representer(KerasModel, keras_model_representer) except Exception: diff --git a/hls4ml/writer/oneapi_writer.py b/hls4ml/writer/oneapi_writer.py index fe633214f6..c9af2544bd 100644 --- a/hls4ml/writer/oneapi_writer.py +++ b/hls4ml/writer/oneapi_writer.py @@ -102,9 +102,10 @@ def write_project_cpp(self, model): project_name = model.config.get_project_name() filedir = os.path.dirname(os.path.abspath(__file__)) - with open(os.path.join(filedir, '../templates/oneapi/firmware/myproject.cpp')) as f, open( - f'{model.config.get_output_dir()}/src/firmware/{project_name}.cpp', 'w' - ) as fout: + with ( + open(os.path.join(filedir, '../templates/oneapi/firmware/myproject.cpp')) as f, + open(f'{model.config.get_output_dir()}/src/firmware/{project_name}.cpp', 'w') as fout, + ): model_inputs = model.get_input_variables() model_outputs = model.get_output_variables() model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram'] @@ -207,9 +208,10 @@ def write_project_header(self, model): project_name = model.config.get_project_name() filedir = os.path.dirname(os.path.abspath(__file__)) - with open(os.path.join(filedir, '../templates/oneapi/firmware/myproject.h')) as f, open( - f'{model.config.get_output_dir()}/src/firmware/{project_name}.h', 'w' - ) as fout: + with ( + open(os.path.join(filedir, '../templates/oneapi/firmware/myproject.h')) as f, + open(f'{model.config.get_output_dir()}/src/firmware/{project_name}.h', 'w') as fout, + ): model_inputs = model.get_input_variables() model_outputs = model.get_output_variables() # model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram'] @@ -254,9 +256,10 @@ def write_defines(self, model): model (ModelGraph): the hls4ml model. """ filedir = os.path.dirname(os.path.abspath(__file__)) - with open(os.path.join(filedir, '../templates/oneapi/firmware/defines.h')) as f, open( - f'{model.config.get_output_dir()}/src/firmware/defines.h', 'w' - ) as fout: + with ( + open(os.path.join(filedir, '../templates/oneapi/firmware/defines.h')) as f, + open(f'{model.config.get_output_dir()}/src/firmware/defines.h', 'w') as fout, + ): for line in f.readlines(): # Insert numbers if '// hls-fpga-machine-learning insert numbers' in line: @@ -298,9 +301,10 @@ def write_parameters(self, model): model (ModelGraph): the hls4ml model. """ filedir = os.path.dirname(os.path.abspath(__file__)) - with open(os.path.join(filedir, '../templates/oneapi/firmware/parameters.h')) as f, open( - f'{model.config.get_output_dir()}/src/firmware/parameters.h', 'w' - ) as fout: + with ( + open(os.path.join(filedir, '../templates/oneapi/firmware/parameters.h')) as f, + open(f'{model.config.get_output_dir()}/src/firmware/parameters.h', 'w') as fout, + ): for line in f.readlines(): if '// hls-fpga-machine-learning insert includes' in line: newline = line @@ -376,9 +380,10 @@ def write_test_bench(self, model): output_predictions, f'{model.config.get_output_dir()}/tb_data/tb_output_predictions.dat' ) - with open(os.path.join(filedir, '../templates/oneapi/myproject_test.cpp')) as f, open( - f'{model.config.get_output_dir()}/src/{project_name}_test.cpp', 'w' - ) as fout: + with ( + open(os.path.join(filedir, '../templates/oneapi/myproject_test.cpp')) as f, + open(f'{model.config.get_output_dir()}/src/{project_name}_test.cpp', 'w') as fout, + ): for line in f.readlines(): indent = ' ' * (len(line) - len(line.lstrip(' '))) @@ -434,9 +439,10 @@ def write_bridge(self, model): indent = ' ' filedir = os.path.dirname(os.path.abspath(__file__)) - with open(os.path.join(filedir, '../templates/oneapi/myproject_bridge.cpp')) as f, open( - f'{model.config.get_output_dir()}/src/{project_name}_bridge.cpp', 'w' - ) as fout: + with ( + open(os.path.join(filedir, '../templates/oneapi/myproject_bridge.cpp')) as f, + open(f'{model.config.get_output_dir()}/src/{project_name}_bridge.cpp', 'w') as fout, + ): for line in f.readlines(): if 'MYPROJECT' in line: newline = line.replace('MYPROJECT', format(project_name.upper())) @@ -511,9 +517,10 @@ def write_build_script(self, model): # Makefile filedir = os.path.dirname(os.path.abspath(__file__)) device = model.config.get_config_value('Part') - with open(os.path.join(filedir, '../templates/oneapi/CMakeLists.txt')) as f, open( - f'{model.config.get_output_dir()}/CMakeLists.txt', 'w' - ) as fout: + with ( + open(os.path.join(filedir, '../templates/oneapi/CMakeLists.txt')) as f, + open(f'{model.config.get_output_dir()}/CMakeLists.txt', 'w') as fout, + ): for line in f.readlines(): line = line.replace('myproject', model.config.get_project_name()) line = line.replace('mystamp', model.config.get_config_value('Stamp')) diff --git a/hls4ml/writer/quartus_writer.py b/hls4ml/writer/quartus_writer.py index 932a8b6a6d..1d61bde1f4 100644 --- a/hls4ml/writer/quartus_writer.py +++ b/hls4ml/writer/quartus_writer.py @@ -1327,7 +1327,9 @@ def keras_model_representer(dumper, keras_model): return dumper.represent_scalar('!keras_model', model_path) try: - from tensorflow.keras import Model as KerasModel + import keras + + KerasModel = keras.models.Model yaml.add_multi_representer(KerasModel, keras_model_representer) except Exception: diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 0341959045..6531f9db87 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -817,7 +817,9 @@ def keras_model_representer(dumper, keras_model): return dumper.represent_scalar('!keras_model', model_path) try: - from tensorflow.keras import Model as KerasModel + import keras + + KerasModel = keras.models.Model yaml.add_multi_representer(KerasModel, keras_model_representer) except Exception: diff --git a/pyproject.toml b/pyproject.toml index 6402ab0e7a..24175c9612 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,100 @@ [build-system] -# AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD! -requires = ["setuptools>=46.1.0", "setuptools_scm[toml]>=5", "wheel"] build-backend = "setuptools.build_meta" +requires = [ "setuptools>=61", "setuptools-scm>=8" ] + +[project] +name = "hls4ml" +version = "1.0.0" +description = "Machine learning in FPGAs using HLS" +readme = "README.md" +license = { text = "Apache-2.0" } +authors = [ { name = "hls4ml Team" } ] +requires-python = ">=3.10" +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: C++", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = [ "h5py", "numpy", "pydigitalwavetools==1.1", "pyyaml" ] + +optional-dependencies.doc = [ + "sphinx", + "sphinx-contributors", + "sphinx-github-changelog", + "sphinx-rtd-theme", +] +optional-dependencies.HGQ = [ "hgq~=0.2.0" ] +optional-dependencies.onnx = [ "onnx>=1.4" ] +optional-dependencies.optimization = [ + "keras-tuner==1.1.3", + "ortools==9.4.1874", + "packaging", +] +optional-dependencies.profiling = [ "matplotlib", "pandas", "seaborn" ] +optional-dependencies.qkeras = [ + "qkeras", + "tensorflow>=2.8,<=2.14.1", + "tensorflow-model-optimization<=0.7.5", +] +optional-dependencies.quantus_report = [ "calmjs-parse", "tabulate" ] +optional-dependencies.sr = [ "sympy" ] +optional-dependencies.testing = [ + "calmjs-parse", + "hgq~=0.2.0", + "onnx>=1.4", + "pytest", + "pytest-cov", + "pytest-randomly", + "qonnx", + "tabulate", + "torch", +] +urls.Homepage = "https://fastmachinelearning.org/hls4ml" +scripts.hls4ml = "hls4ml.cli:main" +entry-points.pytest_randomly.random_seeder = "hls4ml:reseed" + +[tool.setuptools] +packages = [ "hls4ml" ] +include-package-data = true + [tool.setuptools_scm] -# See configuration details in https://github.com/pypa/setuptools_scm + version_scheme = "release-branch-semver" -git_describe_command = "git describe --dirty --tags --long --match v* --first-parent" +git_describe_command = [ + "git", + "describe", + "--dirty", + "--tags", + "--long", + "--match", + "v*", + "--first-parent", +] write_to = "hls4ml/_version.py" + +[tool.black] +line-length = 125 +skip-string-normalization = true + +[tool.isort] +profile = "black" +line_length = 125 + +[tool.check-manifest] +ignore = [ + ".github/**", + "docs/**", + ".pre-commit-config.yaml", + "Jenkinsfile", + "hls4ml/_version.py", +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 0b81e7b592..0000000000 --- a/setup.cfg +++ /dev/null @@ -1,70 +0,0 @@ -[metadata] -name = hls4ml -description = Machine learning in FPGAs using HLS -long_description = file: README.md -long_description_content_type = text/markdown -url = https://fastmachinelearning.org/hls4ml -author = hls4ml Team -license = Apache-2.0 -license_files = LICENSE -classifiers = - Development Status :: 4 - Beta - Intended Audience :: Developers - Intended Audience :: Science/Research - License :: OSI Approved :: Apache Software License - Programming Language :: C++ - Programming Language :: Python :: 3 - Programming Language :: Python :: 3 :: Only - Topic :: Software Development :: Libraries - Topic :: Software Development :: Libraries :: Python Modules -description_file = README.md - -[options] -packages = find: -install_requires = - calmjs.parse - h5py - numpy - onnx>=1.4.0 - pydigitalwavetools==1.1 - pyparsing - pyyaml - tabulate - tensorflow>=2.8.0,<=2.14.1 - tensorflow-model-optimization<=0.7.5 -python_requires = >=3.10, <3.12 -include_package_data = True -scripts = scripts/hls4ml - -[options.entry_points] -pytest_randomly.random_seeder = - hls4ml = hls4ml:reseed - -[options.extras_require] -HGQ = - HGQ~=0.2.0 -optimization = - keras-tuner==1.1.3 - ortools==9.4.1874 - packaging -profiling = - matplotlib - pandas - seaborn -sr = - sympy -testing = - HGQ~=0.2.0 - pytest - pytest-cov - pytest-randomly - qonnx - torch - -[check-manifest] -ignore = - .github/** - docs/** - .pre-commit-config.yaml - Jenkinsfile - hls4ml/_version.py diff --git a/setup.py b/setup.py deleted file mode 100644 index 1abbd068c1..0000000000 --- a/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -import setuptools - -if __name__ == "__main__": - setuptools.setup() diff --git a/test/pytest/test_einsum_dense.py b/test/pytest/test_einsum_dense.py new file mode 100644 index 0000000000..f36a319ffb --- /dev/null +++ b/test/pytest/test_einsum_dense.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import keras +import numpy as np +import pytest + +from hls4ml.converters import convert_from_keras_model + +if keras.__version__ < '3.0.0': + pytest.skip('Only keras v3 is supported for now', allow_module_level=True) + +from keras.api.layers import EinsumDense, Input + +test_root_path = Path(__file__).parent + + +@pytest.mark.parametrize('strategy', ['latency']) +@pytest.mark.parametrize('io_type', ['io_parallel']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize( + 'operation', + [ + # eq, inp, out + ('bi,j->bij', (8,), (8, 7), None), + ('bi,j->bij', (8,), (8, 7), 'i'), + ('bi,j->bij', (8,), (8, 7), 'j'), + ('bi,io->bo', (8,), 7, None), + ('...i,oi->...o', (4, 3), (5,), None), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), None), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'aeb'), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'ab'), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'a'), + ], +) +def test_einsum_dense(backend, io_type, strategy, operation): + eq, inp_shape, out_shape, bias_axes = operation + model = keras.Sequential( + [Input(inp_shape), EinsumDense(eq, output_shape=out_shape, bias_axes=bias_axes, name='einsum_dense')] + ) + + if bias_axes is not None: + layer = model.get_layer('einsum_dense') + layer.bias.assign(keras.ops.convert_to_tensor(np.random.rand(*layer.bias.shape))) + + data = np.random.rand(1000, *inp_shape) + eq_name = eq.replace(',', '_').replace('->', '_') + ('' if bias_axes is None else f'_{bias_axes}') + output_dir = str(test_root_path / f'hls4mlprj_einsum_dense_{eq_name}_{backend}_{io_type}_{strategy}') + hls_config = {'Model': {'Precision': 'ap_fixed<32,8>', 'ReuseFactor': 1}, 'Strategy': strategy} + model_hls = convert_from_keras_model( + model, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type + ) + + model_hls.compile() + r_keras = model.predict(data, verbose=0, batch_size=1000) # type: ignore + r_hls = model_hls.predict(data).reshape(r_keras.shape) # type: ignore + + np.testing.assert_allclose(r_hls, r_keras, atol=2e-6, rtol=0) diff --git a/test/pytest/test_keras_v3_api.py b/test/pytest/test_keras_v3_api.py new file mode 100644 index 0000000000..81ac5c240c --- /dev/null +++ b/test/pytest/test_keras_v3_api.py @@ -0,0 +1,516 @@ +import math +from pathlib import Path + +import keras +import numpy as np +import pytest + +if keras.__version__ < '3.0': + pytest.skip('Keras API tests are only for Keras 3.0 and above', allow_module_level=True) + +from keras.api.layers import ( + ELU, + Activation, + AveragePooling1D, + AveragePooling2D, + Conv1D, + Conv2D, + Dense, + DepthwiseConv1D, + DepthwiseConv2D, + LeakyReLU, + MaxPooling1D, + MaxPooling2D, + PReLU, +) + +import hls4ml + +test_root_path = Path('/tmp/tests') + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_dense(backend, io_type): + model = keras.Sequential( + [ + Dense( + 2, + input_shape=(1,), + name='Dense', + use_bias=True, + kernel_initializer=keras.initializers.RandomUniform(minval=1, maxval=10), # type: ignore + bias_initializer='zeros', + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + ), + Activation(activation='elu', name='Activation'), + ] + ) + model.compile(optimizer='adam', loss='mse') + + X_input = np.random.rand(1000, 1) + + keras_prediction = model.predict(X_input, verbose=0) # type: ignore + + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_dense_{backend}_{io_type}') + + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + + hls_model.compile() + + hls_prediction = hls_model.predict(X_input) + + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.02) + + assert len(model.layers) + 1 == len(hls_model.get_layers()) + assert list(hls_model.get_layers())[0].attributes['class_name'] == "InputLayer" + assert list(hls_model.get_layers())[1].attributes["class_name"] == model.layers[0].name + assert list(hls_model.get_layers())[2].attributes['class_name'] == 'ELU' + + +# TODO: add ThresholdedReLU test when it can be made to pass +# https://github.com/fastmachinelearning/hls4ml/issues/376 + + +@pytest.mark.parametrize( + "activation_function", + [ + Activation(activation='relu', name='relu'), + LeakyReLU(negative_slope=0.5), + ELU(alpha=1.0), + PReLU( + alpha_initializer="zeros", + ), + Activation(activation='sigmoid', name='sigmoid'), + ], +) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_activations(activation_function, backend, io_type): + model = keras.models.Sequential() + model.add(Dense(64, input_shape=(1,), name='Dense', kernel_initializer='lecun_uniform', kernel_regularizer=None)) + model.add(activation_function) + + model.compile(optimizer='adam', loss='mse') + + model.summary() + + X_input = np.random.rand(1000, 1) + keras_prediction = model.predict(X_input, verbose=0) # type: ignore + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_activations_{activation_function.name}_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + hls_prediction = hls_model.predict(X_input) + + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.02) + + for layer in hls_model.get_layers(): + print(layer.attributes.attributes['class_name']) + assert len(model.layers) + 1 == len(hls_model.get_layers()) + + assert list(hls_model.get_layers())[2].attributes['class_name'] == activation_function.__class__.__name__ + + +padds_options = ['same', 'valid'] + + +@pytest.mark.parametrize('padds', padds_options) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_conv1d(padds, backend, io_type): + model = keras.models.Sequential() + input_shape = (10, 128, 4) + model.add( + Conv1D( + filters=32, + kernel_size=3, + strides=2, + padding=padds, + activation='relu', + input_shape=input_shape[1:], + kernel_initializer='normal', + use_bias=False, + data_format='channels_last', + name='conv', + ) + ) + model.add(Activation(activation='relu')) + model.compile(optimizer='adam', loss='mse') + + X_input = np.random.rand(10, 128, 4) + keras_prediction = model.predict(X_input, verbose=0) # type: ignore + + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_conv1d_{padds}_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape) # type: ignore + + # 5e-2 might be too high + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=5e-2) + + if backend in ('Vivado', 'Vitis', 'Catapult') and io_type == 'io_stream' and padds == 'same': + # Vivado/Vitis inserts and additional layer for 'same' padding in io_stream + return + + conv: keras.layers.Conv1D = model.layers[0] + ker_w, ch_in, ch_out = conv.kernel.shape + inp_shape = model.inputs[0].shape[1:] + out_shape = model.outputs[0].shape[1:] + hls_attr = hls_model.graph['conv'].attributes + _stride = conv.strides[0] + + assert len(model.layers) + 2 == len(hls_model.get_layers()) + + assert hls_attr['name'] == model.layers[0].name + assert hls_attr['class_name'] == 'Conv1D' + assert hls_attr["in_width"] == inp_shape[0] + assert hls_attr['filt_width'] == ker_w + assert hls_attr['n_chan'] == ch_in + assert hls_attr['n_filt'] == ch_out + assert hls_attr['stride_width'] == _stride + assert hls_attr['data_format'] == conv.data_format + assert hls_attr["out_width"] == out_shape[0] + + w_pad = math.ceil(inp_shape[0] / ker_w) * ker_w - inp_shape[0] + + pad_left = w_pad // 2 + pad_right = w_pad - pad_left + + if model.layers[0].padding == 'same': + assert hls_attr['pad_left'] == pad_left + assert hls_attr['pad_right'] == pad_right + elif model.layers[0].padding == 'valid': + assert hls_attr['pad_left'] == 0 + assert hls_attr['pad_right'] == 0 + + +chans_options = ['channels_last'] +padds_options = ['same', 'valid'] + + +@pytest.mark.parametrize('chans', chans_options) +@pytest.mark.parametrize('padds', padds_options) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_conv2d(chans, padds, backend, io_type): + input_shape = (32, 32, 3) + model = keras.Sequential( + [ + keras.layers.InputLayer(input_shape), + Conv2D( + filters=32, + kernel_size=(2, 3), + strides=(4, 5), + padding=padds, + kernel_initializer='normal', + use_bias=False, + data_format=chans, + name='conv', + ), + ] + ) + model.compile(optimizer='adam', loss='mse') + + X_input = np.random.rand(1000, *input_shape) + keras_prediction = model.predict(X_input) + + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4ml_project_keras_api_conv2d_{backend}_{chans}_{padds}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape) # type: ignore + + # A high tolerance, simply to verify correct functionality + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=5e-2) + + hls_conv_attr = hls_model.graph['conv'].attributes + + conv: keras.layers.Conv2D = model.get_layer('conv') + + kh, kw, ch_in, ch_out = conv.kernel.shape # type: ignore + _stride = conv.strides + inp_shape = model.inputs[0].shape[1:] + out_shape = model.outputs[0].shape[1:] + + if io_type == 'io_stream' and padds == 'same' and backend in ('Vivado', 'Vitis', 'Catapult'): + return + + assert len(model.layers) + 1 == len(hls_model.get_layers()) + assert hls_conv_attr['name'] == conv.name + assert hls_conv_attr['class_name'] == 'Conv2D' + assert hls_conv_attr['filt_width'] == kw + assert hls_conv_attr['filt_height'] == kh + assert hls_conv_attr['n_filt'] == ch_out + assert hls_conv_attr['stride_width'] == _stride[1] + assert hls_conv_attr['stride_height'] == _stride[0] + assert hls_conv_attr['data_format'] == conv.data_format + + if conv.data_format == 'channels_first': + assert hls_conv_attr['n_chan'] == inp_shape[0] + assert hls_conv_attr['in_height'] == inp_shape[1] + assert hls_conv_attr['in_width'] == inp_shape[2] + assert hls_conv_attr['out_height'] == out_shape[1] + assert hls_conv_attr['out_width'] == out_shape[2] + elif model.layers[0].data_format == 'channels_last': + assert hls_conv_attr['n_chan'] == inp_shape[2] + assert hls_conv_attr['in_height'] == inp_shape[0] + assert hls_conv_attr['in_width'] == inp_shape[1] + assert hls_conv_attr['out_height'] == out_shape[0] + assert hls_conv_attr['out_width'] == out_shape[1] + + if conv.padding == 'same': + if conv.data_format == 'channels_first': + h_pad = math.ceil(inp_shape[1] / kh) * kh - inp_shape[1] + w_pad = math.ceil(inp_shape[2] / kw) * kw - inp_shape[2] + elif model.layers[0].data_format == 'channels_last': + h_pad = math.ceil(inp_shape[0] / kh) * kh - inp_shape[0] + w_pad = math.ceil(inp_shape[1] / kw) * kw - inp_shape[1] + else: + raise ValueError('Invalid data_format') + pad_top = h_pad // 2 + pad_bottom = h_pad - pad_top + pad_left = w_pad // 2 + pad_right = w_pad - pad_left + assert hls_conv_attr['pad_top'] == pad_top + assert hls_conv_attr['pad_bottom'] == pad_bottom + assert hls_conv_attr['pad_left'] == pad_left + assert hls_conv_attr['pad_right'] == pad_right + elif model.layers[0].padding == 'valid': + assert hls_conv_attr['pad_top'] == 0 + assert hls_conv_attr['pad_bottom'] == 0 + assert hls_conv_attr['pad_left'] == 0 + assert hls_conv_attr['pad_right'] == 0 + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel']) +def test_depthwise2d(backend, io_type): + ''' + Test proper handling of DepthwiseConv2D + ''' + X = np.random.rand(10, 32, 32, 3) + X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6> + model = keras.models.Sequential([keras.layers.Input((32, 32, 3)), DepthwiseConv2D(kernel_size=(3, 3))]) + model.compile() + + config = hls4ml.utils.config_from_keras_model( + model, granularity='name', default_precision='fixed<32,12>', backend=backend + ) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_depthwiseconv2d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + + y_qkeras = model.predict(X) + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qkeras, y_hls4ml.reshape(y_qkeras.shape), rtol=1e-2, atol=0.01) # type: ignore + + +# Currently only Vivado and Vitis is supported for io_stream. +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize('io_type', ['io_stream']) +def test_depthwise1d(backend, io_type): + ''' + Test proper handling of DepthwiseConv1D. + ''' + X = np.random.rand(10, 32, 3) + X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6> + model = keras.Sequential([DepthwiseConv1D(kernel_size=3, input_shape=(32, 3))]) + model.compile() + + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_depthwiseconv1d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + + y_qkeras = model.predict(X) + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qkeras, y_hls4ml.reshape(y_qkeras.shape), rtol=1e-2, atol=0.01) # type: ignore + + +pooling_layers = [MaxPooling1D, MaxPooling2D, AveragePooling1D, AveragePooling2D] + + +@pytest.mark.parametrize('pooling', pooling_layers) +@pytest.mark.parametrize('padds', padds_options) +@pytest.mark.parametrize('chans', chans_options) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +def test_pooling(pooling, padds, chans, backend): + assert '1D' in pooling.__name__ or '2D' in pooling.__name__ + + input_shape = (18, 15, 3) if '2D' in pooling.__name__ else (121, 3) + pool_size = (4, 2) if '2D' in pooling.__name__ else 2 + + X_input = np.random.rand(100, *input_shape) + + keras_model = keras.Sequential([pooling(pool_size, padding=padds, input_shape=input_shape)]) + keras_model.compile() + + hls_cfg = hls4ml.utils.config_from_keras_model(keras_model) + output_dir = str( + test_root_path / f'hls4mlprj_keras_api_pooling_{pooling.__name__}_channels_{chans}_padds_{padds}_backend_{backend}' + ) + hls_model = hls4ml.converters.convert_from_keras_model( + keras_model, hls_config=hls_cfg, output_dir=output_dir, backend=backend + ) + hls_model.compile() + + # Verify accuracy + keras_prediction = keras_model.predict(X_input) + hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape) # type: ignore + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=3e-2) + + # # Verify correct parsing of layer + # hls_pool = list(hls_model.get_layers())[-1] + # ker_pool = keras_model.layers[-1] + # if '2D' in pooling.__name__: + # assert hls_pool.attributes['name'] == ker_pool._name + # assert hls_pool.attributes['class_name'][-2] == str(2) + # assert hls_pool.attributes['stride_height'] == ker_pool.strides[0] + # assert hls_pool.attributes['stride_width'] == ker_pool.strides[1] + # assert hls_pool.attributes['pool_height'] == ker_pool.pool_size[1] + # assert hls_pool.attributes['pool_width'] == ker_pool.pool_size[0] + + # if hls_pool.attributes['data_format'] == 'channels_last': + # assert hls_pool.attributes['in_height'] == ker_pool.input_shape[1] + # assert hls_pool.attributes['in_width'] == ker_pool.input_shape[2] + # assert hls_pool.attributes['n_filt'] == ker_pool.input_shape[3] + # elif hls_pool.attributes['data_format'] == 'channels_first': + # assert hls_pool.attributes['in_height'] == ker_pool.input_shape[2] + # assert hls_pool.attributes['in_width'] == ker_pool.input_shape[3] + # assert hls_pool.attributes['n_filt'] == ker_pool.input_shape[1] + + # if ker_pool.padding == 'same': + # # Height + # in_height = ker_pool.input_shape[1] + # if ker_pool.data_format == 'channels_first': + # in_height = ker_pool.input_shape[2] + # out_height = int(math.ceil(float(in_height) / float(ker_pool.strides[0]))) + # assert out_height == hls_pool.attributes['out_height'] + # if in_height % ker_pool.strides[0] == 0: + # pad_along_height = max(ker_pool.pool_size[1] - ker_pool.strides[0], 0) + # else: + # pad_along_height = max(ker_pool.pool_size[1] - (in_height % ker_pool.strides[0]), 0) + # pad_top = pad_along_height // 2 + # pad_bottom = pad_along_height - pad_top + # assert pad_bottom == hls_pool.attributes['pad_bottom'] + # assert pad_top == hls_pool.attributes['pad_top'] + + # # Width + # in_width = ker_pool.input_shape[2] + # if ker_pool.data_format == 'channels_first': + # in_height = keras_model.layers[1].input_shape[-1] + # out_width = int(math.ceil(float(in_width) / float(ker_pool.strides[1]))) + # assert out_width == hls_pool.attributes['out_width'] + # if in_width % ker_pool.strides[1] == 0: + # pad_along_width = max(ker_pool.pool_size[0] - ker_pool.strides[1], 0) + # else: + # pad_along_width = max(ker_pool.pool_size[0] - (in_width % ker_pool.strides[1]), 0) + # pad_left = pad_along_width // 2 + # pad_right = pad_along_width - pad_left + # assert pad_left == hls_pool.attributes['pad_left'] + # assert pad_right == hls_pool.attributes['pad_right'] + + # elif ker_pool.padding == 'valid': + # if hls_pool.attributes['data_format'] == 'channels_first': + # in_height = ker_pool.input_shape[2] + # in_width = ker_pool.input_shape[3] + # elif hls_pool.attributes['data_format'] == 'channels_last': + # in_height = ker_pool.input_shape[1] + # in_width = ker_pool.input_shape[2] + # else: + # raise ValueError('Invalid data_format') + + # out_width = int(math.ceil(float(in_width - ker_pool.pool_size[0] + 1) / float(ker_pool.strides[1]))) + # out_height = int(math.ceil(float(in_height - ker_pool.pool_size[1] + 1) / float(ker_pool.strides[0]))) + + # assert hls_pool.attributes['out_height'] == out_height + # assert hls_pool.attributes['out_width'] == out_width + # assert hls_pool.attributes['pad_top'] == 0 + # assert hls_pool.attributes['pad_bottom'] == 0 + # assert hls_pool.attributes['pad_left'] == 0 + # assert hls_pool.attributes['pad_right'] == 0 + + # elif '1D' in pooling.__name__: + # assert hls_pool.attributes['name'] == ker_pool._name + # assert hls_pool.attributes['class_name'][-2] == str(1) + # assert hls_pool.attributes['n_in'] == ker_pool.input_shape[1] + # assert hls_pool.attributes['n_filt'] == ker_pool.input_shape[2] + # assert hls_pool.attributes['pool_width'] == ker_pool.pool_size[0] + # assert hls_pool.attributes['stride_width'] == ker_pool.strides[0] + + # out_same = math.ceil(float(ker_pool.input_shape[1]) / float(ker_pool.strides[0])) + # out_valid = math.ceil(float(ker_pool.input_shape[1] - ker_pool.pool_size[0] + 1) / ker_pool.strides[0]) + + # if ker_pool.padding == 'same': + # assert hls_pool.attributes['n_out'] == out_same + # if ker_pool.input_shape[1] % ker_pool.strides[0] == 0: + # pad_along_width = max(ker_pool.pool_size[0] - ker_pool.strides[0], 0) + # else: + # pad_along_width = max(ker_pool.pool_size[0] - (ker_pool.input_shape[1] % ker_pool.strides[0]), 0) + # assert hls_pool.attributes['pad_left'] == pad_along_width // 2 + # assert hls_pool.attributes['pad_right'] == pad_along_width - pad_along_width // 2 + + # elif ker_pool.padding == 'valid': + # assert hls_pool.attributes['n_out'] == out_valid + # assert hls_pool.attributes['pad_left'] == 0 + # assert hls_pool.attributes['pad_right'] == 0 + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult', 'oneAPI']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_reused_layer(backend, io_type): + + inp1 = keras.layers.Input(shape=(10, 10)) + inp2 = keras.layers.Input(shape=(10, 10)) + + conv = keras.layers.Conv1D(2, 3, activation='relu') + + o1 = conv(inp1) + o2 = conv(inp2) + o3 = keras.layers.Add()([o1, o2]) + o4 = keras.layers.Dense(5)(o3) + + _ = keras.layers.Dense(5)(o3) + + model = keras.models.Model(inputs=[inp1, inp2], outputs=[o1, o2, o3, o4]) + + _ = model([inp1, inp1]) + + hls_config = {'Model': {'Precision': 'ap_fixed<32,8>', 'ReuseFactor': 1}} + output_dir = str(test_root_path / f'hls4mlprj_keras_api_conv1d_{backend}_{io_type}') + + model_hls = hls4ml.converters.convert_from_keras_model( + model, backend=backend, io_type=io_type, hls_config=hls_config, output_dir=output_dir + ) + + model_hls.compile() + + data = [np.random.rand(1000, 10, 10).astype(np.float32), np.random.rand(1000, 10, 10).astype(np.float32)] + keras_pred = model.predict(data) + hls_pred = model_hls.predict(data) + + np.testing.assert_allclose(keras_pred[0].reshape(hls_pred[0].shape), hls_pred[0], rtol=0, atol=1e-5) + np.testing.assert_allclose(keras_pred[1].reshape(hls_pred[1].shape), hls_pred[1], rtol=0, atol=1e-5) + np.testing.assert_allclose(keras_pred[2].reshape(hls_pred[2].shape), hls_pred[2], rtol=0, atol=1e-5) + np.testing.assert_allclose(keras_pred[3].reshape(hls_pred[3].shape), hls_pred[3], rtol=0, atol=1e-2) diff --git a/test/pytest/test_optimization/test_attributes.py b/test/pytest/test_optimization/test_attributes.py index a42d3a6751..c9e22091f2 100644 --- a/test/pytest/test_optimization/test_attributes.py +++ b/test/pytest/test_optimization/test_attributes.py @@ -1,7 +1,7 @@ from tensorflow.keras.layers import Conv2D, Dense, Flatten, ReLU from tensorflow.keras.models import Sequential -from hls4ml.optimization import get_attributes_from_keras_model_and_hls4ml_config +from hls4ml.optimization.dsp_aware_pruning import get_attributes_from_keras_model_and_hls4ml_config from hls4ml.utils.config import config_from_keras_model diff --git a/test/pytest/test_qeinsum.py b/test/pytest/test_qeinsum.py new file mode 100644 index 0000000000..fd264f23d6 --- /dev/null +++ b/test/pytest/test_qeinsum.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import keras +import numpy as np +import pytest +from keras.api.layers import Input + +from hls4ml.converters import convert_from_keras_model + +if keras.__version__ < '3.0.0': + pytest.skip('Only keras v3 is supported for now', allow_module_level=True) + +try: + from squark.layers import QEinsum + from squark.utils import trace_mode +except ImportError: + pytest.skip('s-quark is not installed', allow_module_level=True) + +test_root_path = Path(__file__).parent + + +@pytest.mark.parametrize('strategy', ['latency']) +@pytest.mark.parametrize('io_type', ['io_parallel']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize( + 'operation', + [ + # eq, inp, out + ('xbi,xj->xbij', (8, 16), (16,)), + ('xbi,xio->xbo', (7, 8), (8, 9)), + ('xi,xoi->xo', (16,), (20, 16)), + ('xabcd,xbcde->xaeb', (2, 4, 8, 16), (4, 8, 16, 3)), + ], +) +def test_einsum_dense(backend, io_type, strategy, operation): + eq, inp0_shape, inp1_shape = operation + inp0 = Input(inp0_shape) + inp1 = Input(inp1_shape) + out = QEinsum(eq, name='einsum')([inp0, inp1]) + model = keras.Model(inputs=[inp0, inp1], outputs=out) + + data = np.random.randn(1000, *inp0_shape).astype(np.float32), np.random.randn(1000, *inp1_shape).astype(np.float32) + eq_name = eq.replace(',', '_').replace('->', '_') + output_dir = str(test_root_path / f'hls4mlprj_einsum_{eq_name}_{backend}_{io_type}_{strategy}') + hls_config = {'Model': {'Precision': 'ap_fixed<1,0>', 'ReuseFactor': 1}, 'Strategy': strategy} + + with trace_mode(model): + r_keras = model.predict(data, verbose=0, batch_size=1000) # type: ignore + + model_hls = convert_from_keras_model( + model, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type + ) + + model_hls.compile() + r_hls = model_hls.predict(data).reshape(r_keras.shape) # type: ignore + + assert np.all(r_hls.ravel() == r_keras.ravel()) diff --git a/test/pytest/test_qinterval.py b/test/pytest/test_qinterval.py new file mode 100644 index 0000000000..78f565e129 --- /dev/null +++ b/test/pytest/test_qinterval.py @@ -0,0 +1,103 @@ +import numpy as np +import pytest +from quantizers.fixed_point import get_fixed_quantizer_np + +from hls4ml.utils.qinterval import QIntervalArray, einsum, minimal_kif + + +def assert_is_represented(qinterval: QIntervalArray, data: np.ndarray): + assert np.all(data <= qinterval.max), f'{np.max(data - qinterval.max)} > 0' + assert np.all(data >= qinterval.min), f'{np.min(data - qinterval.min)} < 0' + with np.errstate(divide='ignore', invalid='ignore'): + is_zero = (qinterval.max == 0) & (qinterval.min == 0) + assert np.all((data % qinterval.delta == 0) | is_zero) + + +@pytest.fixture(scope='module') +def data(): + arr = np.random.randint(-1024, 1024, size=1000000) + arr = arr * 2.0 ** np.random.randint(-20, 20, size=1000000) + return arr + + +def test_minimal_kif(data): + k, i, f = minimal_kif(data) + q = get_fixed_quantizer_np() + assert np.all(data == q(data, k, i, f)) + assert np.all((data != q(data, k, i, f - 1)) | (data == 0)) + assert np.all((data != q(data, k, i - 1, f)) | (data == 0) | (i + f == 0)) + + +def random_arr(seed=None): + rng = np.random.default_rng(seed) + shape = (64, 64) + + _delta = 2.0 ** rng.integers(-8, 8, shape) + _min = rng.integers(-1024, 1024, shape) * _delta + _max = rng.integers(0, 4096, shape) * _delta + _min + interval_arr = QIntervalArray(_min, _max, _delta) + return interval_arr + + +@pytest.fixture(scope='module') +def qint_arr1(): + return random_arr() + + +@pytest.fixture(scope='module') +def qint_arr2(): + return random_arr() + + +@pytest.mark.parametrize('oprstr', ['__add__', '__sub__', '__mul__', '__matmul__', '__rmatmul__']) +def test_qinterval_oprs(qint_arr1, qint_arr2, oprstr): + + sampled_arr1 = qint_arr1.sample(10000) + const_arr = qint_arr2.sample() + applied_symbolic = getattr(qint_arr1, oprstr)(const_arr) + applied_sampled = getattr(sampled_arr1, oprstr)(const_arr) + + assert_is_represented(applied_symbolic, applied_sampled) + + if oprstr != '__rmatmul__': + # rmatmul is only between const and intervals. + + sampled_arr2 = qint_arr2.sample(10000) + rapplied_symbolic = getattr(qint_arr1, oprstr)(qint_arr2) + rapplied_sampled = getattr(sampled_arr1, oprstr)(sampled_arr2) + + assert_is_represented(rapplied_symbolic, rapplied_sampled) + + +@pytest.mark.parametrize('eq', ['ij,jk->ik', 'ij,kj->ikj']) +def test_qinterval_einsum(qint_arr1, qint_arr2, eq): + + _in, out = eq.split('->', 1) + in0, in1 = _in.split(',', 1) + qint_arr1 = qint_arr1[:16, :16] + qint_arr2 = qint_arr2[:16, :16] + + sampled_arr1 = qint_arr1.sample(10000) + sampled_arr2 = qint_arr2.sample(10000) + + # symbolic - symbolic + einsum_symbolic = einsum(eq, qint_arr1, qint_arr2) + einsum_sampled = np.einsum(f'A{in0},A{in1}->A{out}', sampled_arr1, sampled_arr2) + assert_is_represented(einsum_symbolic, einsum_sampled) + + # symbolic - sampled + einsum_symbolic = einsum(eq, qint_arr1, sampled_arr2[0]) + einsum_sampled = np.einsum(f'A{in0},{in1}->A{out}', sampled_arr1, sampled_arr2[0]) + assert_is_represented(einsum_symbolic, einsum_sampled) + + # sampled - symbolic + einsum_symbolic = einsum(eq, sampled_arr1[0], qint_arr2) + einsum_sampled = np.einsum(f'{in0},A{in1}->A{out}', sampled_arr1[0], sampled_arr2) + assert_is_represented(einsum_symbolic, einsum_sampled) + + +def test_qinterval_to_kif(qint_arr1): + k, i, f = qint_arr1.to_kif() + samples = qint_arr1.sample(10000) + q = get_fixed_quantizer_np() + assert np.all(samples == q(samples, k, i, f)) diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index 048b6832ee..73c54711c8 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -22,18 +22,20 @@ def generate_data(input_shape): @pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) @pytest.mark.parametrize('strategy', ['stable', 'latency', 'argmax']) @pytest.mark.parametrize( - 'input_bits,input_shape,table_bits,io_type', + 'input_bits,input_shape,table_bits,io_type,custom_accum', [ - ('16,6', (8,), '18,8', 'io_parallel'), - ('16,6', (8,), '18,8', 'io_stream'), - ('16,6', (8,), '9,6', 'io_parallel'), - ('16,6', (8,), '9,6', 'io_stream'), - ('9,6', (8,), '18,8', 'io_parallel'), - ('9,6', (8,), '18,8', 'io_stream'), - ('16,6', (8, 8, 3), '18,8', 'io_stream'), + ('16,6', (8,), '18,8', 'io_parallel', False), + ('16,6', (8,), '18,8', 'io_stream', False), + ('16,6', (8,), '18,8', 'io_parallel', True), + ('16,6', (8,), '18,8', 'io_stream', True), + ('16,6', (8,), '9,6', 'io_parallel', False), + ('16,6', (8,), '9,6', 'io_stream', False), + ('9,6', (8,), '18,8', 'io_parallel', False), + ('9,6', (8,), '18,8', 'io_stream', False), + ('16,6', (8, 8, 3), '18,8', 'io_stream', False), ], ) -def test_softmax(backend, strategy, generate_data, input_bits, input_shape, table_bits, io_type): +def test_softmax(backend, strategy, generate_data, input_bits, input_shape, table_bits, io_type, custom_accum): X = generate_data model = tf.keras.models.Sequential() model.add(tf.keras.layers.Activation(input_shape=input_shape, activation='softmax', name='softmax')) @@ -45,11 +47,23 @@ def test_softmax(backend, strategy, generate_data, input_bits, input_shape, tabl cfg['LayerName']['softmax']['Strategy'] = strategy cfg['LayerName']['softmax']['inv_table_t'] = table_type cfg['LayerName']['softmax']['exp_table_t'] = table_type - cfg['LayerName']['softmax_input']['Precision']['result'] = f'fixed<{input_bits}>' + cfg['LayerName']['softmax']['accum_t'] = table_type + cfg['LayerName']['softmax']['inv_inp_t'] = table_type + if custom_accum: + if backend not in ['Vivado', 'Vitis']: + pytest.skip('Custom accumulators are only supported for Vivado and Vitis backends') + W, I = map(int, input_bits.split(',')) # noqa: E741 + cfg['LayerName']['softmax']['accum_t'] = f'fixed<{W+3},{I+3}>' + cfg['LayerName']['softmax']['inv_inp_t'] = f'fixed<{W+2},{I+2}>' + inp_layer_name = next(iter(cfg['LayerName'].keys())) + cfg['LayerName'][inp_layer_name]['Precision']['result'] = f'fixed<{input_bits}>' odir = str( test_root_path - / f'hls4mlprj_softmax_{backend}_{io_type}_{strategy}_{input_shape}_input-bits={input_bits}_table-bits={table_bits}' + / ( + f'hls4mlprj_softmax_{backend}_{io_type}_{strategy}_{input_shape}' + f'_input-bits={input_bits}_table-bits={table_bits}_custom-accum={custom_accum}' + ) ) hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=cfg, io_type=io_type, output_dir=odir, backend=backend