Skip to content

Commit

Permalink
Avoid Quant Weight Repeatedly (#47587)
Browse files Browse the repository at this point in the history
  • Loading branch information
RachelXu7 authored Nov 3, 2022
1 parent b3d52d4 commit 66a1df3
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,7 @@ def __init__(
self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict()
self._quant_var_scale_map = collections.OrderedDict()
self._quantized_ops = set()

def apply(self, graph):
"""
Expand Down Expand Up @@ -1173,24 +1174,27 @@ def apply(self, graph):
quant_axis = 1
else:
quant_axis = 0
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
quant_axis,
self._weight_bits,
)
quantized_param_v = np.round(quantized_param_v)
# Weight bias correction
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
if input_arg_name not in self._quantized_ops:
self._quantized_ops.add(input_arg_name)
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
quant_axis,
weight_bits=self._weight_bits,
self._weight_bits,
)
quantized_param_v = np.round(quantized_param_v)
self._restore_var(input_arg_name, quantized_param_v)
# Weight bias correction
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
scale_v,
quant_axis,
weight_bits=self._weight_bits,
)
quantized_param_v = np.round(quantized_param_v)
self._restore_var(input_arg_name, quantized_param_v)

self._remove_fake_quant_and_dequant_op(graph, op_node)

# Remove all fake dequant op
Expand Down Expand Up @@ -3029,6 +3033,7 @@ def __init__(
self._save_int_weight = save_int_weight
assert self._scope is not None, "scope must not be None."
assert self._place is not None, "place must not be None."
self._quantized_ops = set()

def apply(self, graph):
assert isinstance(
Expand Down Expand Up @@ -3066,29 +3071,31 @@ def apply(self, graph):
param_v = self._load_var(x_node.name())
quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length")
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
quant_axis,
bits_length,
onnx_format=True,
)
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
if x_node.name() not in self._quantized_ops:
self._quantized_ops.add(x_node.name())
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
quant_axis,
weight_bits=bits_length,
bits_length,
onnx_format=True,
)
if self._save_int_weight:
# cast weight type to int
if self._quant_bits == 8:
save_weight_dtype = np.int8
quantized_param_v = quantized_param_v.astype(
save_weight_dtype
)
self._restore_var(x_node.name(), quantized_param_v)
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
scale_v,
quant_axis,
weight_bits=bits_length,
)
if self._save_int_weight:
# cast weight type to int
if self._quant_bits == 8:
save_weight_dtype = np.int8
quantized_param_v = quantized_param_v.astype(
save_weight_dtype
)
self._restore_var(x_node.name(), quantized_param_v)

for next_op_node in out_node.outputs:
graph.update_input_link(out_node, x_node, next_op_node)
Expand Down

0 comments on commit 66a1df3

Please sign in to comment.