From 58af4ada959362cfbc7158c6b803aba67ee0337b Mon Sep 17 00:00:00 2001 From: Prem Kumar Date: Thu, 12 Dec 2024 07:29:46 +0000 Subject: [PATCH] update dc.op() with dc.op_with_named_attrs() --- forge/forge/op/eval/forge/argmax.py | 9 +- forge/forge/op/eval/forge/convolution.py | 4 +- forge/forge/op/eval/forge/eltwise_binary.py | 24 +++-- forge/forge/op/eval/forge/eltwise_unary.py | 18 +++- forge/forge/op/eval/forge/nn.py | 2 +- forge/forge/op/eval/forge/pooling.py | 38 +++++--- forge/forge/op/eval/forge/quantize.py | 26 ++++-- forge/forge/op/eval/forge/reduce.py | 2 +- forge/forge/op/eval/forge/tm.py | 64 +++++++++---- forge/test/mlir/test_forge_ops.py | 99 +++++++++++++++++++++ 10 files changed, 230 insertions(+), 56 deletions(-) create mode 100644 forge/test/mlir/test_forge_ops.py diff --git a/forge/forge/op/eval/forge/argmax.py b/forge/forge/op/eval/forge/argmax.py index 83702d9ad..f62bae08a 100644 --- a/forge/forge/op/eval/forge/argmax.py +++ b/forge/forge/op/eval/forge/argmax.py @@ -70,7 +70,12 @@ def decompose(self, dc, inputs): if axis is None: import math - inp_node = dc.op("reshape", [inp_node], (1, math.prod(inp_node.shape.as_list()))) + inp_node = dc.op_with_named_attrs( + "reshape", + [inp_node], + {"shape": (1, math.prod(inp_node.shape.as_list()))}, + (1, math.prod(inp_node.shape.as_list())), + ) axis = -1 input_shape = inp_node.shape.as_list() @@ -90,7 +95,7 @@ def decompose(self, dc, inputs): "multiply", [inp_node, factor_tensor], ) - softmax = dc.op("softmax", [mult_1], (axis, 1)) + softmax = dc.op_with_named_attrs("softmax", [mult_1], {"dimension": axis}, (axis, 1)) mult_2 = dc.op("multiply", [softmax, range_tensor]) reduce_sum = dc.op_with_named_attrs( "reduce_sum", (mult_2,), {"dim_arg": [axis], "keep_dim": True}, (axis, True) diff --git a/forge/forge/op/eval/forge/convolution.py b/forge/forge/op/eval/forge/convolution.py index a4fdcf82d..d95f56f55 100644 --- a/forge/forge/op/eval/forge/convolution.py +++ b/forge/forge/op/eval/forge/convolution.py @@ -136,7 +136,7 @@ def decompose(self, dc, inputs): if bias is not None and len(bias.shape) < len(activations.shape): while len(bias.shape) < len(activations.shape): - bias = dc.op("unsqueeze", [bias], (0, len(bias.shape))) + bias = dc.op_with_named_attrs("unsqueeze", [bias], {"dim": 0}, (0, len(bias.shape))) is_bias_unchanged = bias is None or bias == inputs[2] @@ -310,7 +310,7 @@ def decompose(self, dc, inputs): if bias is not None and len(bias.shape) < len(activations.shape): while len(bias.shape) < len(activations.shape): - bias = dc.op("unsqueeze", [bias], (0, len(bias.shape))) + bias = dc.op_with_named_attrs("unsqueeze", [bias], {"dim": 0}, (0, len(bias.shape))) is_bias_unchanged = bias is None or bias == inputs[2] diff --git a/forge/forge/op/eval/forge/eltwise_binary.py b/forge/forge/op/eval/forge/eltwise_binary.py index a9d4f98cf..220ed06bc 100644 --- a/forge/forge/op/eval/forge/eltwise_binary.py +++ b/forge/forge/op/eval/forge/eltwise_binary.py @@ -398,11 +398,11 @@ def decompose(op_type, attr, dc, inputs): ops0_dims = len(inputs[0].shape) ops1_dims = len(inputs[1].shape) if ops0_dims > ops1_dims and ops0_dims == 5: - ops1 = dc.op("reshape", [inputs[1]], list(inputs[0].shape)) + ops1 = dc.op_with_named_attrs("reshape", [inputs[1]], {"shape": list(inputs[0].shape)}, list(inputs[0].shape)) result = dc.op(op_type, [inputs[0], ops1]) dc.fuse(result) elif ops1_dims > ops0_dims and ops1_dims == 5: - ops0 = dc.op("reshape", [inputs[0]], list(inputs[1].shape)) + ops0 = dc.op_with_named_attrs("reshape", [inputs[0]], {"shape": list(inputs[1].shape)}, list(inputs[1].shape)) result = dc.op(op_type, [ops0, inputs[1]]) dc.fuse(result) @@ -449,13 +449,17 @@ def decompose_post_autograd(op_type, attr, dc, inputs): if slice_factor != None: concat_z = dc.op("interleave", [operand0, operand1], (-3, 1)) - result = dc.op("reduce_max", [concat_z], (-3, 2)) + result = dc.op_with_named_attrs( + "reduce_max", [concat_z], {"dim_arg": [-3], "keep_dim": True}, (-3, 2, True) + ) else: - concat_z = dc.op("concatenate", [operand0, operand1], (-3,)) - result = dc.op("reduce_max", [concat_z], (-3,)) + concat_z = dc.op_with_named_attrs("concatenate", [operand0, operand1], {"dim": -3}, (-3,)) + result = dc.op_with_named_attrs( + "reduce_max", [concat_z], {"dim_arg": [-3], "keep_dim": True}, (-3, concat_z.shape[-3], True) + ) while len(result.shape) > max_operand_nd: - result = dc.op("squeeze", [result], (0,)) + result = dc.op_with_named_attrs("squeeze", [result], {"dim": 0}, (0,)) dc.fuse(result) return @@ -463,11 +467,15 @@ def decompose_post_autograd(op_type, attr, dc, inputs): ops0_dims = len(inputs[0].shape) ops1_dims = len(inputs[1].shape) if ops0_dims > ops1_dims and ops0_dims == 5: - ops1 = dc.op("reshape", [inputs[1]], list(inputs[0].shape)) + ops1 = dc.op_with_named_attrs( + "reshape", [inputs[1]], {"shape": list(inputs[0].shape)}, list(inputs[0].shape) + ) result = dc.op(op_type, [inputs[0], ops1]) dc.fuse(result) elif ops1_dims > ops0_dims and ops1_dims == 5: - ops0 = dc.op("reshape", [inputs[0]], list(inputs[1].shape)) + ops0 = dc.op_with_named_attrs( + "reshape", [inputs[0]], {"shape": list(inputs[0].shape)}, list(inputs[1].shape) + ) result = dc.op(op_type, [ops0, inputs[1]]) dc.fuse(result) diff --git a/forge/forge/op/eval/forge/eltwise_unary.py b/forge/forge/op/eval/forge/eltwise_unary.py index 90f46fa18..4a4ff44f6 100644 --- a/forge/forge/op/eval/forge/eltwise_unary.py +++ b/forge/forge/op/eval/forge/eltwise_unary.py @@ -521,7 +521,12 @@ def decompose(type, attr, dc, inputs): if axis is None: import math - inp_node = dc.op("reshape", [inp_node], (1, math.prod(inp_node.shape.as_list()))) + inp_node = dc.op_with_named_attrs( + "reshape", + [inp_node], + {"shape": (1, math.prod(inp_node.shape.as_list()))}, + (1, math.prod(inp_node.shape.as_list())), + ) axis = -1 input_shape = inp_node.shape.as_list() @@ -559,7 +564,9 @@ def decompose(type, attr, dc, inputs): "multiply", (inp_node, factor_tensor), ) - max_1 = dc.op("reduce_max", [scaled_input], [axis]) + max_1 = dc.op_with_named_attrs( + "reduce_max", [scaled_input], {"dim_arg": [axis], "keep_dim": True}, [axis, scaled_input.shape[axis], True] + ) scaled_input = dc.op("subtract", (scaled_input, max_1)) scaled_input = dc.op( "add", @@ -582,7 +589,12 @@ def decompose(type, attr, dc, inputs): [mul_1, mul_2], ) negative_add_1 = dc.op("multiply", [add_1, negative_ones]) - negative_argmax = dc.op("reduce_max", [negative_add_1], [axis]) + negative_argmax = dc.op_with_named_attrs( + "reduce_max", + [negative_add_1], + {"dim_arg": [axis], "keep_dim": True}, + [axis, negative_add_1.shape[axis], True], + ) output_neg_ones = torch.ones((negative_argmax.shape.as_list()), dtype=data_type) * (-1) output_neg_ones_tensor = dc.tensor(output_neg_ones) diff --git a/forge/forge/op/eval/forge/nn.py b/forge/forge/op/eval/forge/nn.py index fb9397150..184ae2a47 100644 --- a/forge/forge/op/eval/forge/nn.py +++ b/forge/forge/op/eval/forge/nn.py @@ -394,7 +394,7 @@ def decompose(op_type, attr, dc, inputs): x = inputs[0] dim = attr[0] stable = attr[1] - result = dc.op("softmax", (x,), (dim, stable)) + result = dc.op_with_named_attrs("softmax", (x,), {"dimension": dim}, (dim, stable)) result = dc.op(Log.create(), (result,)) dc.fuse(result) return diff --git a/forge/forge/op/eval/forge/pooling.py b/forge/forge/op/eval/forge/pooling.py index e9406a3ad..7818f56a3 100644 --- a/forge/forge/op/eval/forge/pooling.py +++ b/forge/forge/op/eval/forge/pooling.py @@ -467,7 +467,9 @@ def decompose(type, attr, dc, inputs): activations = inputs[0] if kernel_size == activations.shape[-1]: - reduce_avg = dc.op_with_named_attrs("reduce_avg", [activations], {"dim": -1, "keep_dim": True}, (-1, True)) + reduce_avg = dc.op_with_named_attrs( + "reduce_avg", [activations], {"dim_arg": -1, "keep_dim": True}, (-1, True) + ) dc.fuse(reduce_avg) return else: @@ -527,14 +529,14 @@ def decompose(type, attr, dc, inputs): result = dc.op_with_named_attrs( "reshape", [activations], {"shape": (w, 1, y * x, cin)}, (w, 1, y * x, cin) ) - result = dc.op_with_named_attrs("reduce_avg", [result], {"dim": -2, "keep_dim": True}, (-2, True)) + result = dc.op_with_named_attrs("reduce_avg", [result], {"dim_arg": -2, "keep_dim": True}, (-2, True)) result = dc.op_with_named_attrs("reshape", [result], {"shape": (w, 1, 1, cin)}, (w, 1, 1, cin)) else: result = dc.op_with_named_attrs( "reshape", [activations], {"shape": (w, 1, cin, y * x)}, (w, 1, cin, y * x) ) result = dc.op(TransposeTM.create(2, 3), [result]) - result = dc.op_with_named_attrs("reduce_avg", [result], {"dim": -2, "keep_dim": True}, (-2, True)) + result = dc.op_with_named_attrs("reduce_avg", [result], {"dim_arg": -2, "keep_dim": True}, (-2, True)) result = dc.op(TransposeTM.create(2, 3), [result]) result = dc.op_with_named_attrs("reshape", [result], {"shape": (w, cin, 1, 1)}, (w, cin, 1, 1)) dc.fuse(result) @@ -776,9 +778,9 @@ def decompose(type, attr, dc, inputs): if channel_last: # result = dc.op("vstack", [activations], (y,)) _, yout, xout, _ = shape("max_pool2d", attr, [activations.shape])[0] - result = dc.op("reshape", [activations], (w, 1, y * x, cin)) + result = dc.op_with_named_attrs("reshape", [activations], {"shape": (w, 1, y * x, cin)}, (w, 1, y * x, cin)) else: - result = dc.op("reshape", [activations], (w, 1, cin, y * x)) + result = dc.op_with_named_attrs("reshape", [activations], {"shape": (w, 1, cin, y * x)}, (w, 1, cin, y * x)) result = dc.op(TransposeTM.create(2, 3), [result]) _, _, yout, xout = shape("max_pool2d", attr, [activations.shape])[0] result = dc.op("pad_tile", [result], (3, cin)) @@ -812,10 +814,12 @@ def decompose(type, attr, dc, inputs): pad_shape = result.shape.as_list() pad_shape[-1] = (result_c_padding[result_c] - result_c) * TILE_DIM zeros_tensor = dc.tensor(torch.zeros(pad_shape)) - result = dc.op("concatenate", [result, zeros_tensor], (-1,)) + result = dc.op_with_named_attrs("concatenate", [result, zeros_tensor], {"dim": -1}, (-1,)) result = dc.op("sparse_matmul", [picker_tensor, result]) - result = dc.op("reduce_max", [result], (1,)) # z dim + result = dc.op_with_named_attrs( + "reduce_max", [result], {"dim_arg": [1], "keep_dim": True}, (1, result.shape[1], True) + ) # z dim if pad_for_factorization: if sparse_r in sparse_r_padding: @@ -835,10 +839,10 @@ def decompose(type, attr, dc, inputs): result = dc.op("narrow", [result], (2, 0, yout * xout, result.shape[2])) result = dc.op("narrow", [result], (3, 0, cin, result.shape[3])) if channel_last: - result = dc.op("reshape", [result], (w, yout, xout, cin)) + result = dc.op_with_named_attrs("reshape", [result], {"shape": (w, yout, xout, cin)}, (w, yout, xout, cin)) else: result = dc.op(TransposeTM.create(2, 3), [result]) - result = dc.op("reshape", [result], (w, cin, yout, xout)) + result = dc.op_with_named_attrs("reshape", [result], {"shape": (w, cin, yout, xout)}, (w, cin, yout, xout)) if max_pool_add_sub_surround: add_sub_val = dc.tensor(torch.zeros((1,)) + max_pool_add_sub_surround_value) @@ -907,7 +911,9 @@ def decompose(type, attr, dc, inputs): # _, yout, xout, _ = shape('max_pool2d', attr, [activations.shape])[0] # result = dc.op("reshape", [activations], (w, 1, y * x, cin)) # else: - result = dc.op("reshape", [activations], (w, 1, cin * din, y * x)) + result = dc.op_with_named_attrs( + "reshape", [activations], {"shape": (w, 1, cin * din, y * x)}, (w, 1, cin * din, y * x) + ) result = dc.op(TransposeTM.create(-2, -1), [result]) _, cout, dout, yout, xout = shape("max_pool3d", attr, [activations.shape])[0] result = dc.op("pad_tile", [result], (-1, cin * din)) @@ -958,7 +964,9 @@ def create_conv2d_sparse_matrix( ) picker_tensor = dc.tensor(picker.unsqueeze(0)) # (1, kH*kW, yout*xout, yin*xin) result = dc.op("sparse_matmul", [picker_tensor, result]) # (1, kH*kW, yout*xout, cin*din) - result = dc.op("reduce_max", [result], (-3,)) # z dim # (1, 1, yout*xout, cin*din) + result = dc.op_with_named_attrs( + "reduce_max", [result], {"dim_arg": [-3], "keep_dim": True}, (-3, result.shape[-3], True) + ) # z dim # (1, 1, yout*xout, cin*din) # Run max pool on the depth dimension in a separate step if kD > 1: @@ -976,7 +984,9 @@ def create_conv2d_sparse_matrix( # Transpose the activations to allow sparse mm to work on the depth dim result = dc.op(TransposeTM.create(-2, -1), [result]) result = dc.op("sparse_matmul", [depth_picker, result]) # (1, kD, cout*dout, yout*xout) - result = dc.op("reduce_max", [result], (-3,)) # z dim # (1, 1, cout*dout, yout*xout) + result = dc.op_with_named_attrs( + "reduce_max", [result], {"dim_arg": [-3], "keep_dim": True}, (-3, result.shape[-3], True) + ) # z dim # (1, 1, cout*dout, yout*xout) # Transpose back result = dc.op(TransposeTM.create(-2, -1), [result]) @@ -989,7 +999,9 @@ def create_conv2d_sparse_matrix( # if channel_last: # result = dc.op("reshape", [result], (w, yout, xout, cin)) # else: - result = dc.op("reshape", [result], (w, cin, dout, yout, xout)) + result = dc.op_with_named_attrs( + "reshape", [result], {"shape": (w, cin, dout, yout, xout)}, (w, cin, dout, yout, xout) + ) # if max_pool_add_sub_surround: # add_sub_val = dc.tensor(torch.zeros((1,)) + max_pool_add_sub_surround_value) diff --git a/forge/forge/op/eval/forge/quantize.py b/forge/forge/op/eval/forge/quantize.py index 89e89ed5b..e89a6aa5b 100644 --- a/forge/forge/op/eval/forge/quantize.py +++ b/forge/forge/op/eval/forge/quantize.py @@ -202,14 +202,15 @@ def decompose(type, attr, dc, inputs): if len(inp_scale_shape) == 1: # Match ndim with actiavtion for i in range(0, left_ndim): - inp_scale = dc.op( - "unsqueeze", [inp_scale], attrs=(0, len(inp_scale_shape)), output_df=inp_scale.output_df + inp_scale = dc.op_with_named_attrs( + "unsqueeze", [inp_scale], {"dim": 0}, attrs=(0, len(inp_scale_shape)), output_df=inp_scale.output_df ) inp_scale_shape = [1] + inp_scale_shape for i in range(0, right_ndim): - inp_scale = dc.op( + inp_scale = dc.op_with_named_attrs( "unsqueeze", [inp_scale], + {"dim": (len(inp_scale_shape))}, attrs=(len(inp_scale_shape), len(inp_scale_shape)), output_df=inp_scale.output_df, ) @@ -219,14 +220,15 @@ def decompose(type, attr, dc, inputs): if len(out_scale_shape) == 1: # Match ndim with actiavtion for i in range(0, left_ndim): - out_scale = dc.op( - "unsqueeze", [out_scale], attrs=(0, len(out_scale_shape)), output_df=out_scale.output_df + out_scale = dc.op_with_named_attrs( + "unsqueeze", [out_scale], {"dim": 0}, attrs=(0, len(out_scale_shape)), output_df=out_scale.output_df ) out_scale_shape = [1] + out_scale_shape for i in range(0, right_ndim): - out_scale = dc.op( + out_scale = dc.op_with_named_attrs( "unsqueeze", [out_scale], + {"dim": len(out_scale_shape)}, attrs=(len(out_scale_shape), len(out_scale_shape)), output_df=out_scale.output_df, ) @@ -274,11 +276,17 @@ def decompose(type, attr, dc, inputs): if len(scale_shape) == 1: # Match ndim with actiavtion for i in range(0, left_ndim): - scale = dc.op("unsqueeze", [scale], attrs=(0, len(scale_shape)), output_df=scale.output_df) + scale = dc.op_with_named_attrs( + "unsqueeze", [scale], {"dim": 0}, attrs=(0, len(scale_shape)), output_df=scale.output_df + ) scale_shape = [1] + scale_shape for i in range(0, right_ndim): - scale = dc.op( - "unsqueeze", [scale], attrs=(len(scale_shape), len(scale_shape)), output_df=scale.output_df + scale = dc.op_with_named_attrs( + "unsqueeze", + [scale], + {"dim": (len(scale_shape))}, + attrs=(len(scale_shape), len(scale_shape)), + output_df=scale.output_df, ) scale_shape = scale_shape + [1] diff --git a/forge/forge/op/eval/forge/reduce.py b/forge/forge/op/eval/forge/reduce.py index b779d0057..29335a573 100644 --- a/forge/forge/op/eval/forge/reduce.py +++ b/forge/forge/op/eval/forge/reduce.py @@ -378,7 +378,7 @@ def decompose(type, attr, dc, inputs): if isinstance(attr[0], list): x = inputs[0] for dim in attr[0]: - x = dc.op_with_named_attrs("reduce_avg", [x], (dim,)) + x = dc.op_with_named_attrs("reduce_avg", [x], {"dim_arg": dim, "keep_dim": True}, (dim, True)) dc.fuse(x) return diff --git a/forge/forge/op/eval/forge/tm.py b/forge/forge/op/eval/forge/tm.py index 6fc35a1fb..d4ce10819 100644 --- a/forge/forge/op/eval/forge/tm.py +++ b/forge/forge/op/eval/forge/tm.py @@ -1034,7 +1034,7 @@ def decompose(type, attr, dc, inputs): if is_one_dim: # If input is a one-dimensional tensor, reshape it to a 2D tensor with one dimension equal to 1 # and the other equal to the length. Use unsqueeze to add a dimension to the tensor. - act = dc.op("unsqueeze", [act], (0, len(act.shape))) + act = dc.op_with_named_attrs("unsqueeze", [act], {"dim": 0}, (0, len(act.shape))) row_indices = list(range(start, stop, stride)) @@ -1163,15 +1163,28 @@ def decompose(type, attr, dc, inputs): result = dc.op(TransposeTM.create(-3, -1, result.shape[-3]), [result]) orig_shape = result.shape - result = dc.op("reshape", [result], (1, 1, orig_shape[-3], orig_shape[-2] * orig_shape[-1])) + result = dc.op_with_named_attrs( + "reshape", + [result], + {"shape": (1, 1, orig_shape[-3], orig_shape[-2] * orig_shape[-1])}, + (1, 1, orig_shape[-3], orig_shape[-2] * orig_shape[-1]), + ) result = dc.op(TransposeTM.create(-2, -1), [result]) spm = create_pad_replicate_sparse_picker(c, r, top, bottom, left, right) spm = dc.tensor(spm) result = dc.op("sparse_matmul", [spm, result]) result = dc.op(TransposeTM.create(-2, -1), [result]) - result = dc.op( + result = dc.op_with_named_attrs( "reshape", [result], + { + "shape": ( + 1, + orig_shape[-3], + orig_shape[-1] + total_padding_r, + orig_shape[-2] + total_padding_c, + ) + }, (1, orig_shape[-3], orig_shape[-1] + total_padding_r, orig_shape[-2] + total_padding_c), ) @@ -1179,24 +1192,21 @@ def decompose(type, attr, dc, inputs): else: orig_shape = result.shape if len(orig_shape) == 2: - result = dc.op("reshape", [result], (1, orig_shape[-2] * orig_shape[-1])) + shape = (1, orig_shape[-2] * orig_shape[-1]) else: - result = dc.op("reshape", [result], (1, 1, orig_shape[-3], orig_shape[-2] * orig_shape[-1])) + shape = (1, 1, orig_shape[-3], orig_shape[-2] * orig_shape[-1]) + result = dc.op_with_named_attrs("reshape", [result], {"shape": shape}, shape) result = dc.op(TransposeTM.create(-2, -1), [result]) spm = create_pad_replicate_sparse_picker(r, c, left, right, top, bottom) spm = dc.tensor(spm) result = dc.op("sparse_matmul", [spm, result]) result = dc.op(TransposeTM.create(-2, -1), [result]) if len(orig_shape) == 2: - result = dc.op( - "reshape", [result], (orig_shape[-2] + total_padding_r, orig_shape[-1] + total_padding_c) - ) + shape = (orig_shape[-2] + total_padding_r, orig_shape[-1] + total_padding_c) else: - result = dc.op( - "reshape", - [result], - (1, orig_shape[-3], orig_shape[-2] + total_padding_r, orig_shape[-1] + total_padding_c), - ) + shape = (1, orig_shape[-3], orig_shape[-2] + total_padding_r, orig_shape[-1] + total_padding_c) + + result = dc.op_with_named_attrs("reshape", [result], {"shape": shape}, shape) dc.fuse(result) return @@ -1212,7 +1222,12 @@ def decompose(type, attr, dc, inputs): pad_shape[c_dim_axis] = left tensor = torch.zeros(pad_shape) const_tensor = dc.tensor(tensor) - result = dc.op("concatenate", [const_tensor, result], [c_dim_axis]) + result = dc.op_with_named_attrs( + "concatenate", [const_tensor, result], {"dim": c_dim_axis}, [c_dim_axis] + ) + result = dc.op_with_named_attrs( + "concatenate", [const_tensor, result], {"dim": c_dim_axis}, [c_dim_axis] + ) if right > 0: pad_shape = result.shape.as_list().copy() @@ -1221,14 +1236,24 @@ def decompose(type, attr, dc, inputs): ) tensor = torch.zeros(pad_shape) const_tensor = dc.tensor(tensor) - result = dc.op("concatenate", [result, const_tensor], [c_dim_axis]) + result = dc.op_with_named_attrs( + "concatenate", [result, const_tensor], {"dim": c_dim_axis}, [c_dim_axis] + ) + result = dc.op_with_named_attrs( + "concatenate", [result, const_tensor], {"dim": c_dim_axis}, [c_dim_axis] + ) if top > 0: pad_shape = result.shape.as_list().copy() pad_shape[r_dim_axis] = top tensor = torch.zeros(pad_shape) const_tensor = dc.tensor(tensor) - result = dc.op("concatenate", [const_tensor, result], [r_dim_axis]) + result = dc.op_with_named_attrs( + "concatenate", [const_tensor, result], {"dim": r_dim_axis}, [r_dim_axis] + ) + result = dc.op_with_named_attrs( + "concatenate", [const_tensor, result], {"dim": r_dim_axis}, [r_dim_axis] + ) if bottom > 0: pad_shape = result.shape.as_list().copy() @@ -1237,7 +1262,12 @@ def decompose(type, attr, dc, inputs): ) tensor = torch.zeros(pad_shape) const_tensor = dc.tensor(tensor) - result = dc.op("concatenate", [result, const_tensor], [r_dim_axis]) + result = dc.op_with_named_attrs( + "concatenate", [result, const_tensor], {"dim": r_dim_axis}, [r_dim_axis] + ) + result = dc.op_with_named_attrs( + "concatenate", [result, const_tensor], {"dim": r_dim_axis}, [r_dim_axis] + ) result = dc.op("narrow", [result], (c_dim_axis, 0, total_padding_c + c, result.shape[c_dim_axis])) if channel_last: diff --git a/forge/test/mlir/test_forge_ops.py b/forge/test/mlir/test_forge_ops.py new file mode 100644 index 000000000..0c2be1466 --- /dev/null +++ b/forge/test/mlir/test_forge_ops.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import torch + +from forge import compile +import pytest + + +@pytest.mark.parametrize( + "shapes_dtypes", + [ + (torch.tensor([1, 1, 1, 1, 32], dtype=torch.float32), torch.tensor([32], dtype=torch.float32)), + (torch.tensor([32], dtype=torch.float32), torch.tensor([1, 1, 1, 1, 32], dtype=torch.float32)), + ], +) +def test_add(shapes_dtypes): + class AddOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_1, input_2): + output = torch.add(input_1, input_2) + return output + + inputs = shapes_dtypes + + framework_model = AddOp() + framework_model.eval() + + compile(framework_model, sample_inputs=inputs) + + +def test_argmax(): + + inputs = [torch.rand((1, 32, 64))] + + class ArgmaxOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_1): + output = torch.argmax(input_1, dim=-1) + return output + + framework_model = ArgmaxOp() + framework_model.eval() + + compile(framework_model, sample_inputs=inputs) + + +def test_logsoftmax_torch(): + class LogSoftmaxOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_1): + output = torch.nn.functional.log_softmax(input_1, dim=-1) + return output + + inputs = [torch.rand((1, 32, 64))] + + framework_model = LogSoftmaxOp() + framework_model.eval() + + compile(framework_model, sample_inputs=inputs) + + +def test_maximum(): + class MaximumOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_1, input_2): + output = torch.maximum(input_1, input_2) + return output + + inputs = [torch.randn((2, 3, 4)), torch.randn((2, 3, 4))] + + framework_model = MaximumOp() + framework_model.eval() + + compile(framework_model, sample_inputs=inputs) + + +def test_avg_pool1d(): + class AvgPool1d(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.avg_pool1d( + x, kernel_size=[7], stride=[7], padding=0, ceil_mode=False, count_include_pad=True + ) + + inputs = [torch.rand(1, 2048, 7)] + + framework_model = AvgPool1d() + compile(framework_model, sample_inputs=inputs)