From ffc8e5ab0376da049a55a5f944f180156b122b03 Mon Sep 17 00:00:00 2001 From: pchandrasekaran Date: Thu, 18 Apr 2024 08:25:12 +0000 Subject: [PATCH] Update SimplifyReshape Pattern Callback --- python/tvm/relay/op/contrib/buda/buda_passes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/contrib/buda/buda_passes.py b/python/tvm/relay/op/contrib/buda/buda_passes.py index 3bde562e17..28bb6371e1 100644 --- a/python/tvm/relay/op/contrib/buda/buda_passes.py +++ b/python/tvm/relay/op/contrib/buda/buda_passes.py @@ -3010,9 +3010,9 @@ def callback(self, pre, post, node_map): reshape_1 = node_map[self.reshape_1][0] final_shape = list(reshape_1.attrs.newshape) - if input_shape[0] * input_shape[1] == final_shape[-1]: - reshape_0 = tvm.relay.reshape(node_map[self.reshape_0][0], newshape=[1, 1, input_shape[-2] * input_shape[-3], input_shape[-1]]) - final_transpose = tvm.relay.transpose(reshape_0, axes=[0,1,3,2]) + if input_shape == final_shape and len(input_shape) >= 3: + final_transpose_axes = np.arange(int(len(input_shape) - 2)).tolist() + np.flip(np.arange(int(len(input_shape) - 2), len(input_shape))).tolist() + final_transpose = tvm.relay.transpose(node_map[self.act][0], axes=final_transpose_axes) return final_transpose return post