diff --git a/python/tvm/relay/op/contrib/buda/buda_passes.py b/python/tvm/relay/op/contrib/buda/buda_passes.py index 1600ab0c1c..28bb6371e1 100644 --- a/python/tvm/relay/op/contrib/buda/buda_passes.py +++ b/python/tvm/relay/op/contrib/buda/buda_passes.py @@ -3727,76 +3727,6 @@ def callback(self, pre, post, node_map): return unpadded_bmm -class GQABroadcastReshape(DFPatternCallback): - """ - Callback for Grouped Query Attention Pattern. When parsing a standard GQA, - A subpattern appears that is in the form: - - (bs, n_kv_heads, seq_len, head_dim) ->[reshape0]-> (bs, n_kv_heads, 1, seq_len, head_dim) ->[bc0]-> - (bs, n_kv_heads, 1, seq_len, head_dim) ->[bc1]-> (bs, n_kv_heads, n_kv_blocks, seq_len, head_dim) ->[reshape1]-> - (n_query_heads, bs*seq_len, head_dim) ->[transpose]-> (n_query_heads, head_dim, bs*seq_len) - - Where n_query_heads == n_kv_heads * n_kv_blocks. The problem with this subpattern is this broadcast that is - performed (bc1) which generates a 5D tensor with 4 dimensions that are not equal to 1. - That bc output is then input to reshape1 and pybuda compiler has no way to decompose a reshape that is performed - on such input tensor. That is why we change this pattern so that this doesn't occur. - - Modification: - - (bs, n_kv_heads, seq_len, head_dim) ->[transpose] -> (bs, seq_len, n_kv_heads, head_dim) - ->[reshape]-> (bs, n_kv_heads*seq_len, 1, head_dim) ->[bc]-> (bs, seq_len*n_kv_heads, brcst_val, head_dim) - ->[reshape]-> (bs*seqlen, n_kv_heads*brcst_val, head_dim) ->[transpose]-> (n_kv_heads*brcst_val, bs*seqlen, head_dim) - ->[transpose]-> (n_kv_heads*brcst_val, head_dim, bs*seqlen) - - """ - def __init__(self, require_type=False, rewrite_once=False): - super().__init__(require_type, rewrite_once) - self.act = wildcard() - self.reshape0 = is_op('reshape')(self.act) - self.bc0 = is_op('broadcast_to')(self.reshape0) - self.bc1 = is_op('broadcast_to')(self.bc0) - self.reshape1 = is_op('reshape')(self.bc1) - self.pattern = is_op('transpose')(self.reshape1) - - def callback(self, pre, post, node_map): - act = node_map[self.act][0] # [bs, n_kv_heads, seq_len, head_dim] - orig_shape = act.checked_type.shape - - # idea is to catch only reshapes [bs, n_kv_heads, seq_len, head_dim] -> [bs, n_kv_heads, 1, seq_len, head_dim] - if len(orig_shape) != 4: - return post - - if len(node_map[self.reshape0][0].attrs.newshape) != 5: - return post - - transpose0 = tvm.relay.transpose(act, axes=[0,2,1,3]) # [bs, seq_len, n_kv_heads, head_dim] - prev_shape = (orig_shape[-4], orig_shape[-2], orig_shape[-3], orig_shape[-1]) # a.k.a. transpose0 shape - - new_shape = (prev_shape[-4], int(prev_shape[-3] * prev_shape[-2]), 1, prev_shape[-1]) - - reshape0 = tvm.relay.reshape(transpose0, newshape=new_shape) # (bs, seq_len*n_kv_heads, 1, head_dim) - - if new_shape[-3] != prev_shape[-3] * prev_shape[-2]: - return post - - bc1 = node_map[self.bc1][0] - pre_broadcast_shape = list(bc1.type_args[0].shape) - post_broadcast_shape = list(bc1.attrs.shape) - - # get the value of dimension that is different after applying broadcast - broadcasted_value = [el for idx, el in enumerate(post_broadcast_shape) if el != pre_broadcast_shape[idx]][0] - - new_broadcast_shape = [el for el in new_shape] - new_broadcast_shape[-2] = broadcasted_value - - bc = tvm.relay.broadcast_to(reshape0, new_broadcast_shape) # (bs, seq_len*n_kv_heads, 1, head_dim) -> (bs, seq_len*n_kv_heads, brcst_val, head_dim) - - new_shape = [prev_shape[-4]*prev_shape[-3], new_broadcast_shape[-3]*new_broadcast_shape[-2] // prev_shape[-3], prev_shape[-1]] # (bs*seqlen, n_kv_heads*brcst_val, head_dim) - reshape1 = tvm.relay.reshape(bc, new_shape) - - transpose1 = tvm.relay.transpose(reshape1, axes=[1,0,2]) # (n_kv_heads*brcst_val, bs*seqlen, head_dim) - transpose2 = tvm.relay.transpose(transpose1, axes=[0,2,1]) # (n_kv_heads*brcst_val, head_dim, bs*seqlen) - return transpose2 def _get_callback_name(callback): @@ -3935,7 +3865,6 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None, # LowerSplitToStridedSlice(), PadSpecificBatchMatmulShapes(), SimplifyVITOnnxAttention(), - GQABroadcastReshape(), ], params=params, inputs=inputs,