diff --git a/python/tvm/relay/op/contrib/buda/buda_passes.py b/python/tvm/relay/op/contrib/buda/buda_passes.py index 22c6d695b..424bce8e6 100644 --- a/python/tvm/relay/op/contrib/buda/buda_passes.py +++ b/python/tvm/relay/op/contrib/buda/buda_passes.py @@ -2055,14 +2055,69 @@ def __init__(self): self.pattern = is_op("repeat")(wildcard()) def callback(self, pre, post, node_map): - axis = int(post.attrs.axis) + repeat_axis = int(post.attrs.axis) num_repeats = int(post.attrs.repeats) input_shape = list(pre.args[0].checked_type.shape) - assert input_shape[axis] == 1, "Cannot decompose repeat to broadcast when input dim != 1" - output_shape = input_shape - output_shape[axis] *= num_repeats + if input_shape[repeat_axis] == 1: + output_shape = input_shape + output_shape[repeat_axis] *= num_repeats + result = tvm.relay.broadcast_to(post.args[0], output_shape) + else: + if repeat_axis < 0: + repeat_axis = len(input_shape) + repeat_axis + + input_axes = list(range(len(input_shape))) + + tranpose_axes = input_axes[:repeat_axis] + input_axes[repeat_axis + 1:] + [repeat_axis] + transpose_1_output_shape = [input_shape[t_axes] for t_axes in tranpose_axes] + + + # Step 1: If the repeat axis is not last dimension, transpose the act + # to make repeat axis as the last dimension + # Eg: + # act_shape = (1, 1, 3, 3) + # num_repeats = 2 + # repeat_axis = 2 + # eg: (N, C, H, W) -> (N, C, W, H) + if int(len(input_shape) - 1) != int(repeat_axis): + tranpose_1 = tvm.relay.transpose(post.args[0], axes=tranpose_axes) + else: + tranpose_1 = post.args[0] + + + # Step 2: Reshape the act to 2D for matrix multiplication + # eg: (N, C, W, H) -> (N * C * W, H) + reshape_1_new_shape = [np.prod(transpose_1_output_shape[:-1]), transpose_1_output_shape[-1]] + reshape_1 = tvm.relay.reshape(tranpose_1, newshape=reshape_1_new_shape) + + + # Step 3: Create a repetition matrix of shape (input_shape[repeat_axis], input_shape[repeat_axis] * num_repeats) + # eg: (H, H * num_repeats) + repeat_matrix = np.zeros((int(input_shape[repeat_axis]), (int(input_shape[repeat_axis]) * num_repeats))) + for i in range(int(input_shape[repeat_axis])): + for j in range(num_repeats): + repeat_matrix[i, i * num_repeats + j] = 1.0 + repeat_matrix_constant = tvm.relay.Constant(tvm.nd.array(repeat_matrix.astype(np.float32))) + + # Step 4: Perform matrix multiplication (reshape_1 x repeat_matrix_constant) + # eg: (N * C * W, H) x (H, H * num_repeats) -> (N * C * W, H * num_repeats) + matmul_1 = tvm.relay.nn.matmul(reshape_1, repeat_matrix_constant) + + + # Step 5: Reshape back to original dimensions with repeated dimension + # eg: (N * C * W, H * repeats) -> (N, C, W, H * repeats) + final_reshape_new_shape = list(transpose_1_output_shape) + final_reshape_new_shape[-1] = final_reshape_new_shape[-1] * num_repeats + reshape_2 = tvm.relay.reshape(matmul_1, newshape=final_reshape_new_shape) + + # Step 6: If the repeat axis is not last dimension, transpose back to original axes order + # eg: (N, C, W, H * repeats) => (N, C, H * repeats, W) + if int(len(input_shape) - 1) != int(repeat_axis): + reverse_tranpose_axes = [tranpose_axes.index(i) for i in range(len(tranpose_axes))] + result = tvm.relay.transpose(reshape_2, axes=reverse_tranpose_axes) + else: + result = reshape_2 - result = tvm.relay.broadcast_to(post.args[0], output_shape) return result @@ -3862,7 +3917,6 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None, DecomposeConv1DToConv2D(), PopulateReduceAxes(), DecomposeMultiAxisMax(), - DecomposeMultiAxisTranspose(), EstimateWhereInCausalMask(), CastWhereConditionToBool(), LowerAdaptiveAvgPool(), @@ -3886,6 +3940,7 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None, ConvertAddToBiasAddAfterConv2d(), DecomposeBatchFlatten(), DecomposeRepeat(), + DecomposeMultiAxisTranspose(), ConvertGlobalAvgPool2dtoAvgPool2d(), ConvertUpsampleToResize2d(), DecomposeMultiIndexAdvIndex(),