From a14c698efbfda311b8d387330fb73d723bfea5ab Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Thu, 1 Aug 2024 16:04:31 +0000 Subject: [PATCH 1/2] keep axis attribute in inserted onnx dequantize clone --- python/tvm/contrib/pybuda_compile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/pybuda_compile.py b/python/tvm/contrib/pybuda_compile.py index c0b72047d..a54718bed 100644 --- a/python/tvm/contrib/pybuda_compile.py +++ b/python/tvm/contrib/pybuda_compile.py @@ -489,12 +489,12 @@ def duplicate_dequantize_nodes_in_onnx_graph(onnx_module): for i, consumer_name in enumerate(consumers): new_node_name = node.name + f"_clone{i}" new_output_name = output_name + f"_clone{i}" - cloned_node = onnx.helper.make_node( node.op_type, node.input, [new_output_name], - name=new_node_name + name=new_node_name, + axis=node.attribute[0].i ) # Add the cloned node to the list of nodes to add From 78e10bb85bc170cac9297a0cfe68468793d27594 Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Thu, 1 Aug 2024 17:38:04 +0000 Subject: [PATCH 2/2] Conditionally add dequant axis --- python/tvm/contrib/pybuda_compile.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/pybuda_compile.py b/python/tvm/contrib/pybuda_compile.py index a54718bed..260867fa4 100644 --- a/python/tvm/contrib/pybuda_compile.py +++ b/python/tvm/contrib/pybuda_compile.py @@ -489,12 +489,13 @@ def duplicate_dequantize_nodes_in_onnx_graph(onnx_module): for i, consumer_name in enumerate(consumers): new_node_name = node.name + f"_clone{i}" new_output_name = output_name + f"_clone{i}" + attrs = {"axis": node.attribute[0].i} if len(node.attribute) > 0 else {} cloned_node = onnx.helper.make_node( node.op_type, node.input, [new_output_name], name=new_node_name, - axis=node.attribute[0].i + **attrs ) # Add the cloned node to the list of nodes to add