Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update dc.op() with dc.op_with_named_attrs() #942

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions forge/forge/op/eval/forge/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ def decompose(self, dc, inputs):
if axis is None:
import math

inp_node = dc.op("reshape", [inp_node], (1, math.prod(inp_node.shape.as_list())))
inp_node = dc.op_with_named_attrs(
"reshape",
[inp_node],
{"shape": (1, math.prod(inp_node.shape.as_list()))},
(1, math.prod(inp_node.shape.as_list())),
)
axis = -1

input_shape = inp_node.shape.as_list()
Expand All @@ -90,7 +95,7 @@ def decompose(self, dc, inputs):
"multiply",
[inp_node, factor_tensor],
)
softmax = dc.op("softmax", [mult_1], (axis, 1))
softmax = dc.op_with_named_attrs("softmax", [mult_1], {"dimension": axis}, (axis, 1))
mult_2 = dc.op("multiply", [softmax, range_tensor])
reduce_sum = dc.op_with_named_attrs(
"reduce_sum", (mult_2,), {"dim_arg": [axis], "keep_dim": True}, (axis, True)
Expand Down
4 changes: 2 additions & 2 deletions forge/forge/op/eval/forge/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def decompose(self, dc, inputs):

if bias is not None and len(bias.shape) < len(activations.shape):
while len(bias.shape) < len(activations.shape):
bias = dc.op("unsqueeze", [bias], (0, len(bias.shape)))
bias = dc.op_with_named_attrs("unsqueeze", [bias], {"dim": 0}, (0, len(bias.shape)))

is_bias_unchanged = bias is None or bias == inputs[2]

Expand Down Expand Up @@ -310,7 +310,7 @@ def decompose(self, dc, inputs):

if bias is not None and len(bias.shape) < len(activations.shape):
while len(bias.shape) < len(activations.shape):
bias = dc.op("unsqueeze", [bias], (0, len(bias.shape)))
bias = dc.op_with_named_attrs("unsqueeze", [bias], {"dim": 0}, (0, len(bias.shape)))

is_bias_unchanged = bias is None or bias == inputs[2]

Expand Down
24 changes: 16 additions & 8 deletions forge/forge/op/eval/forge/eltwise_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,11 @@ def decompose(op_type, attr, dc, inputs):
ops0_dims = len(inputs[0].shape)
ops1_dims = len(inputs[1].shape)
if ops0_dims > ops1_dims and ops0_dims == 5:
ops1 = dc.op("reshape", [inputs[1]], list(inputs[0].shape))
ops1 = dc.op_with_named_attrs("reshape", [inputs[1]], {"shape": list(inputs[0].shape)}, list(inputs[0].shape))
result = dc.op(op_type, [inputs[0], ops1])
dc.fuse(result)
elif ops1_dims > ops0_dims and ops1_dims == 5:
ops0 = dc.op("reshape", [inputs[0]], list(inputs[1].shape))
ops0 = dc.op_with_named_attrs("reshape", [inputs[0]], {"shape": list(inputs[1].shape)}, list(inputs[1].shape))
result = dc.op(op_type, [ops0, inputs[1]])
dc.fuse(result)

Expand Down Expand Up @@ -449,25 +449,33 @@ def decompose_post_autograd(op_type, attr, dc, inputs):

if slice_factor != None:
concat_z = dc.op("interleave", [operand0, operand1], (-3, 1))
result = dc.op("reduce_max", [concat_z], (-3, 2))
result = dc.op_with_named_attrs(
"reduce_max", [concat_z], {"dim_arg": [-3], "keep_dim": True}, (-3, 2, True)
)
else:
concat_z = dc.op("concatenate", [operand0, operand1], (-3,))
result = dc.op("reduce_max", [concat_z], (-3,))
concat_z = dc.op_with_named_attrs("concatenate", [operand0, operand1], {"dim": -3}, (-3,))
result = dc.op_with_named_attrs(
"reduce_max", [concat_z], {"dim_arg": [-3], "keep_dim": True}, (-3, concat_z.shape[-3], True)
)

while len(result.shape) > max_operand_nd:
result = dc.op("squeeze", [result], (0,))
result = dc.op_with_named_attrs("squeeze", [result], {"dim": 0}, (0,))

dc.fuse(result)
return
else:
ops0_dims = len(inputs[0].shape)
ops1_dims = len(inputs[1].shape)
if ops0_dims > ops1_dims and ops0_dims == 5:
ops1 = dc.op("reshape", [inputs[1]], list(inputs[0].shape))
ops1 = dc.op_with_named_attrs(
"reshape", [inputs[1]], {"shape": list(inputs[0].shape)}, list(inputs[0].shape)
)
result = dc.op(op_type, [inputs[0], ops1])
dc.fuse(result)
elif ops1_dims > ops0_dims and ops1_dims == 5:
ops0 = dc.op("reshape", [inputs[0]], list(inputs[1].shape))
ops0 = dc.op_with_named_attrs(
"reshape", [inputs[0]], {"shape": list(inputs[0].shape)}, list(inputs[1].shape)
)
result = dc.op(op_type, [ops0, inputs[1]])
dc.fuse(result)

Expand Down
18 changes: 15 additions & 3 deletions forge/forge/op/eval/forge/eltwise_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,12 @@ def decompose(type, attr, dc, inputs):
if axis is None:
import math

inp_node = dc.op("reshape", [inp_node], (1, math.prod(inp_node.shape.as_list())))
inp_node = dc.op_with_named_attrs(
"reshape",
[inp_node],
{"shape": (1, math.prod(inp_node.shape.as_list()))},
(1, math.prod(inp_node.shape.as_list())),
)
axis = -1

input_shape = inp_node.shape.as_list()
Expand Down Expand Up @@ -559,7 +564,9 @@ def decompose(type, attr, dc, inputs):
"multiply",
(inp_node, factor_tensor),
)
max_1 = dc.op("reduce_max", [scaled_input], [axis])
max_1 = dc.op_with_named_attrs(
"reduce_max", [scaled_input], {"dim_arg": [axis], "keep_dim": True}, [axis, scaled_input.shape[axis], True]
)
scaled_input = dc.op("subtract", (scaled_input, max_1))
scaled_input = dc.op(
"add",
Expand All @@ -582,7 +589,12 @@ def decompose(type, attr, dc, inputs):
[mul_1, mul_2],
)
negative_add_1 = dc.op("multiply", [add_1, negative_ones])
negative_argmax = dc.op("reduce_max", [negative_add_1], [axis])
negative_argmax = dc.op_with_named_attrs(
"reduce_max",
[negative_add_1],
{"dim_arg": [axis], "keep_dim": True},
[axis, negative_add_1.shape[axis], True],
)

output_neg_ones = torch.ones((negative_argmax.shape.as_list()), dtype=data_type) * (-1)
output_neg_ones_tensor = dc.tensor(output_neg_ones)
Expand Down
2 changes: 1 addition & 1 deletion forge/forge/op/eval/forge/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def decompose(op_type, attr, dc, inputs):
x = inputs[0]
dim = attr[0]
stable = attr[1]
result = dc.op("softmax", (x,), (dim, stable))
result = dc.op_with_named_attrs("softmax", (x,), {"dimension": dim}, (dim, stable))
result = dc.op(Log.create(), (result,))
dc.fuse(result)
return
Expand Down
38 changes: 25 additions & 13 deletions forge/forge/op/eval/forge/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ def decompose(type, attr, dc, inputs):

activations = inputs[0]
if kernel_size == activations.shape[-1]:
reduce_avg = dc.op_with_named_attrs("reduce_avg", [activations], {"dim": -1, "keep_dim": True}, (-1, True))
reduce_avg = dc.op_with_named_attrs(
"reduce_avg", [activations], {"dim_arg": -1, "keep_dim": True}, (-1, True)
)
dc.fuse(reduce_avg)
return
else:
Expand Down Expand Up @@ -527,14 +529,14 @@ def decompose(type, attr, dc, inputs):
result = dc.op_with_named_attrs(
"reshape", [activations], {"shape": (w, 1, y * x, cin)}, (w, 1, y * x, cin)
)
result = dc.op_with_named_attrs("reduce_avg", [result], {"dim": -2, "keep_dim": True}, (-2, True))
result = dc.op_with_named_attrs("reduce_avg", [result], {"dim_arg": -2, "keep_dim": True}, (-2, True))
result = dc.op_with_named_attrs("reshape", [result], {"shape": (w, 1, 1, cin)}, (w, 1, 1, cin))
else:
result = dc.op_with_named_attrs(
"reshape", [activations], {"shape": (w, 1, cin, y * x)}, (w, 1, cin, y * x)
)
result = dc.op(TransposeTM.create(2, 3), [result])
result = dc.op_with_named_attrs("reduce_avg", [result], {"dim": -2, "keep_dim": True}, (-2, True))
result = dc.op_with_named_attrs("reduce_avg", [result], {"dim_arg": -2, "keep_dim": True}, (-2, True))
result = dc.op(TransposeTM.create(2, 3), [result])
result = dc.op_with_named_attrs("reshape", [result], {"shape": (w, cin, 1, 1)}, (w, cin, 1, 1))
dc.fuse(result)
Expand Down Expand Up @@ -776,9 +778,9 @@ def decompose(type, attr, dc, inputs):
if channel_last:
# result = dc.op("vstack", [activations], (y,))
_, yout, xout, _ = shape("max_pool2d", attr, [activations.shape])[0]
result = dc.op("reshape", [activations], (w, 1, y * x, cin))
result = dc.op_with_named_attrs("reshape", [activations], {"shape": (w, 1, y * x, cin)}, (w, 1, y * x, cin))
else:
result = dc.op("reshape", [activations], (w, 1, cin, y * x))
result = dc.op_with_named_attrs("reshape", [activations], {"shape": (w, 1, cin, y * x)}, (w, 1, cin, y * x))
result = dc.op(TransposeTM.create(2, 3), [result])
_, _, yout, xout = shape("max_pool2d", attr, [activations.shape])[0]
result = dc.op("pad_tile", [result], (3, cin))
Expand Down Expand Up @@ -812,10 +814,12 @@ def decompose(type, attr, dc, inputs):
pad_shape = result.shape.as_list()
pad_shape[-1] = (result_c_padding[result_c] - result_c) * TILE_DIM
zeros_tensor = dc.tensor(torch.zeros(pad_shape))
result = dc.op("concatenate", [result, zeros_tensor], (-1,))
result = dc.op_with_named_attrs("concatenate", [result, zeros_tensor], {"dim": -1}, (-1,))

result = dc.op("sparse_matmul", [picker_tensor, result])
result = dc.op("reduce_max", [result], (1,)) # z dim
result = dc.op_with_named_attrs(
"reduce_max", [result], {"dim_arg": [1], "keep_dim": True}, (1, result.shape[1], True)
) # z dim

if pad_for_factorization:
if sparse_r in sparse_r_padding:
Expand All @@ -835,10 +839,10 @@ def decompose(type, attr, dc, inputs):
result = dc.op("narrow", [result], (2, 0, yout * xout, result.shape[2]))
result = dc.op("narrow", [result], (3, 0, cin, result.shape[3]))
if channel_last:
result = dc.op("reshape", [result], (w, yout, xout, cin))
result = dc.op_with_named_attrs("reshape", [result], {"shape": (w, yout, xout, cin)}, (w, yout, xout, cin))
else:
result = dc.op(TransposeTM.create(2, 3), [result])
result = dc.op("reshape", [result], (w, cin, yout, xout))
result = dc.op_with_named_attrs("reshape", [result], {"shape": (w, cin, yout, xout)}, (w, cin, yout, xout))

if max_pool_add_sub_surround:
add_sub_val = dc.tensor(torch.zeros((1,)) + max_pool_add_sub_surround_value)
Expand Down Expand Up @@ -907,7 +911,9 @@ def decompose(type, attr, dc, inputs):
# _, yout, xout, _ = shape('max_pool2d', attr, [activations.shape])[0]
# result = dc.op("reshape", [activations], (w, 1, y * x, cin))
# else:
result = dc.op("reshape", [activations], (w, 1, cin * din, y * x))
result = dc.op_with_named_attrs(
"reshape", [activations], {"shape": (w, 1, cin * din, y * x)}, (w, 1, cin * din, y * x)
)
result = dc.op(TransposeTM.create(-2, -1), [result])
_, cout, dout, yout, xout = shape("max_pool3d", attr, [activations.shape])[0]
result = dc.op("pad_tile", [result], (-1, cin * din))
Expand Down Expand Up @@ -958,7 +964,9 @@ def create_conv2d_sparse_matrix(
)
picker_tensor = dc.tensor(picker.unsqueeze(0)) # (1, kH*kW, yout*xout, yin*xin)
result = dc.op("sparse_matmul", [picker_tensor, result]) # (1, kH*kW, yout*xout, cin*din)
result = dc.op("reduce_max", [result], (-3,)) # z dim # (1, 1, yout*xout, cin*din)
result = dc.op_with_named_attrs(
"reduce_max", [result], {"dim_arg": [-3], "keep_dim": True}, (-3, result.shape[-3], True)
) # z dim # (1, 1, yout*xout, cin*din)

# Run max pool on the depth dimension in a separate step
if kD > 1:
Expand All @@ -976,7 +984,9 @@ def create_conv2d_sparse_matrix(
# Transpose the activations to allow sparse mm to work on the depth dim
result = dc.op(TransposeTM.create(-2, -1), [result])
result = dc.op("sparse_matmul", [depth_picker, result]) # (1, kD, cout*dout, yout*xout)
result = dc.op("reduce_max", [result], (-3,)) # z dim # (1, 1, cout*dout, yout*xout)
result = dc.op_with_named_attrs(
"reduce_max", [result], {"dim_arg": [-3], "keep_dim": True}, (-3, result.shape[-3], True)
) # z dim # (1, 1, cout*dout, yout*xout)

# Transpose back
result = dc.op(TransposeTM.create(-2, -1), [result])
Expand All @@ -989,7 +999,9 @@ def create_conv2d_sparse_matrix(
# if channel_last:
# result = dc.op("reshape", [result], (w, yout, xout, cin))
# else:
result = dc.op("reshape", [result], (w, cin, dout, yout, xout))
result = dc.op_with_named_attrs(
"reshape", [result], {"shape": (w, cin, dout, yout, xout)}, (w, cin, dout, yout, xout)
)

# if max_pool_add_sub_surround:
# add_sub_val = dc.tensor(torch.zeros((1,)) + max_pool_add_sub_surround_value)
Expand Down
26 changes: 17 additions & 9 deletions forge/forge/op/eval/forge/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,15 @@ def decompose(type, attr, dc, inputs):
if len(inp_scale_shape) == 1:
# Match ndim with actiavtion
for i in range(0, left_ndim):
inp_scale = dc.op(
"unsqueeze", [inp_scale], attrs=(0, len(inp_scale_shape)), output_df=inp_scale.output_df
inp_scale = dc.op_with_named_attrs(
"unsqueeze", [inp_scale], {"dim": 0}, attrs=(0, len(inp_scale_shape)), output_df=inp_scale.output_df
)
inp_scale_shape = [1] + inp_scale_shape
for i in range(0, right_ndim):
inp_scale = dc.op(
inp_scale = dc.op_with_named_attrs(
"unsqueeze",
[inp_scale],
{"dim": (len(inp_scale_shape))},
attrs=(len(inp_scale_shape), len(inp_scale_shape)),
output_df=inp_scale.output_df,
)
Expand All @@ -219,14 +220,15 @@ def decompose(type, attr, dc, inputs):
if len(out_scale_shape) == 1:
# Match ndim with actiavtion
for i in range(0, left_ndim):
out_scale = dc.op(
"unsqueeze", [out_scale], attrs=(0, len(out_scale_shape)), output_df=out_scale.output_df
out_scale = dc.op_with_named_attrs(
"unsqueeze", [out_scale], {"dim": 0}, attrs=(0, len(out_scale_shape)), output_df=out_scale.output_df
)
out_scale_shape = [1] + out_scale_shape
for i in range(0, right_ndim):
out_scale = dc.op(
out_scale = dc.op_with_named_attrs(
"unsqueeze",
[out_scale],
{"dim": len(out_scale_shape)},
attrs=(len(out_scale_shape), len(out_scale_shape)),
output_df=out_scale.output_df,
)
Expand Down Expand Up @@ -274,11 +276,17 @@ def decompose(type, attr, dc, inputs):
if len(scale_shape) == 1:
# Match ndim with actiavtion
for i in range(0, left_ndim):
scale = dc.op("unsqueeze", [scale], attrs=(0, len(scale_shape)), output_df=scale.output_df)
scale = dc.op_with_named_attrs(
"unsqueeze", [scale], {"dim": 0}, attrs=(0, len(scale_shape)), output_df=scale.output_df
)
scale_shape = [1] + scale_shape
for i in range(0, right_ndim):
scale = dc.op(
"unsqueeze", [scale], attrs=(len(scale_shape), len(scale_shape)), output_df=scale.output_df
scale = dc.op_with_named_attrs(
"unsqueeze",
[scale],
{"dim": (len(scale_shape))},
attrs=(len(scale_shape), len(scale_shape)),
output_df=scale.output_df,
)
scale_shape = scale_shape + [1]

Expand Down
2 changes: 1 addition & 1 deletion forge/forge/op/eval/forge/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def decompose(type, attr, dc, inputs):
if isinstance(attr[0], list):
x = inputs[0]
for dim in attr[0]:
x = dc.op_with_named_attrs("reduce_avg", [x], (dim,))
x = dc.op_with_named_attrs("reduce_avg", [x], {"dim_arg": dim, "keep_dim": True}, (dim, True))
dc.fuse(x)
return

Expand Down
Loading