Skip to content

Commit

Permalink
Added passes for QDQ reconstruction (#11)
Browse files Browse the repository at this point in the history
Duplicating dequantize nodes with forks (skip connections)

Removing unused QDQ passes
  • Loading branch information
LPanosTT authored Jul 19, 2024
2 parents 6b61b3e + 2189168 commit 05bf431
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 17 deletions.
62 changes: 62 additions & 0 deletions python/tvm/contrib/pybuda_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,67 @@ def trim_count(name):

return json_graph


def duplicate_dequantize_nodes_in_onnx_graph(onnx_module):
from collections import defaultdict

# Create a dictionary to store the consumers of each tensor
tensor_consumers = defaultdict(list)

graph = onnx_module.graph
# Populate the tensor_consumers dictionary
for node in graph.node:
for input_name in node.input:
tensor_consumers[input_name].append(node.name)

# Find and duplicate nodes with output branches
nodes_to_add = []
nodes_to_remove = []
indices_for_adding = []
for node_ind, node in enumerate(graph.node):

if node.op_type != "DequantizeLinear":
continue

output_name = node.output[0]
consumers = tensor_consumers[output_name]

if len(consumers) > 1:
# Duplicate the node for each consumer
for i, consumer_name in enumerate(consumers):
new_node_name = node.name + f"_clone{i}"
new_output_name = output_name + f"_clone{i}"

cloned_node = onnx.helper.make_node(
node.op_type,
node.input,
[new_output_name],
name=new_node_name
)

# Add the cloned node to the list of nodes to add
nodes_to_add.append(cloned_node)
indices_for_adding.append((cloned_node, node_ind))

# Update the consumer to use the cloned node's output
consumer_node = next(n for n in graph.node if n.name == consumer_name)
for j, input_name in enumerate(consumer_node.input):
if input_name == output_name:
consumer_node.input[j] = new_output_name

# Remove the original node since it will be replaced by its clones
nodes_to_remove.append(node)


# This is needed to remain the order of the nodes in graph
# since graph is not put in topsort order when visiting nodes
for i, (node, insertion_index) in enumerate(indices_for_adding):
graph.node.insert(insertion_index + i, node)

for node in nodes_to_remove:
graph.node.remove(node)


def compile_onnx_for_buda(onnx_mod, path, *inputs, graph_name, compiler_cfg, verify_cfg=None):
import onnxruntime as ort

Expand All @@ -483,6 +544,7 @@ def compile_onnx_for_buda(onnx_mod, path, *inputs, graph_name, compiler_cfg, ver

assert len(input_names) == len(inputs), "Number of input names must match number of inputs"

duplicate_dequantize_nodes_in_onnx_graph(onnx_mod)
framework_outputs = extract_framework_model_outputs(
framework="onnx",
model=onnx_mod,
Expand Down
18 changes: 2 additions & 16 deletions python/tvm/relay/op/contrib/buda/buda_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,19 +2232,6 @@ def callback(self, pre, post, node_map):
return tvm.relay.gelu(node_map[self.act][0])


class RemoveQuantDequantSequence(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)
self.act = wildcard()
self.quant = is_op("qnn.quantize")(self.act, wildcard(), wildcard(),)
self.pattern = is_op("qnn.dequantize")(self.quant, wildcard(), wildcard(),)

def callback(self, pre, post, node_map):
act = node_map[self.act][0]
quant = node_map[self.quant][0]
return node_map[self.act][0]


class ReconstructOnnxQuantizedGelu(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)
Expand Down Expand Up @@ -3364,7 +3351,7 @@ def callback(self, pre, post, node_map):
if index != val:
count += 1

assert (count == 2, "Multi-axis transpose should be decomposed into single-axis transpose at this point")
assert count == 2, "Multi-axis transpose should be decomposed into single-axis transpose at this point"
is_transpose_yz = len(transpose_axes) >= 3 and len(dims) >= 3 and (transpose_axes[-2] == dims[-3] and transpose_axes[-3] == dims[-2])

if (
Expand Down Expand Up @@ -3805,7 +3792,7 @@ def _get_callback_name(callback):
elif isinstance(callback, tvm.transform.Pass):
return callback.info.name
else:
raise NotImplementedError(f"Type of callback ({type(callback)}) not implemented")
raise NotImplementedError(f"Type of callback ({(callback)}) not implemented")


def _run_pattern_callback(relay_module, callback, callback_name):
Expand Down Expand Up @@ -3902,7 +3889,6 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None,
ConvertGlobalAvgPool2dtoAvgPool2d(),
ConvertUpsampleToResize2d(),
DecomposeMultiIndexAdvIndex(),
RemoveQuantDequantSequence(),
ReconstructOnnxQuantizedGelu(),
DecomposeQnnConcat(),
# DecomposeErf(),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/contrib/buda/relay_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def fskip_eliminate(expr):
if isinstance(expr, relay.expr.Call) and expr.op.name == "transpose":
if isinstance(expr, relay.expr.Call) and (expr.op.name == "transpose" or expr.op.name == "qnn.dequantize" or expr.op.name == "qnn.quantize"):
return True
return False

Expand Down

0 comments on commit 05bf431

Please sign in to comment.