Skip to content

Commit

Permalink
bit-exact concatenate
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Dec 17, 2024
1 parent b38420d commit a2d6e1a
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions hls4ml/model/optimizer/passes/bit_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from hls4ml.model.layers import (
Activation,
BatchNormalization,
Concatenate,
Conv1D,
Conv2D,
Dense,
Expand Down Expand Up @@ -126,6 +127,20 @@ def _(layer: Activation):
return (_maximum_kif_at_shape(inp_shape),)


@request_kif.register
def _(layer: Concatenate):
inp_shape0, inp_shape1 = get_input_shapes(layer)
k, i, f = requested_kif(layer)
ax = layer.attributes['axis']
n_split = inp_shape0[ax]

k0, k1 = np.split(k, [n_split], axis=ax)
i0, i1 = np.split(i, [n_split], axis=ax)
f0, f1 = np.split(f, [n_split], axis=ax)

return ((k0, i0, f0), (k1, i1, f1))


def requested_kif(layer: Layer) -> KIF_t:
out_layers = get_output_layers(layer)
out_shape = get_output_shape(layer)
Expand Down Expand Up @@ -403,6 +418,17 @@ def _(layer: Softmax):
return k, i, f


@produce_kif.register
def _(layer: Concatenate):
kifs_in = get_input_kifs(layer)
ks, is_, fs = zip(*kifs_in)
ax = layer.attributes.attributes['axis']
k = np.concatenate(ks, axis=ax)
i = np.concatenate(is_, axis=ax)
f = np.concatenate(fs, axis=ax)
return k, i, f


@produce_kif.register
def _(layer: Activation):
fn_name = layer.attributes.attributes['activation']
Expand Down

0 comments on commit a2d6e1a

Please sign in to comment.