Skip to content

Commit

Permalink
Remove Repeat op decomposition
Browse files Browse the repository at this point in the history
This op with directly tracked to ttnn.repeat / ttnn.repeat_interleave
  • Loading branch information
ashokkumarkannan1 committed Nov 22, 2024
1 parent de355f9 commit c6aa189
Showing 1 changed file with 0 additions and 71 deletions.
71 changes: 0 additions & 71 deletions python/tvm/relay/op/contrib/forge/forge_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2123,76 +2123,6 @@ def callback(self, pre, post, node_map):

return tvm.relay.reshape(act, newshape=target_shape)

class DecomposeRepeat(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)
self.pattern = is_op("repeat")(wildcard())

def callback(self, pre, post, node_map):
repeat_axis = int(post.attrs.axis)
num_repeats = int(post.attrs.repeats)
input_shape = list(pre.args[0].checked_type.shape)
if input_shape[repeat_axis] == 1:
output_shape = input_shape
output_shape[repeat_axis] *= num_repeats
result = tvm.relay.broadcast_to(post.args[0], output_shape)
else:
if repeat_axis < 0:
repeat_axis = len(input_shape) + repeat_axis

# Step 1: If the repeat axis is not last dimension, transpose the act
# to make repeat axis as the last dimension
# Eg:
# act_shape = (1, 1, 3, 3)
# num_repeats = 2
# repeat_axis = 2
# eg: (N, C, H, W) -> (N, C, W, H)
transpose_1 = post.args[0]
transpose_1_output_shape = input_shape
if int(len(input_shape) - 1) != int(repeat_axis):
for t_axes in range(int(repeat_axis), int(len(input_shape) - 1)):
transpose_1_axes = list(range(len(input_shape)))
transpose_1_axes[t_axes], transpose_1_axes[t_axes + 1] = transpose_1_axes[t_axes + 1], transpose_1_axes[t_axes]
transpose_1 = tvm.relay.transpose(transpose_1, axes=transpose_1_axes)
transpose_1_output_shape = [transpose_1_output_shape[i_axes] for i_axes in transpose_1_axes]

# Step 2: Reshape the act to 2D for matrix multiplication
# eg: (N, C, W, H) -> (N * C * W, H)
reshape_1_new_shape = [np.prod(transpose_1_output_shape[:-1]), transpose_1_output_shape[-1]]
reshape_1 = tvm.relay.reshape(transpose_1, newshape=reshape_1_new_shape)


# Step 3: Create a repetition matrix of shape (input_shape[repeat_axis], input_shape[repeat_axis] * num_repeats)
# eg: (H, H * num_repeats)
repeat_matrix = np.zeros((int(input_shape[repeat_axis]), (int(input_shape[repeat_axis]) * num_repeats)))
for i in range(int(input_shape[repeat_axis])):
for j in range(num_repeats):
repeat_matrix[i, i * num_repeats + j] = 1.0
repeat_matrix_constant = tvm.relay.Constant(tvm.nd.array(repeat_matrix.astype(np.float32)))

# Step 4: Perform matrix multiplication (reshape_1 x repeat_matrix_constant)
# eg: (N * C * W, H) x (H, H * num_repeats) -> (N * C * W, H * num_repeats)
matmul_1 = tvm.relay.nn.matmul(reshape_1, repeat_matrix_constant)


# Step 5: Reshape back to original dimensions with repeated dimension
# eg: (N * C * W, H * repeats) -> (N, C, W, H * repeats)
final_reshape_new_shape = list(transpose_1_output_shape)
final_reshape_new_shape[-1] = final_reshape_new_shape[-1] * num_repeats
reshape_2 = tvm.relay.reshape(matmul_1, newshape=final_reshape_new_shape)

# Step 6: If the repeat axis is not last dimension, transpose back to original axes order
# eg: (N, C, W, H * repeats) => (N, C, H * repeats, W)
result = reshape_2
if int(len(input_shape) - 1) != int(repeat_axis):
for t_axes in range(int(len(input_shape) - 1), int(repeat_axis), -1):
reverse_transpose_axes = list(range(len(input_shape)))
reverse_transpose_axes[t_axes], reverse_transpose_axes[t_axes - 1] = reverse_transpose_axes[t_axes - 1], reverse_transpose_axes[t_axes]
result = tvm.relay.transpose(result, axes=reverse_transpose_axes)

return result


class ConvertGlobalAvgPool2dtoAvgPool2d(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)
Expand Down Expand Up @@ -4121,7 +4051,6 @@ def run_forge_compile_passes(relay_module, params=None, inputs=None, target=None
LowerTakeToStridedSlice(),
ConvertAddToBiasAddAfterConv2d(),
DecomposeBatchFlatten(),
DecomposeRepeat(),
ConvertGlobalAvgPool2dtoAvgPool2d(),
ConvertUpsampleToResize2d(),
DecomposeMultiIndexAdvIndex(),
Expand Down

0 comments on commit c6aa189

Please sign in to comment.