diff --git a/hls4ml/model/optimizer/passes/bit_exact.py b/hls4ml/model/optimizer/passes/bit_exact.py index 084b67360..9b16c72cc 100644 --- a/hls4ml/model/optimizer/passes/bit_exact.py +++ b/hls4ml/model/optimizer/passes/bit_exact.py @@ -11,6 +11,7 @@ from hls4ml.model.layers import ( Activation, BatchNormalization, + Concatenate, Conv1D, Conv2D, Dense, @@ -126,6 +127,20 @@ def _(layer: Activation): return (_maximum_kif_at_shape(inp_shape),) +@request_kif.register +def _(layer: Concatenate): + inp_shape0, inp_shape1 = get_input_shapes(layer) + k, i, f = requested_kif(layer) + ax = layer.attributes['axis'] + n_split = inp_shape0[ax] + + k0, k1 = np.split(k, [n_split], axis=ax) + i0, i1 = np.split(i, [n_split], axis=ax) + f0, f1 = np.split(f, [n_split], axis=ax) + + return ((k0, i0, f0), (k1, i1, f1)) + + def requested_kif(layer: Layer) -> KIF_t: out_layers = get_output_layers(layer) out_shape = get_output_shape(layer) @@ -403,6 +418,17 @@ def _(layer: Softmax): return k, i, f +@produce_kif.register +def _(layer: Concatenate): + kifs_in = get_input_kifs(layer) + ks, is_, fs = zip(*kifs_in) + ax = layer.attributes.attributes['axis'] + k = np.concatenate(ks, axis=ax) + i = np.concatenate(is_, axis=ax) + f = np.concatenate(fs, axis=ax) + return k, i, f + + @produce_kif.register def _(layer: Activation): fn_name = layer.attributes.attributes['activation']