Skip to content

Commit

Permalink
Update on "Autoquant"
Browse files Browse the repository at this point in the history
Summary: Adding autoquantization functionality, using hte do_quant api
we can test kernel speeds and pick the best quantization type (or no
quantization) for each layer.

Test Plan: python test/test.py -k "autoquant"

also tested on SAM and SDXL
(pytorch-labs/segment-anything-fast#114,
huggingface/diffusion-fast@176e85f)

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
HDCharles committed Mar 5, 2024
1 parent 0823e95 commit c6d59e5
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 107 deletions.
14 changes: 3 additions & 11 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,7 @@ def test_on_dummy_distilbert(self):
print("sqnr_pt_quant", sqnr_pt_quant)
self.assertTrue(sqnr_sq >= 8.0)

# TODO FINISH TEST CODE
class TestAutoQuant(unittest.TestCase):
def test_auto_quant(self):
torch._inductor.config.epilogue_fusion = False
Expand All @@ -1215,20 +1216,11 @@ def test_auto_quant(self):
(64, 4096, 1024),
(4096, 4096, 1024),
]:
print("testing", m, k, n)
example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
model = torch.nn.Sequential(
# torch.nn.ReLU(),
torch.nn.ReLU(),
torch.nn.Linear(k,n),
# torch.nn.ReLU(),
# torch.nn.Linear(1280,3840),
# torch.nn.ReLU(),
# torch.nn.Linear(3840,1280),
# torch.nn.ReLU(),
# torch.nn.Linear(1280,1024),
# torch.nn.ReLU(),
# torch.nn.Linear(1024,4096),
# torch.nn.ReLU(),
torch.nn.ReLU(),
).to("cuda").to(torch.bfloat16)
do_autoquant(model, example_input)

Expand Down
35 changes: 0 additions & 35 deletions test/test_autoquant.py

This file was deleted.

78 changes: 23 additions & 55 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def log_shape(act_mat, w_autoquant, bias):
y += bias
return y

def tune_autoquant(self, q_cls):
def tune_autoquant(self, q_cls, best_time):
act_shape, w_shape, bias_shape = self.logged_shape
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
with torch.no_grad():
act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device)
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device)
res = q_cls._autoquant_test(act_mat, self.weight, bias)
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time)
update_cache(q_cls, self.logged_shape, self.logged_dtype, res)

def to_quantized(self, error_on_unseen, **kwargs):
Expand All @@ -91,7 +91,7 @@ def to_quantized(self, error_on_unseen, **kwargs):
for q_cls in self.qtensor_class_list:
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
do_print=True
self.tune_autoquant(q_cls)
self.tune_autoquant(q_cls, best_time)
torch._dynamo.reset()
cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf)
if best_time >= cls_res:
Expand Down Expand Up @@ -149,14 +149,12 @@ class AQMixin():
Mixin to turn normal quantized subclasses into autoquantizable ones
"""
@classmethod
def _autoquant_test(cls, act_mat, weight, bias):
def _autoquant_test(cls, act_mat, weight, bias, best_time, *args, **kwargs):
w_qtensor = cls.from_float(weight)
func = lambda a, b, c: F.relu(cls._quantized_op(F.relu(a), b, c))
q_c_op = torch.compile(func, mode="max-autotune")
# q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
with torch.no_grad():
torch.cuda.synchronize()
res = benchmark(q_c_op, act_mat, w_qtensor, bias)
res = benchmark(q_c_op, act_mat, w_qtensor, bias, best_time=best_time)
print(cls, res)
return res

Expand All @@ -165,8 +163,9 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLi
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
"""
@classmethod
def _autoquant_test(cls, act_mat, weight, bias):
res = super()._autoquant_test(act_mat, weight, bias)
def _autoquant_test(cls, act_mat, weight, bias, best_time):
# SAM best is between .51 to .60, SDXL also performs best in this range
INTERPOLATION_CONSTANT=.55
w_qtensor = cls.from_float(weight)
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
act_mat.reshape(-1, act_mat.shape[-1])
Expand All @@ -177,10 +176,18 @@ def _autoquant_test(cls, act_mat, weight, bias):
)
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune")
with torch.no_grad():
res2=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
print(cls, "matmul", res2)
# for SAM best is between .458-.499, SDXL .45=3.094 .47=2.880 .48=3.036 .5=2.930
return res
res_matmul=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data, best_time=best_time)
print(cls, "matmul", res_matmul)

# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
if res_matmul>=best_time:
return res_matmul

# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
res = super()._autoquant_test(act_mat, weight, bias, to_beat)
print(cls, "full", INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul)
return INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul


class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
Expand All @@ -205,11 +212,11 @@ def _quantized_op(act_mat, w_qtensor, bias):
return y.to(orig_dtype)

@classmethod
def _autoquant_test(cls, act_mat, weight, bias):
def _autoquant_test(cls, act_mat, weight, bias, best_time):
# if act_mat has batchsize>2 don't use this kernel
if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>2:
return torch.inf
return super()._autoquant_test(act_mat, weight, bias)
return super()._autoquant_test(act_mat, weight, bias, best_time)

class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
def _quantized_op(act_mat, w_qtensor, bias):
Expand Down Expand Up @@ -246,42 +253,3 @@ def from_float(cls, weight):
AQWeightOnlyQuantizedLinearWeight2,
AQWeightOnlyQuantizedLinearWeight3,
]

if False:
# def _get_to_kwargs(self, *args, **kwargs):
# device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
# device = self.device if device is None else device
# dtype = self.dtype if dtype is None else dtype
# memory_format = (
# memory_format if memory_format is not None else torch.preserve_format
# )
# kwargs = {
# "device": device,
# "dtype": dtype,
# "memory_format": memory_format,
# }
# return kwargs

# def to(self, *args, **kwargs):
# kwargs = self._get_to_kwargs(*args, **kwargs)
# return self.__class__(
# self.int_data.to(kwargs["device"]),
# self.q_scales.to(kwargs["device"]),
# self.transposed,
# self.shape,
# **kwargs,
# )

# def _apply_fn_to_data(self, fn):
# return self.__class__(
# fn(self.int_data), fn(self.q_scales), self.transposed, self.shape, dtype=self.dtype
# )

# def _change_shape(self, shape):
# return self.__class__(
# self.int_data, self.q_scales, self.transposed, shape, dtype=self.dtype
# )

# def half(self):
# return self.to(torch.float16)
pass
5 changes: 1 addition & 4 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
groupwise_affine_quantize_tensor,
quant_int8_dynamic_per_token_linear,
unpack_tinygemm_scales_and_zeros,
quantize_activation_per_token_absmax,
quant_int8_per_token_matmul,
safe_int_mm,
)
from .utils import find_multiple, benchmark
from .utils import find_multiple
import warnings


Expand Down
12 changes: 10 additions & 2 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,20 @@ def get_model_size_in_bytes(model):


def benchmark(f, *args, **kwargs):
if "best_time" in kwargs:
best_time = kwargs.pop("best_time")
else:
best_time = torch.inf
t0 = Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)

# warmup
t0.timeit(10)

res=t0.blocked_autorange(min_run_time=.5)
res=t0.adaptive_autorange(min_run_time=.1)
# run more if median vs median minus iqr (interpolated based on number of runs left) is lower than best_time,
# stop if good res.iqr/res.median or have 20 samples
while res.median-res.iqr+res.iqr*len(res.times)/20 < best_time * 1e-3 and not (res.iqr/res.median<.02 or len(res.times)>=20):
res2 = t0.adaptive_autorange(min_run_time=.5)
res=res.merge([res2, res])[0]
return res.median * 1e3

0 comments on commit c6d59e5

Please sign in to comment.