diff --git a/python/tvm/relay/op/contrib/forge/forge_passes.py b/python/tvm/relay/op/contrib/forge/forge_passes.py index a794ae57c0..45f811aa50 100644 --- a/python/tvm/relay/op/contrib/forge/forge_passes.py +++ b/python/tvm/relay/op/contrib/forge/forge_passes.py @@ -2123,76 +2123,6 @@ def callback(self, pre, post, node_map): return tvm.relay.reshape(act, newshape=target_shape) -class DecomposeRepeat(DFPatternCallback): - def __init__(self): - super().__init__(rewrite_once=True, require_type=True) - self.pattern = is_op("repeat")(wildcard()) - - def callback(self, pre, post, node_map): - repeat_axis = int(post.attrs.axis) - num_repeats = int(post.attrs.repeats) - input_shape = list(pre.args[0].checked_type.shape) - if input_shape[repeat_axis] == 1: - output_shape = input_shape - output_shape[repeat_axis] *= num_repeats - result = tvm.relay.broadcast_to(post.args[0], output_shape) - else: - if repeat_axis < 0: - repeat_axis = len(input_shape) + repeat_axis - - # Step 1: If the repeat axis is not last dimension, transpose the act - # to make repeat axis as the last dimension - # Eg: - # act_shape = (1, 1, 3, 3) - # num_repeats = 2 - # repeat_axis = 2 - # eg: (N, C, H, W) -> (N, C, W, H) - transpose_1 = post.args[0] - transpose_1_output_shape = input_shape - if int(len(input_shape) - 1) != int(repeat_axis): - for t_axes in range(int(repeat_axis), int(len(input_shape) - 1)): - transpose_1_axes = list(range(len(input_shape))) - transpose_1_axes[t_axes], transpose_1_axes[t_axes + 1] = transpose_1_axes[t_axes + 1], transpose_1_axes[t_axes] - transpose_1 = tvm.relay.transpose(transpose_1, axes=transpose_1_axes) - transpose_1_output_shape = [transpose_1_output_shape[i_axes] for i_axes in transpose_1_axes] - - # Step 2: Reshape the act to 2D for matrix multiplication - # eg: (N, C, W, H) -> (N * C * W, H) - reshape_1_new_shape = [np.prod(transpose_1_output_shape[:-1]), transpose_1_output_shape[-1]] - reshape_1 = tvm.relay.reshape(transpose_1, newshape=reshape_1_new_shape) - - - # Step 3: Create a repetition matrix of shape (input_shape[repeat_axis], input_shape[repeat_axis] * num_repeats) - # eg: (H, H * num_repeats) - repeat_matrix = np.zeros((int(input_shape[repeat_axis]), (int(input_shape[repeat_axis]) * num_repeats))) - for i in range(int(input_shape[repeat_axis])): - for j in range(num_repeats): - repeat_matrix[i, i * num_repeats + j] = 1.0 - repeat_matrix_constant = tvm.relay.Constant(tvm.nd.array(repeat_matrix.astype(np.float32))) - - # Step 4: Perform matrix multiplication (reshape_1 x repeat_matrix_constant) - # eg: (N * C * W, H) x (H, H * num_repeats) -> (N * C * W, H * num_repeats) - matmul_1 = tvm.relay.nn.matmul(reshape_1, repeat_matrix_constant) - - - # Step 5: Reshape back to original dimensions with repeated dimension - # eg: (N * C * W, H * repeats) -> (N, C, W, H * repeats) - final_reshape_new_shape = list(transpose_1_output_shape) - final_reshape_new_shape[-1] = final_reshape_new_shape[-1] * num_repeats - reshape_2 = tvm.relay.reshape(matmul_1, newshape=final_reshape_new_shape) - - # Step 6: If the repeat axis is not last dimension, transpose back to original axes order - # eg: (N, C, W, H * repeats) => (N, C, H * repeats, W) - result = reshape_2 - if int(len(input_shape) - 1) != int(repeat_axis): - for t_axes in range(int(len(input_shape) - 1), int(repeat_axis), -1): - reverse_transpose_axes = list(range(len(input_shape))) - reverse_transpose_axes[t_axes], reverse_transpose_axes[t_axes - 1] = reverse_transpose_axes[t_axes - 1], reverse_transpose_axes[t_axes] - result = tvm.relay.transpose(result, axes=reverse_transpose_axes) - - return result - - class ConvertGlobalAvgPool2dtoAvgPool2d(DFPatternCallback): def __init__(self): super().__init__(rewrite_once=True, require_type=True) @@ -4121,7 +4051,6 @@ def run_forge_compile_passes(relay_module, params=None, inputs=None, target=None LowerTakeToStridedSlice(), ConvertAddToBiasAddAfterConv2d(), DecomposeBatchFlatten(), - DecomposeRepeat(), ConvertGlobalAvgPool2dtoAvgPool2d(), ConvertUpsampleToResize2d(), DecomposeMultiIndexAdvIndex(),