Skip to content

Commit

Permalink
Decompose repeat_interleave pytorch op in buda passes
Browse files Browse the repository at this point in the history
  • Loading branch information
chandrasekaranpradeep committed Jul 22, 2024
1 parent 9c685f9 commit f81b478
Showing 1 changed file with 61 additions and 6 deletions.
67 changes: 61 additions & 6 deletions python/tvm/relay/op/contrib/buda/buda_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,14 +2055,69 @@ def __init__(self):
self.pattern = is_op("repeat")(wildcard())

def callback(self, pre, post, node_map):
axis = int(post.attrs.axis)
repeat_axis = int(post.attrs.axis)
num_repeats = int(post.attrs.repeats)
input_shape = list(pre.args[0].checked_type.shape)
assert input_shape[axis] == 1, "Cannot decompose repeat to broadcast when input dim != 1"
output_shape = input_shape
output_shape[axis] *= num_repeats
if input_shape[repeat_axis] == 1:
output_shape = input_shape
output_shape[repeat_axis] *= num_repeats
result = tvm.relay.broadcast_to(post.args[0], output_shape)
else:
if repeat_axis < 0:
repeat_axis = len(input_shape) + repeat_axis

input_axes = list(range(len(input_shape)))

tranpose_axes = input_axes[:repeat_axis] + input_axes[repeat_axis + 1:] + [repeat_axis]
transpose_1_output_shape = [input_shape[t_axes] for t_axes in tranpose_axes]


# Step 1: If the repeat axis is not last dimension, transpose the act
# to make repeat axis as the last dimension
# Eg:
# act_shape = (1, 1, 3, 3)
# num_repeats = 2
# repeat_axis = 2
# eg: (N, C, H, W) -> (N, C, W, H)
if int(len(input_shape) - 1) != int(repeat_axis):
tranpose_1 = tvm.relay.transpose(post.args[0], axes=tranpose_axes)
else:
tranpose_1 = post.args[0]


# Step 2: Reshape the act to 2D for matrix multiplication
# eg: (N, C, W, H) -> (N * C * W, H)
reshape_1_new_shape = [np.prod(transpose_1_output_shape[:-1]), transpose_1_output_shape[-1]]
reshape_1 = tvm.relay.reshape(tranpose_1, newshape=reshape_1_new_shape)


# Step 3: Create a repetition matrix of shape (input_shape[repeat_axis], input_shape[repeat_axis] * num_repeats)
# eg: (H, H * num_repeats)
repeat_matrix = np.zeros((int(input_shape[repeat_axis]), (int(input_shape[repeat_axis]) * num_repeats)))
for i in range(int(input_shape[repeat_axis])):
for j in range(num_repeats):
repeat_matrix[i, i * num_repeats + j] = 1.0
repeat_matrix_constant = tvm.relay.Constant(tvm.nd.array(repeat_matrix.astype(np.float32)))

# Step 4: Perform matrix multiplication (reshape_1 x repeat_matrix_constant)
# eg: (N * C * W, H) x (H, H * num_repeats) -> (N * C * W, H * num_repeats)
matmul_1 = tvm.relay.nn.matmul(reshape_1, repeat_matrix_constant)


# Step 5: Reshape back to original dimensions with repeated dimension
# eg: (N * C * W, H * repeats) -> (N, C, W, H * repeats)
final_reshape_new_shape = list(transpose_1_output_shape)
final_reshape_new_shape[-1] = final_reshape_new_shape[-1] * num_repeats
reshape_2 = tvm.relay.reshape(matmul_1, newshape=final_reshape_new_shape)

# Step 6: If the repeat axis is not last dimension, transpose back to original axes order
# eg: (N, C, W, H * repeats) => (N, C, H * repeats, W)
if int(len(input_shape) - 1) != int(repeat_axis):
reverse_tranpose_axes = [tranpose_axes.index(i) for i in range(len(tranpose_axes))]
result = tvm.relay.transpose(reshape_2, axes=reverse_tranpose_axes)
else:
result = reshape_2

result = tvm.relay.broadcast_to(post.args[0], output_shape)
return result


Expand Down Expand Up @@ -3862,7 +3917,6 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None,
DecomposeConv1DToConv2D(),
PopulateReduceAxes(),
DecomposeMultiAxisMax(),
DecomposeMultiAxisTranspose(),
EstimateWhereInCausalMask(),
CastWhereConditionToBool(),
LowerAdaptiveAvgPool(),
Expand All @@ -3886,6 +3940,7 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None,
ConvertAddToBiasAddAfterConv2d(),
DecomposeBatchFlatten(),
DecomposeRepeat(),
DecomposeMultiAxisTranspose(),
ConvertGlobalAvgPool2dtoAvgPool2d(),
ConvertUpsampleToResize2d(),
DecomposeMultiIndexAdvIndex(),
Expand Down

0 comments on commit f81b478

Please sign in to comment.