Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] s-quark support #1154

Open
wants to merge 69 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
ef20647
import converter dependencies lazily
calad0i Oct 26, 2024
028b4d0
make tf and qkeras optionl, stop assuming keras is tf.keras
calad0i Oct 26, 2024
72eb053
less mandatory dependency
calad0i Oct 26, 2024
63af2ac
fix dsp_aware_pruning test import path
calad0i Oct 26, 2024
c11dddb
fix broken setup.cfg after rebase, rm pyparsing
calad0i Dec 15, 2024
d9aaa1a
purge qkeras workaround
calad0i Dec 15, 2024
4854423
switch to pyproject.toml
calad0i Dec 15, 2024
06f9cda
format
calad0i Dec 15, 2024
014c1db
rm useless flake8 config in pyprject.toml
calad0i Dec 15, 2024
d3c8881
Add hint on import failure
calad0i Dec 16, 2024
738e5b0
leftover
calad0i Dec 16, 2024
bc7778b
rm setup.py from manifest
calad0i Dec 16, 2024
b76b5cb
manifest fix 2
calad0i Dec 16, 2024
b7f60f5
keras v3 object based parser
calad0i Nov 7, 2024
a7206b4
sequential and i/o tensor name parsing fix
calad0i Nov 8, 2024
1605f96
support activation layers
calad0i Nov 8, 2024
a8aa489
consistent v2 weight reader behavior
calad0i Nov 8, 2024
eafe8b9
add v3 conv handlers
calad0i Nov 8, 2024
6b8a44c
add test
calad0i Nov 8, 2024
3f8acb5
pre-commit fix
calad0i Dec 17, 2024
d2ccfb4
revert keras v2 converter
calad0i Dec 6, 2024
0334960
make reshape handler compatiable with keras v3
calad0i Nov 13, 2024
074b4b6
add general transpose for vivado/vitis
calad0i Nov 13, 2024
29674db
general einsum support for io_parallel and latency
calad0i Nov 15, 2024
1fb23b9
add tests for einsumdense
calad0i Nov 15, 2024
5489803
keras v3 converter clean-up
calad0i Nov 19, 2024
5e18781
add symbolic quantized interval
calad0i Dec 2, 2024
02ff0c3
preliminary bit-exact precision derivation opt pass
calad0i Dec 4, 2024
7c47be9
squark layer support start
calad0i Dec 4, 2024
43847c4
fix einsum_dense precision computation
calad0i Dec 4, 2024
afdaf21
add leftover
calad0i Dec 4, 2024
0da5cd0
qdense fix
calad0i Dec 4, 2024
6b73774
support batch_norm
calad0i Dec 4, 2024
93043de
support merge layers
calad0i Dec 4, 2024
d8708f5
support bit-exact q_einsum and fix precision trace for multi inp layers
calad0i Dec 5, 2024
cba1411
add einsum test
calad0i Dec 5, 2024
f8ae929
declare all softmax attrs in layer class
calad0i Dec 6, 2024
9326ad5
fix lazy import in handler
calad0i Dec 6, 2024
0cde312
cleanup einsum handler
calad0i Dec 6, 2024
b97d01e
cleanup einsum handler
calad0i Dec 6, 2024
c34abbe
more granular control over softmax for vivado
calad0i Dec 6, 2024
7ea6310
properly propagate inv/exp_table_size
calad0i Dec 7, 2024
0ecd12e
support bit-exact softmax for stable impl
calad0i Dec 7, 2024
fdfaac5
bit-exact softmax fix and leftovers
calad0i Dec 7, 2024
3f4c642
softmax table fixer update
calad0i Dec 7, 2024
bf99e83
support input scaler in softmax
calad0i Dec 8, 2024
b925bc8
support multidim parallel softmax
calad0i Dec 8, 2024
c611c77
fuse quantizer when possible
calad0i Dec 8, 2024
b7975fa
partial activation, fix input precision in SAT mode
calad0i Dec 9, 2024
3d1431e
fix padded convXd precition derivation rule
calad0i Dec 9, 2024
f97d4d8
add unary lut support
calad0i Dec 9, 2024
61e76a2
fix bit-exact corner case introduced by reverse flow
calad0i Dec 10, 2024
e50e731
general data_t inference
calad0i Dec 10, 2024
4a6b0b5
softmax compatbility
calad0i Dec 11, 2024
a6128ae
fix typo in einsum handler
calad0i Dec 11, 2024
5190c33
fix more typos
calad0i Dec 11, 2024
9cdb67c
MHA :tada:
calad0i Dec 11, 2024
5bcae96
fix einsum and softmax template typos
calad0i Dec 11, 2024
d780de2
assert einsum ops doesnot include direct sum operation
calad0i Dec 12, 2024
e3cef20
style
calad0i Dec 13, 2024
2bcf9e7
fix mha layer indexing
calad0i Dec 13, 2024
c426ddc
switch to model opt
calad0i Dec 14, 2024
a749c27
pooling layers
calad0i Dec 15, 2024
0317b5b
handle stray inputs
calad0i Dec 15, 2024
b38420d
fix pooling layer accum_t
calad0i Dec 15, 2024
a2d6e1a
bit-exact concatenate
calad0i Dec 15, 2024
af5c798
rm np.float_ in favor of numpy >=2.0
calad0i Jan 17, 2025
c32df4b
add comments
calad0i Jan 18, 2025
fe0ff2f
skip non-bit-exact compatiable softmax in bit-exact pass
calad0i Jan 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@ repos:
args: ['--line-length=125',
'--skip-string-normalization']

- repo: https://github.com/tox-dev/pyproject-fmt
rev: v2.5.0
hooks:
- id: pyproject-fmt

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: check-toml
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
Expand All @@ -27,27 +33,25 @@ repos:
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", --line-length=125]

- repo: https://github.com/asottile/pyupgrade
rev: v3.19.0
hooks:
- id: pyupgrade
args: ["--py36-plus"]

- repo: https://github.com/asottile/setup-cfg-fmt
rev: v2.7.0
hooks:
- id: setup-cfg-fmt

- repo: https://github.com/pycqa/flake8
rev: 7.1.1
hooks:
- id: flake8
exclude: docs/conf.py
additional_dependencies: [flake8-bugbear, flake8-print]
args: ['--max-line-length=125', # github viewer width
'--extend-ignore=E203,T201'] # E203 is not PEP8 compliant
'--extend-ignore=E203,T201', # E203 is not PEP8 compliant
'--per-file-ignores=hls4ml/model/optimizer/passes/bit_exact.py:E741,hls4ml/converters/keras_v3/squark/_base.py:E741,__init__.py:F401',
# i for #int w/o sign, I for #int w/ sign when massively processing bw conversions ......
# ignore unused imports in __init__.py .....
]

- repo: https://github.com/mgedmin/check-manifest
rev: "0.50"
Expand Down
5 changes: 3 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
include LICENSE README.md CONTRIBUTING.md CITATION.cff pyproject.toml setup.py setup.cfg .clang-format
include LICENSE README.md CONTRIBUTING.md CITATION.cff pyproject.toml .clang-format
graft example-models
graft test
graft contrib
recursive-include hls4ml/templates *
global-exclude .git .gitmodules .gitlab-ci.yml
recursive-include hls4ml *.py
global-exclude .git .gitmodules .gitlab-ci.yml *.pyc
include hls4ml/backends/vivado_accelerator/supported_boards.json
30 changes: 0 additions & 30 deletions hls4ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,3 @@
# Temporary workaround for QKeras installation requirement, will be removed after 1.0.0
def maybe_install_qkeras():
import subprocess
import sys

QKERAS_PKG_NAME = 'QKeras'
# QKERAS_PKG_SOURCE = QKERAS_PKG_NAME
QKERAS_PKG_SOURCE = 'qkeras@git+https://github.com/fastmachinelearning/qkeras.git'

def pip_list():
p = subprocess.run([sys.executable, '-m', 'pip', 'list'], check=True, capture_output=True)
return p.stdout.decode()

def pip_install(package):
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])

all_pkgs = pip_list()
if QKERAS_PKG_NAME not in all_pkgs:
print('QKeras installation not found, installing one...')
pip_install(QKERAS_PKG_SOURCE)
print('QKeras installed.')


try:
maybe_install_qkeras()
except Exception:
print('Could not find QKeras installation, make sure you have QKeras installed.')

# End of workaround

from hls4ml import converters, report, utils # noqa: F401, E402

try:
Expand Down
33 changes: 1 addition & 32 deletions hls4ml/backends/fpga/fpga_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from hls4ml.backends.backend import Backend
from hls4ml.model.attributes import ChoiceAttribute, ConfigurableAttribute, TypeAttribute
from hls4ml.model.attributes import ConfigurableAttribute, TypeAttribute
from hls4ml.model.layers import (
GRU,
LSTM,
Expand All @@ -32,16 +32,13 @@
SeparableConv1D,
SeparableConv2D,
SimpleRNN,
Softmax,
)
from hls4ml.model.optimizer import model_optimizer
from hls4ml.model.types import (
ExponentPrecisionType,
FixedPrecisionType,
IntegerPrecisionType,
PrecisionType,
RoundingMode,
SaturationMode,
UnspecifiedPrecisionType,
XnorPrecisionType,
)
Expand Down Expand Up @@ -109,34 +106,6 @@ def __init__(self, name):
act_attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8), description=descriptions.table_type))
self.attribute_map[Activation] = act_attrs

softmax_attrs = self.attribute_map.get(Softmax, [])
softmax_attrs.append(
ChoiceAttribute(
'implementation',
['latency', 'stable', 'argmax', 'legacy'],
default='stable',
description=descriptions.softmax_implementation,
)
)
softmax_attrs.append(
ConfigurableAttribute('skip', value_type=bool, default=False, description=descriptions.softmax_skip)
)
softmax_attrs.append(
TypeAttribute(
'exp_table',
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
description=descriptions.table_type,
)
)
softmax_attrs.append(
TypeAttribute(
'inv_table',
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
description=descriptions.table_type,
)
)
self.attribute_map[Softmax] = softmax_attrs

def create_layer_class(self, layer_class):
new_attrubutes = []
for cls, attributes in self.attribute_map.items():
Expand Down
6 changes: 5 additions & 1 deletion hls4ml/backends/fpga/passes/fix_softmax_table_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

class FixSoftmaxTableSize(OptimizerPass):
def match(self, node):
return isinstance(node, Softmax)
if not isinstance(node, Softmax):
return False
if 'inv_table_size' in node.attributes:
return False # handler generating inv_table_size sets it properly
return True

def transform(self, model, node: Layer):
inp_layer = node.get_input_node() # type: ignore
Expand Down
5 changes: 0 additions & 5 deletions hls4ml/backends/fpga/passes/hgq_proxy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ def match(self, node: Layer):
return isinstance(node, FixedPointQuantizer)

def transform(self, model, node: FixedPointQuantizer):
if node.fusible:
model.remove_node(node, rewire=True)
return True

if model.config.config['IOType'] != 'io_parallel':
raise NotImplementedError('Heterogenous quantization for activations is only supported with IOType=io_parallel')

Expand Down Expand Up @@ -94,7 +90,6 @@ def __init__(self):

def format(self, node):
params = self._default_function_params(node)
node.attributes['result_t'].precision = node.attributes['table_t'].precision
params['config'] = f'unary_lut_config{node.index}'
params['table'] = node.get_weights('table').name

Expand Down
46 changes: 44 additions & 2 deletions hls4ml/backends/vivado/passes/core_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,21 @@ def format(self, node):

softmax_config_template = """struct {type}_config{index} : nnet::activ_config {{
static const unsigned n_in = {n_in};
static const unsigned table_size = {table_size};
static const unsigned n_outer = {n_outer};
static const unsigned n_inner = {n_inner};
static const unsigned parallelization_factor = {parallelization_factor};
static const unsigned exp_table_size = {exp_table_size};
static const unsigned inv_table_size = {inv_table_size};
static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
static const unsigned axis = {axis};
static const nnet::softmax_implementation implementation = nnet::softmax_implementation::{implementation};
static constexpr float exp_scale = {exp_scale};
typedef {exp_table_t.name} exp_table_t;
typedef {inv_table_t.name} inv_table_t;
typedef {accum_t.name} accum_t;
typedef {inv_inp_t.name} inv_inp_t;
typedef {inp_norm_t_str} inp_norm_t;
}};\n"""

activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});'
Expand Down Expand Up @@ -208,10 +216,44 @@ def __init__(self):
super(ActivationConfigTemplate, self).__init__(Softmax) # Skip ActivationConfigTemplate's __init__
self.template = softmax_config_template

def format(self, node):
params = self._default_config_params(node)
params['type'] = node.get_attr('activation')
params.setdefault('exp_table_size', params['table_size'])
params.setdefault('inv_table_size', params['table_size'])
params.setdefault('n_inner', 1)
params.setdefault('n_outer', 1)
params.setdefault('exp_scale', 1.0)
params.setdefault('parallelization_factor', -1)

if 'inp_norm_t' not in params:
input_t = node.get_input_variable().type.precision
width, iwidth = input_t.width, input_t.integer
params['inp_norm_t_str'] = f'ap_fixed<{width}, {iwidth}, AP_RND, AP_SAT>'
else:
params['inp_norm_t_str'] = params['inp_norm_t'].name # type: ignore

return self.template.format(**params)


class SoftmaxFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(Softmax, include_header=activ_include_list)
self.template = activ_function_template

def format(self, node):
params = self._default_function_params(node)
use_multidim = node.get_attr('n_inner', 1) > 1 or node.get_attr('n_outer', 1) > 1
use_multidim = use_multidim and node.model.config.get_config_value('IOType') == 'io_parallel'
params['activation'] = 'softmax' if not use_multidim else 'softmax_multidim'
params['config'] = f'softmax_config{node.index}'

return self.template.format(**params)


class ActivationFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__((Activation, HardActivation, Softmax), include_header=activ_include_list)
super().__init__((Activation, HardActivation), include_header=activ_include_list)
self.template = activ_function_template

def format(self, node):
Expand Down
105 changes: 105 additions & 0 deletions hls4ml/backends/vivado/passes/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from math import ceil

from hls4ml.backends.backend import get_backend
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
from hls4ml.model.layers import Einsum

from .reshaping_templates import transpose_config_gen

# Shared Dense template
# Einsum template

einsum_config_template = '''
struct config{index} {{
typedef config{index}_tpose_inp0 tpose_inp0_conf;
typedef config{index}_tpose_inp1 tpose_inp1_conf;
typedef config{index}_tpose_out tpose_out_conf;

typedef {accum_t.name} accum_t;

// Layer Sizes
static const unsigned n_free0 = {n_free0};
static const unsigned n_free1 = {n_free1};
static const unsigned n_contract = {n_contract};
static const unsigned n_inplace = {n_inplace};

// Resource reuse info
static const unsigned io_type = nnet::{iotype};
static const unsigned strategy = nnet::{strategy};
static const unsigned reuse_factor = {reuse_factor};
static const unsigned multiplier_limit = {multiplier_limit};
static const bool store_weights_in_bram = false; // NOT USED

template <class x_T, class y_T>
using product = nnet::product::{product_type}<x_T, y_T>;
}};
'''

einsum_function_template = 'nnet::einsum<{input0_t}, {input1_t}, {output_t}, {config}>({input0}, {input1}, {output});'

einsum_include_list = ['nnet_utils/nnet_einsum.h']


class EinsumConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(Einsum)
self.template = einsum_config_template

def format(self, node: Einsum):
default_params = self._default_config_params(node)

strategy = node.model.config.get_strategy(node)
io_type = node.model.config.get_config_value('IOType')

assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel for now'
assert strategy.lower() == 'latency', 'EinsumDense layer only supports Latency strategy for now'

# EinsumDense config
params = default_params.copy()
params['strategy'] = strategy
params['n_free0'] = node.attributes.attributes['n_free0']
params['n_free1'] = node.attributes.attributes['n_free1']
params['n_contract'] = node.attributes.attributes['n_contract']
params['n_inplace'] = node.attributes.attributes['n_inplace']
inp0_t = node.get_input_variable(node.inputs[0]).type.precision
inp1_t = node.get_input_variable(node.inputs[1]).type.precision
params['product_type'] = get_backend('vivado').product_type(inp0_t, inp1_t)

total_mults = params['n_free0'] * params['n_free1'] * params['n_contract'] * params['n_inplace']
params['multiplier_limit'] = ceil(total_mults / params['reuse_factor'])

einsum_conf = self.template.format(**params)

# inp/out transpose config
inp0_shape = node.attributes.attributes['inp0_shape']
inp1_shape = node.attributes.attributes['inp1_shape']
out_interpert_shape = node.attributes.attributes['out_interpert_shape']
inp0_tpose_idxs = node.attributes.attributes['inp0_tpose_idxs']
inp1_tpose_idxs = node.attributes.attributes['inp1_tpose_idxs']
out_tpose_idxs = node.attributes.attributes['out_tpose_idxs']
tpose_inp0_conf_name = f'config{node.index}_tpose_inp0'
tpose_inp1_conf_name = f'config{node.index}_tpose_inp1'
tpose_out_conf_name = f'config{node.index}_tpose_out'

inp0_tpose_conf = transpose_config_gen(tpose_inp0_conf_name, inp0_shape, inp0_tpose_idxs)
inp1_tpose_conf = transpose_config_gen(tpose_inp1_conf_name, inp1_shape, inp1_tpose_idxs)
out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs)

return '\n\n'.join((inp0_tpose_conf, inp1_tpose_conf, out_tpose_conf, einsum_conf))


class EinsumFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(Einsum, include_header=einsum_include_list)
self.template = einsum_function_template

def format(self, node: Einsum):
params = {}
params['config'] = f'config{node.index}'
params['input0_t'] = node.get_input_variable(node.inputs[0]).type.name
params['input1_t'] = node.get_input_variable(node.inputs[1]).type.name
params['output_t'] = node.get_output_variable().type.name
params['input0'] = node.get_input_variable(node.inputs[0]).name
params['input1'] = node.get_input_variable(node.inputs[1]).name
params['output'] = node.get_output_variable().name
return self.template.format(**params)
Loading