diff --git a/python/tvm/contrib/pybuda_compile.py b/python/tvm/contrib/pybuda_compile.py index 29991a2691..c0b72047d4 100644 --- a/python/tvm/contrib/pybuda_compile.py +++ b/python/tvm/contrib/pybuda_compile.py @@ -459,6 +459,67 @@ def trim_count(name): return json_graph + +def duplicate_dequantize_nodes_in_onnx_graph(onnx_module): + from collections import defaultdict + + # Create a dictionary to store the consumers of each tensor + tensor_consumers = defaultdict(list) + + graph = onnx_module.graph + # Populate the tensor_consumers dictionary + for node in graph.node: + for input_name in node.input: + tensor_consumers[input_name].append(node.name) + + # Find and duplicate nodes with output branches + nodes_to_add = [] + nodes_to_remove = [] + indices_for_adding = [] + for node_ind, node in enumerate(graph.node): + + if node.op_type != "DequantizeLinear": + continue + + output_name = node.output[0] + consumers = tensor_consumers[output_name] + + if len(consumers) > 1: + # Duplicate the node for each consumer + for i, consumer_name in enumerate(consumers): + new_node_name = node.name + f"_clone{i}" + new_output_name = output_name + f"_clone{i}" + + cloned_node = onnx.helper.make_node( + node.op_type, + node.input, + [new_output_name], + name=new_node_name + ) + + # Add the cloned node to the list of nodes to add + nodes_to_add.append(cloned_node) + indices_for_adding.append((cloned_node, node_ind)) + + # Update the consumer to use the cloned node's output + consumer_node = next(n for n in graph.node if n.name == consumer_name) + for j, input_name in enumerate(consumer_node.input): + if input_name == output_name: + consumer_node.input[j] = new_output_name + + # Remove the original node since it will be replaced by its clones + nodes_to_remove.append(node) + + + # This is needed to remain the order of the nodes in graph + # since graph is not put in topsort order when visiting nodes + for i, (node, insertion_index) in enumerate(indices_for_adding): + graph.node.insert(insertion_index + i, node) + + for node in nodes_to_remove: + graph.node.remove(node) + + def compile_onnx_for_buda(onnx_mod, path, *inputs, graph_name, compiler_cfg, verify_cfg=None): import onnxruntime as ort @@ -483,6 +544,7 @@ def compile_onnx_for_buda(onnx_mod, path, *inputs, graph_name, compiler_cfg, ver assert len(input_names) == len(inputs), "Number of input names must match number of inputs" + duplicate_dequantize_nodes_in_onnx_graph(onnx_mod) framework_outputs = extract_framework_model_outputs( framework="onnx", model=onnx_mod, diff --git a/python/tvm/relay/op/contrib/buda/buda_passes.py b/python/tvm/relay/op/contrib/buda/buda_passes.py index 1600ab0c1c..22c6d695b4 100644 --- a/python/tvm/relay/op/contrib/buda/buda_passes.py +++ b/python/tvm/relay/op/contrib/buda/buda_passes.py @@ -2232,19 +2232,6 @@ def callback(self, pre, post, node_map): return tvm.relay.gelu(node_map[self.act][0]) -class RemoveQuantDequantSequence(DFPatternCallback): - def __init__(self): - super().__init__(rewrite_once=True, require_type=True) - self.act = wildcard() - self.quant = is_op("qnn.quantize")(self.act, wildcard(), wildcard(),) - self.pattern = is_op("qnn.dequantize")(self.quant, wildcard(), wildcard(),) - - def callback(self, pre, post, node_map): - act = node_map[self.act][0] - quant = node_map[self.quant][0] - return node_map[self.act][0] - - class ReconstructOnnxQuantizedGelu(DFPatternCallback): def __init__(self): super().__init__(rewrite_once=True, require_type=True) @@ -3364,7 +3351,7 @@ def callback(self, pre, post, node_map): if index != val: count += 1 - assert (count == 2, "Multi-axis transpose should be decomposed into single-axis transpose at this point") + assert count == 2, "Multi-axis transpose should be decomposed into single-axis transpose at this point" is_transpose_yz = len(transpose_axes) >= 3 and len(dims) >= 3 and (transpose_axes[-2] == dims[-3] and transpose_axes[-3] == dims[-2]) if ( @@ -3805,7 +3792,7 @@ def _get_callback_name(callback): elif isinstance(callback, tvm.transform.Pass): return callback.info.name else: - raise NotImplementedError(f"Type of callback ({type(callback)}) not implemented") + raise NotImplementedError(f"Type of callback ({(callback)}) not implemented") def _run_pattern_callback(relay_module, callback, callback_name): @@ -3902,7 +3889,6 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None, ConvertGlobalAvgPool2dtoAvgPool2d(), ConvertUpsampleToResize2d(), DecomposeMultiIndexAdvIndex(), - RemoveQuantDequantSequence(), ReconstructOnnxQuantizedGelu(), DecomposeQnnConcat(), # DecomposeErf(), diff --git a/python/tvm/relay/op/contrib/buda/relay_passes.py b/python/tvm/relay/op/contrib/buda/relay_passes.py index a876ea9e88..ca6e068e4d 100644 --- a/python/tvm/relay/op/contrib/buda/relay_passes.py +++ b/python/tvm/relay/op/contrib/buda/relay_passes.py @@ -22,7 +22,7 @@ def fskip_eliminate(expr): - if isinstance(expr, relay.expr.Call) and expr.op.name == "transpose": + if isinstance(expr, relay.expr.Call) and (expr.op.name == "transpose" or expr.op.name == "qnn.dequantize" or expr.op.name == "qnn.quantize"): return True return False