diff --git a/python/tvm/relay/op/contrib/buda/buda_passes.py b/python/tvm/relay/op/contrib/buda/buda_passes.py index 1600ab0c1c..f5fd4dad62 100644 --- a/python/tvm/relay/op/contrib/buda/buda_passes.py +++ b/python/tvm/relay/op/contrib/buda/buda_passes.py @@ -2232,6 +2232,167 @@ def callback(self, pre, post, node_map): return tvm.relay.gelu(node_map[self.act][0]) +class ReconstructQDQConvSequence(DFPatternCallback): + def __init__(self): + super().__init__(rewrite_once=True, require_type=True) + + self.weight = wildcard() + self.act = wildcard() + self.bias = wildcard() + + self.quant_weight = is_op("qnn.quantize")(self.weight, wildcard(), wildcard(),) + self.dequant_weight = is_op("qnn.dequantize")(self.quant_weight, wildcard(), wildcard(),) + + self.quant_act = is_op("qnn.quantize")(self.act, wildcard(), wildcard(),) + self.dequant_act = is_op("qnn.dequantize")(self.quant_act, wildcard(), wildcard(),) + + self.conv2d = is_op('nn.conv2d')(self.dequant_act, self.dequant_weight) + self.add = is_op("nn.bias_add")(self.conv2d, self.bias) + self.relu = is_op("nn.relu")(self.add) + self.maxpool = is_op("nn.max_pool2d")(self.relu) + self.pattern = self.maxpool + + + def callback(self, pre, post, node_map): + + act = node_map[self.act][0] + weight = node_map[self.weight][0] + + + # quantization activations + orig_quant_act = node_map[self.quant_act][0] + scale_act = orig_quant_act.args[1] + zp_act = orig_quant_act.args[2] + quant_axis_act = orig_quant_act.attrs.axis + quant_act = tvm.relay.qnn.op.quantize(act, scale_act, zp_act, axis=quant_axis_act) + + + # quantization weights + orig_quant_weights = node_map[self.quant_weight][0] + scale_weights = orig_quant_weights.args[1] + zp_weights = orig_quant_weights.args[2] + quant_axis_weights = orig_quant_weights.attrs.axis + quant_weight = tvm.relay.qnn.op.quantize(weight, scale_weights, zp_weights, axis=quant_axis_weights) + + + # conv + conv_attrs = node_map[self.conv2d][0].attrs + conv2d_act = tvm.relay.op.nn.conv2d( + quant_act, + quant_weight, + strides=conv_attrs.strides, + padding=conv_attrs.padding, + groups=conv_attrs.groups, + channels=conv_attrs.channels, + kernel_size=conv_attrs.kernel_size, + out_dtype='int32' + ) + + # bias add + bias = node_map[self.bias][0] + # reshaping bias to match conv2d output + new_shape = list(node_map[self.bias][0].checked_type.concrete_shape) + [1, 1] + bias = tvm.relay.reshape(bias, new_shape) + # quantizing bias using act and weights scale and zp=0 as per https://arxiv.org/pdf/1712.05877 + scale_bias = scale_act * scale_weights + zp_bias = tvm.relay.expr.const(0) + bias = tvm.relay.qnn.op.quantize(bias, scale_bias, zp_bias, axis=0, out_dtype='int32') + bias = tvm.relay.cast(bias, "int32") + add_act = tvm.relay.add(conv2d_act, bias) + + + # relu + relu_act = tvm.relay.nn.relu(add_act) + + + # dequant + orig_dequant = node_map[self.dequant_act][0] + scale_dequant = scale_weights * scale_act + zp_dequant = orig_dequant.args[2] + dequant_axis = orig_dequant.attrs.axis + dequant_act = tvm.relay.qnn.op.dequantize(relu_act, scale_dequant, zp_dequant, axis=dequant_axis) + + + # maxpool + maxpool_attrs = node_map[self.maxpool][0].attrs + maxpool_act = tvm.relay.nn.max_pool2d( + dequant_act, + pool_size=maxpool_attrs.pool_size, + strides=maxpool_attrs.strides, + dilation=maxpool_attrs.dilation, + padding=maxpool_attrs.padding, + layout=maxpool_attrs.layout + ) + + return maxpool_act + + +class ReconstructQDQGemmSequence(DFPatternCallback): + def __init__(self): + super().__init__(rewrite_once=True, require_type=True) + self.bias = wildcard() + self.weight = wildcard() + self.act = wildcard() + + self.quant_w = is_op("qnn.quantize")(self.weight, wildcard(), wildcard(),) + self.dequant_w = is_op("qnn.dequantize")(self.quant_w, wildcard(), wildcard(),) + self.transpose1 = is_op("transpose")(self.dequant_w) + self.transpose2 = is_op("transpose")(self.transpose1) + + self.quant_act = is_op("qnn.quantize")(self.act, wildcard(), wildcard(),) + self.dequant_act = is_op("qnn.dequantize")(self.quant_act, wildcard(), wildcard(),) + + self.gemm = is_op('nn.dense')(self.dequant_act, self.transpose2) + self.add = is_op("add")(self.gemm, self.bias) + self.pattern = self.add + + + def callback(self, pre, post, node_map): + + act = node_map[self.act][0] + weight = node_map[self.weight][0] + + # quantization activations + orig_quant = node_map[self.quant_act][0] + scale_activations = orig_quant.args[1] + zp_activations = orig_quant.args[2] + quant_axis_act = orig_quant.attrs.axis + quant_act = tvm.relay.qnn.op.quantize(act, scale_activations, zp_activations, axis=quant_axis_act) + + + # quantization weights + orig_quant = node_map[self.quant_w][0] + scale_weights = orig_quant.args[1] + zp_weights = orig_quant.args[2] + quant_axis = orig_quant.attrs.axis + quant_weight = tvm.relay.qnn.op.quantize(weight, scale_weights, zp_weights, axis=0) + + transpose1_axes = node_map[self.transpose1][0].attrs.axes + tanspose1_weights = tvm.relay.transpose(quant_weight, axes=transpose1_axes) + + # gemm + gemm_act = tvm.relay.nn.matmul(quant_act, tanspose1_weights, transpose_a=False, transpose_b=False, out_dtype='int32') + + + # bias add + bias = node_map[self.bias][0] + # quantizing bias using act and weights scale and zp=0 as per https://arxiv.org/pdf/1712.05877 + scale_bias = scale_activations * scale_weights + zp_bias = tvm.relay.expr.const(0) + bias = tvm.relay.qnn.op.quantize(bias, scale_bias, zp_bias, axis=quant_axis, out_dtype='int32') + add_act = tvm.relay.add(gemm_act, bias) + + + # dequant + orig_dequant = node_map[self.dequant_act][0] + scale_dequant = scale_weights * scale_activations + zp_dequant = orig_dequant.args[2] + dequant_axis = orig_dequant.attrs.axis + dequant_act = tvm.relay.qnn.op.dequantize(add_act, scale_dequant, zp_dequant, axis=dequant_axis) + + return dequant_act + + class RemoveQuantDequantSequence(DFPatternCallback): def __init__(self): super().__init__(rewrite_once=True, require_type=True) @@ -3902,6 +4063,8 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None, ConvertGlobalAvgPool2dtoAvgPool2d(), ConvertUpsampleToResize2d(), DecomposeMultiIndexAdvIndex(), + ReconstructQDQConvSequence(), + ReconstructQDQGemmSequence(), RemoveQuantDequantSequence(), ReconstructOnnxQuantizedGelu(), DecomposeQnnConcat(),