From 29214a9efa750f2e1159bd700ecc2c226699d741 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 19 Mar 2024 15:27:37 -0700 Subject: [PATCH] Update on "Autoquant" Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL https://github.com/pytorch-labs/segment-anything-fast/pull/114 https://github.com/HDCharles/sdxl-fast/commit/8d9942ab05a552f25f5bfe09da02719ce255467f Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- README.md | 41 ++++++++++++---- __init__.py | 0 torchao/__init__.py | 24 ++++++++++ torchao/quantization/__init__.py | 2 +- torchao/quantization/autoquant.py | 78 ++++++++++++++++++++++--------- torchao/quantization/quant_api.py | 13 +++--- torchao/quantization/utils.py | 21 --------- 7 files changed, 120 insertions(+), 59 deletions(-) create mode 100644 __init__.py diff --git a/README.md b/README.md index cd34c0d8ac..45e51c828f 100644 --- a/README.md +++ b/README.md @@ -43,29 +43,50 @@ The following apis use quantized [tensor subclasses](https://pytorch.org/docs/st This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible. -### A8W8 Dynamic Quantization +### Autoquantization -The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this -converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul. - -Example +The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes +of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer. ``` import torch -from torchao.quantization import quant_api +import torchao + +# inductor settings which improve torch.compile runtime for quantized modules +torch._inductor.config.force_fuse_int_mm_with_mul +torch._inductor.config.use_mixed_mm # some user model and example input model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') -# convert linear modules to quantized linear modules -quant_api.change_linear_weights_to_int8_dqtensors(model) +# perform autoquantization +torchao.autoquant(model, (input)) # compile the model to improve performance model = torch.compile(model, mode='max-autotune') model(input) ``` + +### A8W8 Dynamic Quantization + +The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this +converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul. + +Example + +``` +# some user model and example input +... + +# convert linear modules to quantized linear modules +torchao.change_linear_weights_to_int8_dqtensors(model) + +# compile the model to improve performance +... +``` + This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor. @@ -81,7 +102,7 @@ Example ... # convert linear modules to quantized linear modules -quant_api.change_linear_weights_to_int8_woqtensors(model) +torchao.change_linear_weights_to_int8_woqtensors(model) # compile the model to improve performance ... @@ -102,7 +123,7 @@ Example ... # convert linear modules to quantized linear modules -quant_api.change_linear_weights_to_int4_woqtensors(model) +torchao.change_linear_weights_to_int4_woqtensors(model) # compile the model to improve performance ... diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/__init__.py b/torchao/__init__.py index e69de29bb2..c2634c5365 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -0,0 +1,24 @@ +from torchao.quantization import ( + apply_weight_only_int8_quant, + apply_dynamic_quant, + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + change_linear_weights_to_int4_woqtensors, + swap_conv2d_1x1_to_linear, + autoquant, + change_linears_to_autoquantizable, + change_autoquantizable_to_quantized, +) + +__all__ = [ + "apply_weight_only_int8_quant", + "apply_dynamic_quant", + "change_linear_weights_to_int8_dqtensors", + "change_linear_weights_to_int8_woqtensors", + "change_linear_weights_to_int4_woqtensors", + "swap_conv2d_1x1_to_linear" + "safe_int_mm", + "autoquant", + "change_linears_to_autoquantizable", + "change_autoquantizable_to_quantized", +] diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 525008a77d..1b421ab8e4 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -25,7 +25,7 @@ "dynamically_quantize_per_channel", "dequantize_per_tensor", "dequantize_per_channel", - "do_autoquant", + "autoquant", "change_linears_to_autoquantizable", "change_autoquantizable_to_quantized", "quant_int8_dynamic_linear", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 60eb29127c..f05958c84c 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -1,6 +1,4 @@ import torch -import os -from subprocess import check_output from .subclass import ( # noqa Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, @@ -79,26 +77,56 @@ def to_quantized(self, error_on_unseen, **kwargs): # default back to non-quantized weight if not seen self = AQFloatLinearWeight.from_float(self.weight) return self + + + # only want to do shape+final print a single time if multiple layers + # see/have same shapes so we gate on check_cache being empty for + # at least one of the class/shape combinations. + do_final_print = False + print_once = True + + def count_shapes(self, do_print=True): + differe_shape_count=0 + for shapes_and_dtype, times_seen in self.logged_data.items(): + differe_shape_count += 1 + if do_print: + act_shape, weight_shape, bias_shape, dtype = shapes_and_dtype + print(f"activation_shapes: {act_shape}, times_seen: {times_seen}") + if do_print: + print(f"weight_shape: {weight_shape}, dtype: {dtype}, bias_shape: {bias_shape}") + return differe_shape_count + + # check each class best_time = torch.inf best_cls = None - do_print=False - # check each class for q_cls in self.qtensor_class_list: # for each logged shape+dtype, benchmark - cls_res=0 + cur_time=0 + shape_count = count_shapes(self, do_print=False) for shapes_and_dtype, times_seen in self.logged_data.items(): if check_cache(q_cls, shapes_and_dtype) is None: - do_print=True - self.tune_autoquant(q_cls, shapes_and_dtype, best_time) + # only do final print if we have to autotune at least one cls/shape pair + do_final_print=True + + # only print shapes once + if print_once == True: + print_once = False + count_shapes(self, do_print=True) + + time_for_best_shape = check_cache(best_cls, shapes_and_dtype) + time_for_best_shape = torch.inf if time_for_best_shape is None else time_for_best_shape + self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape) torch._dynamo.reset() - cls_res += check_cache(q_cls, shapes_and_dtype) * times_seen - if best_time >= cls_res: - best_time = cls_res + cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen + if shape_count is not None and shape_count > 1: + print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms") + if best_time >= cur_time: + best_time = cur_time best_cls = q_cls # only print if this is the first time seeing some cls+shape combo, # otherwise we will print the same thing for every layer. - if do_print: - print(f"for {self.logged_data}, best_cls={best_cls}") + if do_final_print: + print(f"best_cls={best_cls}\n") # TODO handle random cls args/kwargs? or should they be curried? self = best_cls.from_float(self.weight) return self @@ -145,6 +173,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) def do_autoquant_bench(op, *args, **kwargs): + """ + runs benchmark op(*args, **kwargs) avoiding torch.compile overhead + """ rep = kwargs.pop("rep", 100) warmup = kwargs.pop("warmup", 25) with torch.no_grad(): @@ -152,14 +183,14 @@ def do_autoquant_bench(op, *args, **kwargs): stream = torch.cuda.Stream() stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream): - op(*args) + op(*args, **kwargs) stream.synchronize() torch.cuda.current_stream().wait_stream(stream) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - op(*args) + op(*args, **kwargs) res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") return res @@ -180,11 +211,11 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): else: func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c)) q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs") - res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias) + res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100) if res < best_time*1.1: res2 = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900) res=(res2*.9+res*.1) - print(f"time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") + print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") return res class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight): @@ -196,7 +227,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): if not _is_interpolate_mode(mode): return super()._autoquant_test(act_mat, weight, bias, best_time, mode) - # SAM best is between .8 to 1, SDXL also performs best in this range + # SAM best is between .8 and 1, SDXL also performs best in this range INTERPOLATION_CONSTANT = mode[1] w_qtensor = cls.from_float(weight) x_vals_int8, x_scales = quantize_activation_per_token_absmax( @@ -209,7 +240,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") with torch.no_grad(): res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) - print(f"time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") + print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op if res_matmul>=best_time: @@ -220,7 +251,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): res = super()._autoquant_test(act_mat, weight, bias, to_beat) max_int_const_win = (best_time-res_matmul)/(res-res_matmul) res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul - print(f"time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") + print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") return res_f class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin): @@ -252,6 +283,10 @@ def _autoquant_test(cls, act_mat, *args): return super()._autoquant_test(act_mat, *args) class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that + uses a different kernel + """ def _quantized_op(act_mat, w_qtensor, bias): orig_shape = act_mat.shape y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales) @@ -265,7 +300,8 @@ class AQFloatLinearWeight(torch.Tensor, AQMixin): A class to be used in concert with AutoQuantizableLinearWeight to provide a default/non-quantized option. Only implements the bare minimum needed to work with the AutoQuantizableLinearWeight class using the same interfaces that would normally be - used by QTensor subclasses but for a default linear op instead. + used by QTensor subclasses but for a default linear op instead. Result of from_float + is not a tensor subclass, but rather the float tensor. """ def __init__(self): super().__init__() @@ -284,5 +320,5 @@ def from_float(cls, weight): AQWeightOnlyQuantizedLinearWeight, AQWeightOnlyQuantizedLinearWeight2, # AQWeightOnlyQuantizedLinearWeight3, - # 3rd version gets picked in situations where it is slower for the interpolation mode + # TODO this gets picked in places where it makes perf worse, why? ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 942bf043e9..06ffe21dcb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -37,7 +37,7 @@ "change_linear_weights_to_int8_woqtensors", "change_linear_weights_to_int4_woqtensors", "swap_conv2d_1x1_to_linear", - "do_autoquant", + "autoquant", "change_linears_to_autoquantizable", "change_autoquantizable_to_quantized", ] @@ -182,6 +182,9 @@ def change_autoquantizable_to_quantized(model, **kwargs): on benchmark results. Expectation is that these modules are torch.compiled afterwards. """ + hold = torch._dynamo.config.automatic_dynamic_shapes + torch._dynamo.config.automatic_dynamic_shapes = False + filter_fn = kwargs.pop( "filter_fn", lambda mod, *args: @@ -195,24 +198,22 @@ def change_autoquantizable_to_quantized(model, **kwargs): ), filter_fn, ) + torch._dynamo.config.automatic_dynamic_shapes = hold + torch._dynamo.reset() @torch.no_grad() -def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear, mode=["relu",None], **kwargs): +def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear, mode=["relu",None], **kwargs): """ Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer and applies that type of quantization. """ - hold = torch._dynamo.config.automatic_dynamic_shapes - torch._dynamo.config.automatic_dynamic_shapes = False change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs) if not isinstance(example_input, (tuple, list)): assert isinstance(example_input, torch.Tensor) example_input = [example_input] model(*example_input) change_autoquantizable_to_quantized(model, **kwargs) - torch._dynamo.config.automatic_dynamic_shapes = hold - torch._dynamo.reset() return model def swap_conv2d_1x1_to_linear(model, filter_fn=None): diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index e973ab8ca9..73621e6297 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -7,7 +7,6 @@ import torch from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils.benchmark import Timer __all__ = [ "find_multiple", @@ -87,23 +86,3 @@ def get_model_size_in_bytes(model): for b in model.buffers(): s += b.nelement() * b.element_size() return s - - -def benchmark(f, *args, **kwargs): - if "best_time" in kwargs: - best_time = kwargs.pop("best_time") - else: - best_time = torch.inf - t0 = Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - - # warmup - t0.timeit(10) - res=t0.adaptive_autorange(min_run_time=.1) - # run more if median vs median minus iqr (interpolated based on number of runs left) is lower than best_time, - # stop if good res.iqr/res.median or have 20 samples - while res.median-res.iqr+res.iqr*len(res.times)/20 < best_time * 1e-3 and not (res.iqr/res.median<.02 or len(res.times)>=20): - res2 = t0.adaptive_autorange(min_run_time=.5) - res=res.merge([res2, res])[0] - return res.median * 1e3