Skip to content

Commit

Permalink
Added passes for QDQ reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
acvetkovic-tt committed May 28, 2024
1 parent 99f3160 commit 3071de5
Showing 1 changed file with 163 additions and 0 deletions.
163 changes: 163 additions & 0 deletions python/tvm/relay/op/contrib/buda/buda_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,6 +2232,167 @@ def callback(self, pre, post, node_map):
return tvm.relay.gelu(node_map[self.act][0])


class ReconstructQDQConvSequence(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)

self.weight = wildcard()
self.act = wildcard()
self.bias = wildcard()

self.quant_weight = is_op("qnn.quantize")(self.weight, wildcard(), wildcard(),)
self.dequant_weight = is_op("qnn.dequantize")(self.quant_weight, wildcard(), wildcard(),)

self.quant_act = is_op("qnn.quantize")(self.act, wildcard(), wildcard(),)
self.dequant_act = is_op("qnn.dequantize")(self.quant_act, wildcard(), wildcard(),)

self.conv2d = is_op('nn.conv2d')(self.dequant_act, self.dequant_weight)
self.add = is_op("nn.bias_add")(self.conv2d, self.bias)
self.relu = is_op("nn.relu")(self.add)
self.maxpool = is_op("nn.max_pool2d")(self.relu)
self.pattern = self.maxpool


def callback(self, pre, post, node_map):

act = node_map[self.act][0]
weight = node_map[self.weight][0]


# quantization activations
orig_quant_act = node_map[self.quant_act][0]
scale_act = orig_quant_act.args[1]
zp_act = orig_quant_act.args[2]
quant_axis_act = orig_quant_act.attrs.axis
quant_act = tvm.relay.qnn.op.quantize(act, scale_act, zp_act, axis=quant_axis_act)


# quantization weights
orig_quant_weights = node_map[self.quant_weight][0]
scale_weights = orig_quant_weights.args[1]
zp_weights = orig_quant_weights.args[2]
quant_axis_weights = orig_quant_weights.attrs.axis
quant_weight = tvm.relay.qnn.op.quantize(weight, scale_weights, zp_weights, axis=quant_axis_weights)


# conv
conv_attrs = node_map[self.conv2d][0].attrs
conv2d_act = tvm.relay.op.nn.conv2d(
quant_act,
quant_weight,
strides=conv_attrs.strides,
padding=conv_attrs.padding,
groups=conv_attrs.groups,
channels=conv_attrs.channels,
kernel_size=conv_attrs.kernel_size,
out_dtype='int32'
)

# bias add
bias = node_map[self.bias][0]
# reshaping bias to match conv2d output
new_shape = list(node_map[self.bias][0].checked_type.concrete_shape) + [1, 1]
bias = tvm.relay.reshape(bias, new_shape)
# quantizing bias using act and weights scale and zp=0 as per https://arxiv.org/pdf/1712.05877
scale_bias = scale_act * scale_weights
zp_bias = tvm.relay.expr.const(0)
bias = tvm.relay.qnn.op.quantize(bias, scale_bias, zp_bias, axis=0, out_dtype='int32')
bias = tvm.relay.cast(bias, "int32")
add_act = tvm.relay.add(conv2d_act, bias)


# relu
relu_act = tvm.relay.nn.relu(add_act)


# dequant
orig_dequant = node_map[self.dequant_act][0]
scale_dequant = scale_weights * scale_act
zp_dequant = orig_dequant.args[2]
dequant_axis = orig_dequant.attrs.axis
dequant_act = tvm.relay.qnn.op.dequantize(relu_act, scale_dequant, zp_dequant, axis=dequant_axis)


# maxpool
maxpool_attrs = node_map[self.maxpool][0].attrs
maxpool_act = tvm.relay.nn.max_pool2d(
dequant_act,
pool_size=maxpool_attrs.pool_size,
strides=maxpool_attrs.strides,
dilation=maxpool_attrs.dilation,
padding=maxpool_attrs.padding,
layout=maxpool_attrs.layout
)

return maxpool_act


class ReconstructQDQGemmSequence(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)
self.bias = wildcard()
self.weight = wildcard()
self.act = wildcard()

self.quant_w = is_op("qnn.quantize")(self.weight, wildcard(), wildcard(),)
self.dequant_w = is_op("qnn.dequantize")(self.quant_w, wildcard(), wildcard(),)
self.transpose1 = is_op("transpose")(self.dequant_w)
self.transpose2 = is_op("transpose")(self.transpose1)

self.quant_act = is_op("qnn.quantize")(self.act, wildcard(), wildcard(),)
self.dequant_act = is_op("qnn.dequantize")(self.quant_act, wildcard(), wildcard(),)

self.gemm = is_op('nn.dense')(self.dequant_act, self.transpose2)
self.add = is_op("add")(self.gemm, self.bias)
self.pattern = self.add


def callback(self, pre, post, node_map):

act = node_map[self.act][0]
weight = node_map[self.weight][0]

# quantization activations
orig_quant = node_map[self.quant_act][0]
scale_activations = orig_quant.args[1]
zp_activations = orig_quant.args[2]
quant_axis_act = orig_quant.attrs.axis
quant_act = tvm.relay.qnn.op.quantize(act, scale_activations, zp_activations, axis=quant_axis_act)


# quantization weights
orig_quant = node_map[self.quant_w][0]
scale_weights = orig_quant.args[1]
zp_weights = orig_quant.args[2]
quant_axis = orig_quant.attrs.axis
quant_weight = tvm.relay.qnn.op.quantize(weight, scale_weights, zp_weights, axis=0)

transpose1_axes = node_map[self.transpose1][0].attrs.axes
tanspose1_weights = tvm.relay.transpose(quant_weight, axes=transpose1_axes)

# gemm
gemm_act = tvm.relay.nn.matmul(quant_act, tanspose1_weights, transpose_a=False, transpose_b=False, out_dtype='int32')


# bias add
bias = node_map[self.bias][0]
# quantizing bias using act and weights scale and zp=0 as per https://arxiv.org/pdf/1712.05877
scale_bias = scale_activations * scale_weights
zp_bias = tvm.relay.expr.const(0)
bias = tvm.relay.qnn.op.quantize(bias, scale_bias, zp_bias, axis=quant_axis, out_dtype='int32')
add_act = tvm.relay.add(gemm_act, bias)


# dequant
orig_dequant = node_map[self.dequant_act][0]
scale_dequant = scale_weights * scale_activations
zp_dequant = orig_dequant.args[2]
dequant_axis = orig_dequant.attrs.axis
dequant_act = tvm.relay.qnn.op.dequantize(add_act, scale_dequant, zp_dequant, axis=dequant_axis)

return dequant_act


class RemoveQuantDequantSequence(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)
Expand Down Expand Up @@ -3902,6 +4063,8 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None,
ConvertGlobalAvgPool2dtoAvgPool2d(),
ConvertUpsampleToResize2d(),
DecomposeMultiIndexAdvIndex(),
ReconstructQDQConvSequence(),
ReconstructQDQGemmSequence(),
RemoveQuantDequantSequence(),
ReconstructOnnxQuantizedGelu(),
DecomposeQnnConcat(),
Expand Down

0 comments on commit 3071de5

Please sign in to comment.