Skip to content

Commit

Permalink
Merge branch 'pchandrasekaran/budapasses_fix' into 'main'
Browse files Browse the repository at this point in the history
Add validation for RemoveRedundantReshapeTransposeReshape Pattern Callback

See merge request tenstorrent/tvm!54
  • Loading branch information
chandrasekaranpradeep committed Mar 19, 2024
2 parents 99235d6 + 18695b9 commit dfda382
Showing 1 changed file with 38 additions and 3 deletions.
41 changes: 38 additions & 3 deletions python/tvm/relay/op/contrib/buda/buda_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2913,10 +2913,45 @@ def callback(self, pre, post, node_map):
reshape2 = node_map[self.reshape_2][0]
final_shape = reshape2.attrs.newshape

if list(final_shape) == list(input_shape)[-2:]:
return tvm.relay.reshape(node_map[self.act][0], newshape=final_shape)
reshape_1_new_shape = list(pre_node_map[self.reshape_1][0].checked_type.shape)
transpose_axis_list = list(pre_node_map[self.transpose][0].attrs.axes)

# Extract transpose dimensions(i.e dim0 and dim1) from transpose op axes parameter
# which is used for verification of data movement in tranpose op for reduction
# into single reshape op.
# Eg: If the input shape for tranpose op is (1,4,1,1,4,9) and tranpose axes is
# (0,2,1,3,4,5), then the tranpose dimemsions dim0 = 1 and dim1 = 2 are
# extracted by taking subtraction of transpose axes (0,2,1,3,4,5) from
# list (0,1,2,3,4,5) which is of length of original tranpose op input minus one
# and then tranpose dimensions are filtered out by taking non-zero index values
# from subraction array (0,-1,1,0,0,0).
dim0, dim1 = np.nonzero(np.subtract(np.arange(len(reshape_1_new_shape)), np.array(transpose_axis_list)))[0].tolist()
is_reshapeable = False

# If both the transpose dims values are 1, it will be reshapable
# Eg: newshape = (1,4,1,1,4,9), dim0 = 0, dim1 = 2
if reshape_1_new_shape[dim0] == 1 and reshape_1_new_shape[dim1] == 1:
is_reshapeable = True

# If the dim0 value is 1 and the dim0 is ahead or behind the dim1, it will be reshapeable
# Eg: newshape = (1,4,1,1,4,9), dim0 = 0, dim1 = 1
elif reshape_1_new_shape[dim0] == 1 and (dim0 - 1 == dim1 or dim0 + 1 == dim1):
is_reshapeable = True

# If the dim1 value is 1 and the dim1 value is ahead or behind the dim0, it will be reshapeable
# Eg: newshape = (1,4,1,1,4,9), dim0 = 2, dim1 = 1
elif reshape_1_new_shape[dim1] == 1 and (dim1 - 1 == dim0 or dim1 + 1 == dim0 ):
is_reshapeable = True

# If the dim0 or dim1 value is 1 and the intermediate tranpose dims values are 1, it will reshapeable
# Eg: newshape = (1,4,1,1,4,9), dim0 = 2, dim1 = 4
elif (reshape_1_new_shape[dim0] == 1 or reshape_1_new_shape[dim1] == 1) and np.all(np.array(reshape_1_new_shape[(min(dim0, dim1)+1):max(dim0, dim1)]) == 1):
is_reshapeable = True

else:
is_reshapeable = False

if list(final_shape) == [1] + list(input_shape):
if is_reshapeable and (list(final_shape) == list(input_shape)[-2:] or list(final_shape) == [1] + list(input_shape) or list(final_shape) == list(input_shape)):
return tvm.relay.reshape(node_map[self.act][0], newshape=final_shape)

return post
Expand Down

0 comments on commit dfda382

Please sign in to comment.